ai-station/.venv/lib/python3.12/site-packages/literalai/callback/langchain_callback.py

541 lines
20 KiB
Python
Raw Normal View History

import time
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union, cast
from literalai.helper import ensure_values_serializable
from literalai.observability.generation import (
ChatGeneration,
CompletionGeneration,
GenerationMessage,
GenerationMessageRole,
)
from literalai.observability.step import Step
if TYPE_CHECKING:
from uuid import UUID
from literalai.client import LiteralClient
from literalai.observability.step import TrueStepType
def process_variable_value(value: Any) -> str:
return str(value) if value is not None else ""
def _convert_message_role(role: str):
message_role = "assistant"
if "human" in role.lower():
message_role = "user"
elif "system" in role.lower():
message_role = "system"
elif "function" in role.lower():
message_role = "function"
elif "tool" in role.lower():
message_role = "tool"
return cast(GenerationMessageRole, message_role)
def get_langchain_callback():
try:
version("langchain")
except Exception:
raise Exception(
"Please install langchain to use the langchain callback. "
"You can install it with `pip install langchain`"
)
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.schema import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
class ChatGenerationStart(TypedDict):
input_messages: List[BaseMessage]
start: float
token_count: int
tt_first_token: Optional[float]
class CompletionGenerationStart(TypedDict):
prompt: str
start: float
token_count: int
tt_first_token: Optional[float]
class GenerationHelper:
chat_generations: Dict[str, ChatGenerationStart]
completion_generations: Dict[str, CompletionGenerationStart]
generation_inputs: Dict[str, Dict]
def __init__(self) -> None:
self.chat_generations = {}
self.completion_generations = {}
self.generation_inputs = {}
def _convert_message_dict(
self,
message: Dict,
):
class_name = message["id"][-1]
kwargs = message.get("kwargs", {})
function_call = kwargs.get("additional_kwargs", {}).get("function_call")
tool_calls = kwargs.get("additional_kwargs", {}).get("tool_calls")
msg = GenerationMessage(
name=kwargs.get("name"),
role=_convert_message_role(class_name),
content="",
tool_call_id=getattr(message, "tool_call_id", None),
)
if function_call:
msg["function_call"] = function_call
else:
msg["content"] = kwargs.get("content", "")
if tool_calls:
msg["tool_calls"] = tool_calls
return msg
def _convert_message(
self,
message: Union[Dict, BaseMessage],
):
if isinstance(message, dict):
return self._convert_message_dict(
message,
)
function_call = message.additional_kwargs.get("function_call")
tool_calls = message.additional_kwargs.get("tool_calls")
msg = GenerationMessage(
name=getattr(message, "name", None),
role=_convert_message_role(message.type),
content="",
tool_call_id=getattr(message, "tool_call_id", None),
)
if literal_uuid := message.additional_kwargs.get("uuid"):
msg["uuid"] = literal_uuid
msg["templated"] = True
if function_call:
msg["function_call"] = function_call
else:
msg["content"] = message.content # type: ignore
if tool_calls:
msg["tool_calls"] = tool_calls
return msg
def _is_message(self, to_check: Any) -> bool:
return isinstance(to_check, BaseMessage)
def _is_message_array(self, to_check: Any) -> bool:
return isinstance(to_check, list) and all(
self._is_message(item) for item in to_check
)
def process_content(self, content: Any, root=True):
if content is None:
return {}
if self._is_message_array(content):
if root:
return {"messages": [self._convert_message(m) for m in content]}
else:
return [self._convert_message(m) for m in content]
elif self._is_message(content):
return self._convert_message(content)
elif isinstance(content, dict):
processed_dict = {}
for key, value in content.items():
processed_value = self.process_content(value, root=False)
processed_dict[key] = processed_value
return processed_dict
elif isinstance(content, list):
return [self.process_content(item, root=False) for item in content]
elif isinstance(content, str):
if root:
return {"content": content}
return content
else:
if root:
return {"content": str(content)}
return str(content)
def _build_llm_settings(
self,
serialized: Dict,
invocation_params: Optional[Dict] = None,
):
# invocation_params = run.extra.get("invocation_params")
if invocation_params is None:
return None, None
provider = invocation_params.pop("_type", "") # type: str
model_kwargs = invocation_params.pop("model_kwargs", {})
if model_kwargs is None:
model_kwargs = {}
merged = {
**invocation_params,
**model_kwargs,
**serialized.get("kwargs", {}),
}
# make sure there is no api key specification
settings = self.process_content(
{k: v for k, v in merged.items() if not k.endswith("_api_key")}
)
model_keys = ["azure_deployment", "deployment_name", "model", "model_name"]
model = next((settings[k] for k in model_keys if k in settings), None)
if isinstance(model, str):
model = model.replace("models/", "")
tools = None
if "functions" in settings:
tools = [
{"type": "function", "function": f} for f in settings["functions"]
]
if "tools" in settings:
tools = settings["tools"]
return provider, model, tools, settings
DEFAULT_TO_IGNORE = [
"RunnableSequence",
"RunnableParallel",
"RunnableAssign",
"RunnableLambda",
"structured_outputs_parser",
"<lambda>",
]
DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"]
class LangchainTracer(BaseTracer, GenerationHelper):
steps: Dict[str, Step]
parent_id_map: Dict[str, str]
ignored_runs: set
client: "LiteralClient"
def __init__(
self,
client: "LiteralClient",
# Runs to ignore to enhance readability
to_ignore: Optional[List[str]] = None,
# Runs to keep within ignored runs
to_keep: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
BaseTracer.__init__(self, **kwargs)
GenerationHelper.__init__(self)
self.client = client
self.steps = {}
self.parent_id_map = {}
self.ignored_runs = set()
if to_ignore is None:
self.to_ignore = DEFAULT_TO_IGNORE
else:
self.to_ignore = to_ignore
if to_keep is None:
self.to_keep = DEFAULT_TO_KEEP
else:
self.to_keep = to_keep
def model_dump(self):
return {}
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: "UUID",
parent_run_id: Optional["UUID"] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
lc_messages = messages[0]
self.chat_generations[str(run_id)] = {
"input_messages": lc_messages,
"start": time.time(),
"token_count": 0,
"tt_first_token": None,
}
return super().on_chat_model_start(
serialized,
messages,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
**kwargs,
)
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: "UUID",
tags: Optional[List[str]] = None,
parent_run_id: Optional["UUID"] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
self.completion_generations[str(run_id)] = {
"prompt": prompts[0],
"start": time.time(),
"token_count": 0,
"tt_first_token": None,
}
return super().on_llm_start(
serialized,
prompts,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
name=name,
**kwargs,
)
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: "UUID",
parent_run_id: Optional["UUID"] = None,
**kwargs: Any,
) -> Run:
if isinstance(chunk, ChatGenerationChunk):
start = self.chat_generations[str(run_id)]
else:
start = self.completion_generations[str(run_id)] # type: ignore
start["token_count"] += 1
if start["tt_first_token"] is None:
start["tt_first_token"] = (time.time() - start["start"]) * 1000
return super().on_llm_new_token(
token,
chunk=chunk,
run_id=run_id,
parent_run_id=parent_run_id,
)
def _persist_run(self, run: Run) -> None:
pass
def _get_run_parent_id(self, run: Run):
parent_id = str(run.parent_run_id) if run.parent_run_id else None
return parent_id
def _get_non_ignored_parent_id(self, current_parent_id: Optional[str] = None):
if not current_parent_id or current_parent_id not in self.parent_id_map:
return None
while current_parent_id in self.parent_id_map:
# If the parent id is in the ignored runs, we need to get the parent id of the ignored run
if current_parent_id in self.ignored_runs:
current_parent_id = self.parent_id_map[current_parent_id]
else:
return current_parent_id
return None
def _should_ignore_run(self, run: Run):
parent_id = self._get_run_parent_id(run)
if parent_id:
# Add the parent id of the ignored run in the mapping
# so we can re-attach a kept child to the right parent id
self.parent_id_map[str(run.id)] = parent_id
ignore_by_name = False
ignore_by_parent = parent_id in self.ignored_runs
for filter in self.to_ignore:
if filter in run.name:
ignore_by_name = True
break
ignore = ignore_by_name or ignore_by_parent
# If the ignore cause is the parent being ignored, check if we should nonetheless keep the child
if ignore_by_parent and not ignore_by_name and run.run_type in self.to_keep:
return False, self._get_non_ignored_parent_id(parent_id)
else:
if ignore:
# Tag the run as ignored
self.ignored_runs.add(str(run.id))
return ignore, parent_id
def _start_trace(self, run: Run) -> None:
super()._start_trace(run)
ignore, parent_id = self._should_ignore_run(run)
if run.run_type in ["chain", "prompt"]:
self.generation_inputs[str(run.id)] = ensure_values_serializable(
run.inputs
)
if ignore:
return
step_type: "TrueStepType" = "run" if not self.steps else "undefined"
if run.run_type == "agent":
step_type = "run"
elif run.run_type == "llm":
step_type = "llm"
elif run.run_type == "retriever":
step_type = "retrieval"
elif run.run_type == "tool":
step_type = "tool"
elif run.run_type == "embedding":
step_type = "embedding"
step = self.client.start_step(
id=str(run.id),
name=run.name,
type=step_type,
parent_id=parent_id,
)
step.tags = run.tags
step.metadata = run.metadata
step.input = self.process_content(run.inputs)
self.steps[str(run.id)] = step
def _on_run_update(self, run: Run) -> None:
"""Process a run upon update."""
ignore, parent_id = self._should_ignore_run(run)
if ignore:
return
current_step = self.steps.get(str(run.id), None)
if run.run_type == "llm" and current_step:
provider, model, tools, llm_settings = self._build_llm_settings(
(run.serialized or {}), (run.extra or {}).get("invocation_params")
)
generations = (run.outputs or {}).get("generations", [])
generation = generations[0][0]
variables = self.generation_inputs.get(str(run.parent_run_id), {})
variables = {
k: process_variable_value(v)
for k, v in variables.items()
if v is not None
}
if message := generation.get("message"):
chat_start = self.chat_generations[str(run.id)]
duration = time.time() - chat_start["start"]
if duration and chat_start["token_count"]:
throughput = chat_start["token_count"] / duration
else:
throughput = None
kwargs = message.get("kwargs", {})
usage_metadata = kwargs.get("usage_metadata", {})
message_completion = self._convert_message(message)
current_step.generation = ChatGeneration(
provider=provider,
model=model,
tools=tools,
variables=variables,
settings=llm_settings,
duration=duration,
token_throughput_in_s=throughput,
tt_first_token=chat_start.get("tt_first_token"),
messages=[
self._convert_message(m)
for m in chat_start["input_messages"]
],
message_completion=message_completion,
input_token_count=usage_metadata.get("input_tokens"),
output_token_count=usage_metadata.get("output_tokens"),
token_count=usage_metadata.get("total_tokens"),
)
# find first message with prompt_id
prompt_id = None
variables_with_defaults: Optional[Dict] = None
for m in chat_start["input_messages"]:
if m.additional_kwargs.get("prompt_id"):
prompt_id = m.additional_kwargs["prompt_id"]
variables_with_defaults = m.additional_kwargs.get(
"variables"
)
break
if prompt_id:
current_step.generation.prompt_id = prompt_id
if variables_with_defaults:
current_step.generation.variables = {
k: process_variable_value(v)
for k, v in variables_with_defaults.items()
if v is not None
}
current_step.output = message_completion
else:
completion_start = self.completion_generations[str(run.id)]
duration = time.time() - completion_start["start"]
if duration and completion_start["token_count"]:
throughput = completion_start["token_count"] / duration
else:
throughput = None
completion = generation.get("text", "")
kwargs = message.get("kwargs", {})
usage_metadata = kwargs.get("usage_metadata", {})
current_step.generation = CompletionGeneration(
provider=provider,
model=model,
settings=llm_settings,
variables=variables,
duration=duration,
token_throughput_in_s=throughput,
tt_first_token=completion_start.get("tt_first_token"),
prompt=completion_start["prompt"],
completion=completion,
input_token_count=usage_metadata.get("input_tokens"),
output_token_count=usage_metadata.get("output_tokens"),
token_count=usage_metadata.get("total_tokens"),
)
current_step.output = {"content": completion}
if current_step:
if current_step.metadata is None:
current_step.metadata = {}
current_step.end()
return
outputs = run.outputs or {}
if current_step:
current_step.output = self.process_content(outputs)
current_step.end()
def _on_error(self, error: BaseException, *, run_id: "UUID", **kwargs: Any):
if current_step := self.steps.get(str(run_id), None):
if current_step.metadata is None:
current_step.metadata = {}
current_step.error = str(error)
current_step.end()
self.client.flush()
on_llm_error = _on_error
on_chain_error = _on_error
on_tool_error = _on_error
on_retriever_error = _on_error
return LangchainTracer