ai-station/.venv/lib/python3.12/site-packages/literalai/wrappers.py

135 lines
4.7 KiB
Python

import time
from functools import wraps
from importlib import import_module
from typing import TYPE_CHECKING, Callable, Optional, TypedDict, Union
from literalai.context import active_steps_var
if TYPE_CHECKING:
from literalai.observability.generation import ChatGeneration, CompletionGeneration
from literalai.observability.step import Step
class BeforeContext(TypedDict):
original_func: Callable
generation: Optional[Union["ChatGeneration", "CompletionGeneration"]]
step: Optional["Step"]
start: float
class AfterContext(TypedDict):
original_func: Callable
generation: Optional[Union["ChatGeneration", "CompletionGeneration"]]
step: Optional["Step"]
start: float
def remove_literalai_args(kargs):
"""Remove argument prefixed with "literalai_" from kwargs and return them in a separate dict"""
largs = {}
for key in list(kargs.keys()):
if key.startswith("literalai_"):
value = kargs.pop(key)
largs[key] = value
return largs
def restore_literalai_args(kargs, largs):
"""Reverse the effect of remove_literalai_args by merging the literal arguments into kwargs"""
for key in list(largs.keys()):
kargs[key] = largs[key]
def sync_wrapper(before_func=None, after_func=None):
def decorator(original_func):
@wraps(original_func)
def wrapped(*args, **kwargs):
context = {"original_func": original_func}
# If a before_func is provided, call it with the shared context.
if before_func:
before_func(context, *args, **kwargs)
# Remove literal arguments before calling the original function
literalai_args = remove_literalai_args(kwargs)
context["start"] = time.time()
try:
result = original_func(*args, **kwargs)
except Exception as e:
active_steps = active_steps_var.get()
if active_steps and len(active_steps) > 0:
current_step = active_steps[-1]
current_step.error = str(e)
current_step.end()
current_step.processor.flush()
raise e
# If an after_func is provided, call it with the result and the shared context.
if after_func:
restore_literalai_args(kwargs, literalai_args)
result = after_func(result, context, *args, **kwargs)
return result
return wrapped
return decorator
def async_wrapper(before_func=None, after_func=None):
def decorator(original_func):
@wraps(original_func)
async def wrapped(*args, **kwargs):
context = {"original_func": original_func}
# If a before_func is provided, call it with the shared context.
if before_func:
await before_func(context, *args, **kwargs)
# Remove literal arguments before calling the original function
literalai_args = remove_literalai_args(kwargs)
context["start"] = time.time()
try:
result = await original_func(*args, **kwargs)
except Exception as e:
active_steps = active_steps_var.get()
if active_steps and len(active_steps) > 0:
current_step = active_steps[-1]
current_step.error = str(e)
current_step.end()
current_step.processor.flush()
raise e
# If an after_func is provided, call it with the result and the shared context.
if after_func:
restore_literalai_args(kwargs, literalai_args)
result = await after_func(result, context, *args, **kwargs)
return result
return wrapped
return decorator
def wrap_all(
to_wrap: list,
before_wrapper,
after_wrapper,
async_before_wrapper,
async_after_wrapper,
):
for patch in to_wrap:
module = import_module(str(patch["module"]))
target_object = getattr(module, str(patch["object"]))
original_method = getattr(target_object, str(patch["method"]))
if patch["async"]:
wrapped_method = async_wrapper(
before_func=async_before_wrapper(metadata=patch["metadata"]),
after_func=async_after_wrapper(metadata=patch["metadata"]),
)(original_method)
else:
wrapped_method = sync_wrapper(
before_func=before_wrapper(metadata=patch["metadata"]),
after_func=after_wrapper(metadata=patch["metadata"]),
)(original_method)
setattr(target_object, str(patch["method"]), wrapped_method)