683 lines
24 KiB
Python
683 lines
24 KiB
Python
|
|
import time
|
||
|
|
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
|
||
|
|
from uuid import UUID
|
||
|
|
|
||
|
|
import pydantic
|
||
|
|
from langchain_core.load import dumps
|
||
|
|
from langchain_core.messages import BaseMessage
|
||
|
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||
|
|
from langchain_core.tracers.base import AsyncBaseTracer
|
||
|
|
from langchain_core.tracers.schemas import Run
|
||
|
|
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
|
||
|
|
from literalai.observability.step import TrueStepType
|
||
|
|
|
||
|
|
from chainlit.context import context_var
|
||
|
|
from chainlit.message import Message
|
||
|
|
from chainlit.step import Step
|
||
|
|
from chainlit.utils import utc_now
|
||
|
|
|
||
|
|
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
|
||
|
|
|
||
|
|
|
||
|
|
class FinalStreamHelper:
|
||
|
|
# The stream we can use to stream the final answer from a chain
|
||
|
|
final_stream: Union[Message, None]
|
||
|
|
# Should we stream the final answer?
|
||
|
|
stream_final_answer: bool = False
|
||
|
|
# Token sequence that prefixes the answer
|
||
|
|
answer_prefix_tokens: List[str]
|
||
|
|
# Ignore white spaces and new lines when comparing answer_prefix_tokens to last tokens? (to determine if answer has been reached)
|
||
|
|
strip_tokens: bool
|
||
|
|
|
||
|
|
answer_reached: bool
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
answer_prefix_tokens: Optional[List[str]] = None,
|
||
|
|
stream_final_answer: bool = False,
|
||
|
|
force_stream_final_answer: bool = False,
|
||
|
|
strip_tokens: bool = True,
|
||
|
|
) -> None:
|
||
|
|
# Langchain final answer streaming logic
|
||
|
|
if answer_prefix_tokens is None:
|
||
|
|
self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS
|
||
|
|
else:
|
||
|
|
self.answer_prefix_tokens = answer_prefix_tokens
|
||
|
|
if strip_tokens:
|
||
|
|
self.answer_prefix_tokens_stripped = [
|
||
|
|
token.strip() for token in self.answer_prefix_tokens
|
||
|
|
]
|
||
|
|
else:
|
||
|
|
self.answer_prefix_tokens_stripped = self.answer_prefix_tokens
|
||
|
|
|
||
|
|
self.last_tokens = [""] * len(self.answer_prefix_tokens)
|
||
|
|
self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens)
|
||
|
|
self.strip_tokens = strip_tokens
|
||
|
|
self.answer_reached = force_stream_final_answer
|
||
|
|
|
||
|
|
# Our own final answer streaming logic
|
||
|
|
self.stream_final_answer = stream_final_answer
|
||
|
|
self.final_stream = None
|
||
|
|
self.has_streamed_final_answer = False
|
||
|
|
|
||
|
|
def _check_if_answer_reached(self) -> bool:
|
||
|
|
if self.strip_tokens:
|
||
|
|
return self._compare_last_tokens(self.last_tokens_stripped)
|
||
|
|
else:
|
||
|
|
return self._compare_last_tokens(self.last_tokens)
|
||
|
|
|
||
|
|
def _compare_last_tokens(self, last_tokens: List[str]):
|
||
|
|
if last_tokens == self.answer_prefix_tokens_stripped:
|
||
|
|
# If tokens match perfectly we are done
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
# Some LLMs will consider all the tokens of the final answer as one token
|
||
|
|
# so we check if any last token contains all answer tokens
|
||
|
|
return any(
|
||
|
|
[
|
||
|
|
all(
|
||
|
|
answer_token in last_token
|
||
|
|
for answer_token in self.answer_prefix_tokens_stripped
|
||
|
|
)
|
||
|
|
for last_token in last_tokens
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
def _append_to_last_tokens(self, token: str) -> None:
|
||
|
|
self.last_tokens.append(token)
|
||
|
|
self.last_tokens_stripped.append(token.strip())
|
||
|
|
if len(self.last_tokens) > len(self.answer_prefix_tokens):
|
||
|
|
self.last_tokens.pop(0)
|
||
|
|
self.last_tokens_stripped.pop(0)
|
||
|
|
|
||
|
|
|
||
|
|
class ChatGenerationStart(TypedDict):
|
||
|
|
input_messages: List[BaseMessage]
|
||
|
|
start: float
|
||
|
|
token_count: int
|
||
|
|
tt_first_token: Optional[float]
|
||
|
|
|
||
|
|
|
||
|
|
class CompletionGenerationStart(TypedDict):
|
||
|
|
prompt: str
|
||
|
|
start: float
|
||
|
|
token_count: int
|
||
|
|
tt_first_token: Optional[float]
|
||
|
|
|
||
|
|
|
||
|
|
class GenerationHelper:
|
||
|
|
chat_generations: Dict[str, ChatGenerationStart]
|
||
|
|
completion_generations: Dict[str, CompletionGenerationStart]
|
||
|
|
generation_inputs: Dict[str, Dict]
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self.chat_generations = {}
|
||
|
|
self.completion_generations = {}
|
||
|
|
self.generation_inputs = {}
|
||
|
|
|
||
|
|
def ensure_values_serializable(self, data):
|
||
|
|
"""
|
||
|
|
Recursively ensures that all values in the input (dict or list) are JSON serializable.
|
||
|
|
"""
|
||
|
|
if isinstance(data, dict):
|
||
|
|
return {
|
||
|
|
key: self.ensure_values_serializable(value)
|
||
|
|
for key, value in data.items()
|
||
|
|
}
|
||
|
|
elif isinstance(data, pydantic.BaseModel):
|
||
|
|
# Fallback to support pydantic v1
|
||
|
|
# https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel
|
||
|
|
if pydantic.VERSION.startswith("1"):
|
||
|
|
return data.dict()
|
||
|
|
|
||
|
|
# pydantic v2
|
||
|
|
return data.model_dump() # pyright: ignore reportAttributeAccessIssue
|
||
|
|
elif isinstance(data, list):
|
||
|
|
return [self.ensure_values_serializable(item) for item in data]
|
||
|
|
elif isinstance(data, (str, int, float, bool, type(None))):
|
||
|
|
return data
|
||
|
|
elif isinstance(data, (tuple, set)):
|
||
|
|
return list(data) # Convert tuples and sets to lists
|
||
|
|
else:
|
||
|
|
return str(data) # Fallback: convert other types to string
|
||
|
|
|
||
|
|
def _convert_message_role(self, role: str):
|
||
|
|
if "human" in role.lower():
|
||
|
|
return "user"
|
||
|
|
elif "system" in role.lower():
|
||
|
|
return "system"
|
||
|
|
elif "function" in role.lower():
|
||
|
|
return "function"
|
||
|
|
elif "tool" in role.lower():
|
||
|
|
return "tool"
|
||
|
|
else:
|
||
|
|
return "assistant"
|
||
|
|
|
||
|
|
def _convert_message_dict(
|
||
|
|
self,
|
||
|
|
message: Dict,
|
||
|
|
):
|
||
|
|
class_name = message["id"][-1]
|
||
|
|
kwargs = message.get("kwargs", {})
|
||
|
|
function_call = kwargs.get("additional_kwargs", {}).get("function_call")
|
||
|
|
|
||
|
|
msg = GenerationMessage(
|
||
|
|
role=self._convert_message_role(class_name),
|
||
|
|
content="",
|
||
|
|
)
|
||
|
|
if name := kwargs.get("name"):
|
||
|
|
msg["name"] = name
|
||
|
|
if function_call:
|
||
|
|
msg["function_call"] = function_call
|
||
|
|
else:
|
||
|
|
content = kwargs.get("content")
|
||
|
|
if isinstance(content, list):
|
||
|
|
tool_calls = []
|
||
|
|
content_parts = []
|
||
|
|
for item in content:
|
||
|
|
if item.get("type") == "tool_use":
|
||
|
|
tool_calls.append(
|
||
|
|
{
|
||
|
|
"id": item.get("id"),
|
||
|
|
"type": "function",
|
||
|
|
"function": {
|
||
|
|
"name": item.get("name"),
|
||
|
|
"arguments": item.get("input"),
|
||
|
|
},
|
||
|
|
}
|
||
|
|
)
|
||
|
|
elif item.get("type") == "text":
|
||
|
|
content_parts.append({"type": "text", "text": item.get("text")})
|
||
|
|
|
||
|
|
if tool_calls:
|
||
|
|
msg["tool_calls"] = tool_calls
|
||
|
|
if content_parts:
|
||
|
|
msg["content"] = content_parts # type: ignore
|
||
|
|
else:
|
||
|
|
msg["content"] = content # type: ignore
|
||
|
|
|
||
|
|
return msg
|
||
|
|
|
||
|
|
def _convert_message(
|
||
|
|
self,
|
||
|
|
message: Union[Dict, BaseMessage],
|
||
|
|
):
|
||
|
|
if isinstance(message, dict):
|
||
|
|
return self._convert_message_dict(
|
||
|
|
message,
|
||
|
|
)
|
||
|
|
|
||
|
|
function_call = message.additional_kwargs.get("function_call")
|
||
|
|
|
||
|
|
msg = GenerationMessage(
|
||
|
|
role=self._convert_message_role(message.type),
|
||
|
|
content="",
|
||
|
|
)
|
||
|
|
|
||
|
|
if literal_uuid := message.additional_kwargs.get("uuid"):
|
||
|
|
msg["uuid"] = literal_uuid
|
||
|
|
msg["templated"] = True
|
||
|
|
|
||
|
|
if name := getattr(message, "name", None):
|
||
|
|
msg["name"] = name
|
||
|
|
|
||
|
|
if function_call:
|
||
|
|
msg["function_call"] = function_call
|
||
|
|
else:
|
||
|
|
if isinstance(message.content, list):
|
||
|
|
tool_calls = []
|
||
|
|
content_parts = []
|
||
|
|
for item in message.content:
|
||
|
|
if isinstance(item, str):
|
||
|
|
continue
|
||
|
|
if item.get("type") == "tool_use":
|
||
|
|
tool_calls.append(
|
||
|
|
{
|
||
|
|
"id": item.get("id"),
|
||
|
|
"type": "function",
|
||
|
|
"function": {
|
||
|
|
"name": item.get("name"),
|
||
|
|
"arguments": item.get("input"),
|
||
|
|
},
|
||
|
|
}
|
||
|
|
)
|
||
|
|
elif item.get("type") == "text":
|
||
|
|
content_parts.append({"type": "text", "text": item.get("text")})
|
||
|
|
|
||
|
|
if tool_calls:
|
||
|
|
msg["tool_calls"] = tool_calls
|
||
|
|
if content_parts:
|
||
|
|
msg["content"] = content_parts # type: ignore
|
||
|
|
else:
|
||
|
|
msg["content"] = message.content # type: ignore
|
||
|
|
|
||
|
|
return msg
|
||
|
|
|
||
|
|
def _build_llm_settings(
|
||
|
|
self,
|
||
|
|
serialized: Dict,
|
||
|
|
invocation_params: Optional[Dict] = None,
|
||
|
|
):
|
||
|
|
# invocation_params = run.extra.get("invocation_params")
|
||
|
|
if invocation_params is None:
|
||
|
|
return None, None
|
||
|
|
|
||
|
|
provider = invocation_params.pop("_type", "") # type: str
|
||
|
|
|
||
|
|
model_kwargs = invocation_params.pop("model_kwargs", {})
|
||
|
|
|
||
|
|
if model_kwargs is None:
|
||
|
|
model_kwargs = {}
|
||
|
|
|
||
|
|
merged = {
|
||
|
|
**invocation_params,
|
||
|
|
**model_kwargs,
|
||
|
|
**serialized.get("kwargs", {}),
|
||
|
|
}
|
||
|
|
|
||
|
|
# make sure there is no api key specification
|
||
|
|
settings = {k: v for k, v in merged.items() if not k.endswith("_api_key")}
|
||
|
|
|
||
|
|
model_keys = ["azure_deployment", "deployment_name", "model", "model_name"]
|
||
|
|
model = next((settings[k] for k in model_keys if k in settings), None)
|
||
|
|
if isinstance(model, str):
|
||
|
|
model = model.replace("models/", "")
|
||
|
|
tools = None
|
||
|
|
if "functions" in settings:
|
||
|
|
tools = [{"type": "function", "function": f} for f in settings["functions"]]
|
||
|
|
if "tools" in settings:
|
||
|
|
tools = [
|
||
|
|
{"type": "function", "function": t}
|
||
|
|
if t.get("type") != "function"
|
||
|
|
else t
|
||
|
|
for t in settings["tools"]
|
||
|
|
]
|
||
|
|
return provider, model, tools, settings
|
||
|
|
|
||
|
|
|
||
|
|
def process_content(content: Any) -> Tuple[Dict | str, Optional[str]]:
|
||
|
|
if content is None:
|
||
|
|
return {}, None
|
||
|
|
if isinstance(content, str):
|
||
|
|
return {"content": content}, "text"
|
||
|
|
else:
|
||
|
|
return dumps(content), "json"
|
||
|
|
|
||
|
|
|
||
|
|
DEFAULT_TO_IGNORE = [
|
||
|
|
"RunnableSequence",
|
||
|
|
"RunnableParallel",
|
||
|
|
"RunnableAssign",
|
||
|
|
"RunnableLambda",
|
||
|
|
"<lambda>",
|
||
|
|
]
|
||
|
|
DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"]
|
||
|
|
|
||
|
|
|
||
|
|
class LangchainTracer(AsyncBaseTracer, GenerationHelper, FinalStreamHelper):
|
||
|
|
steps: Dict[str, Step]
|
||
|
|
parent_id_map: Dict[str, str]
|
||
|
|
ignored_runs: set
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
# Token sequence that prefixes the answer
|
||
|
|
answer_prefix_tokens: Optional[List[str]] = None,
|
||
|
|
# Should we stream the final answer?
|
||
|
|
stream_final_answer: bool = False,
|
||
|
|
# Should force stream the first response?
|
||
|
|
force_stream_final_answer: bool = False,
|
||
|
|
# Runs to ignore to enhance readability
|
||
|
|
to_ignore: Optional[List[str]] = None,
|
||
|
|
# Runs to keep within ignored runs
|
||
|
|
to_keep: Optional[List[str]] = None,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> None:
|
||
|
|
AsyncBaseTracer.__init__(self, **kwargs)
|
||
|
|
GenerationHelper.__init__(self)
|
||
|
|
FinalStreamHelper.__init__(
|
||
|
|
self,
|
||
|
|
answer_prefix_tokens=answer_prefix_tokens,
|
||
|
|
stream_final_answer=stream_final_answer,
|
||
|
|
force_stream_final_answer=force_stream_final_answer,
|
||
|
|
)
|
||
|
|
self.context = context_var.get()
|
||
|
|
self.steps = {}
|
||
|
|
self.parent_id_map = {}
|
||
|
|
self.ignored_runs = set()
|
||
|
|
|
||
|
|
if self.context.current_step:
|
||
|
|
self.root_parent_id = self.context.current_step.id
|
||
|
|
else:
|
||
|
|
self.root_parent_id = None
|
||
|
|
|
||
|
|
if to_ignore is None:
|
||
|
|
self.to_ignore = DEFAULT_TO_IGNORE
|
||
|
|
else:
|
||
|
|
self.to_ignore = to_ignore
|
||
|
|
|
||
|
|
if to_keep is None:
|
||
|
|
self.to_keep = DEFAULT_TO_KEEP
|
||
|
|
else:
|
||
|
|
self.to_keep = to_keep
|
||
|
|
|
||
|
|
async def on_chat_model_start(
|
||
|
|
self,
|
||
|
|
serialized: Dict[str, Any],
|
||
|
|
messages: List[List[BaseMessage]],
|
||
|
|
*,
|
||
|
|
run_id: "UUID",
|
||
|
|
parent_run_id: Optional["UUID"] = None,
|
||
|
|
tags: Optional[List[str]] = None,
|
||
|
|
metadata: Optional[Dict[str, Any]] = None,
|
||
|
|
name: Optional[str] = None,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> Run:
|
||
|
|
lc_messages = messages[0]
|
||
|
|
self.chat_generations[str(run_id)] = {
|
||
|
|
"input_messages": lc_messages,
|
||
|
|
"start": time.time(),
|
||
|
|
"token_count": 0,
|
||
|
|
"tt_first_token": None,
|
||
|
|
}
|
||
|
|
|
||
|
|
return await super().on_chat_model_start(
|
||
|
|
serialized,
|
||
|
|
messages,
|
||
|
|
run_id=run_id,
|
||
|
|
parent_run_id=parent_run_id,
|
||
|
|
tags=tags,
|
||
|
|
metadata=metadata,
|
||
|
|
name=name,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def on_llm_start(
|
||
|
|
self,
|
||
|
|
serialized: Dict[str, Any],
|
||
|
|
prompts: List[str],
|
||
|
|
*,
|
||
|
|
run_id: "UUID",
|
||
|
|
parent_run_id: Optional[UUID] = None,
|
||
|
|
tags: Optional[List[str]] = None,
|
||
|
|
metadata: Optional[Dict[str, Any]] = None,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> None:
|
||
|
|
await super().on_llm_start(
|
||
|
|
serialized,
|
||
|
|
prompts,
|
||
|
|
run_id=run_id,
|
||
|
|
parent_run_id=parent_run_id,
|
||
|
|
tags=tags,
|
||
|
|
metadata=metadata,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.completion_generations[str(run_id)] = {
|
||
|
|
"prompt": prompts[0],
|
||
|
|
"start": time.time(),
|
||
|
|
"token_count": 0,
|
||
|
|
"tt_first_token": None,
|
||
|
|
}
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def on_llm_new_token(
|
||
|
|
self,
|
||
|
|
token: str,
|
||
|
|
*,
|
||
|
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||
|
|
run_id: "UUID",
|
||
|
|
parent_run_id: Optional["UUID"] = None,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> None:
|
||
|
|
await super().on_llm_new_token(
|
||
|
|
token=token,
|
||
|
|
chunk=chunk,
|
||
|
|
run_id=run_id,
|
||
|
|
parent_run_id=parent_run_id,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
if isinstance(chunk, ChatGenerationChunk):
|
||
|
|
start = self.chat_generations[str(run_id)]
|
||
|
|
else:
|
||
|
|
start = self.completion_generations[str(run_id)] # type: ignore
|
||
|
|
start["token_count"] += 1
|
||
|
|
if start["tt_first_token"] is None:
|
||
|
|
start["tt_first_token"] = (time.time() - start["start"]) * 1000
|
||
|
|
|
||
|
|
# Process token to ensure it's a string, as strip() will be called on it.
|
||
|
|
processed_token: str
|
||
|
|
# Handle case where token is a list (can occur with some model outputs).
|
||
|
|
# Join all elements into a single string to maintain compatibility with downstream processing.
|
||
|
|
if isinstance(token, list):
|
||
|
|
# If token is a list, join its elements (converted to strings) into a single string.
|
||
|
|
processed_token = "".join(map(str, token))
|
||
|
|
elif not isinstance(token, str):
|
||
|
|
# If token is neither a list nor a string, convert it to a string.
|
||
|
|
processed_token = str(token)
|
||
|
|
else:
|
||
|
|
# If token is already a string, use it as is.
|
||
|
|
processed_token = token
|
||
|
|
|
||
|
|
if self.stream_final_answer:
|
||
|
|
self._append_to_last_tokens(processed_token)
|
||
|
|
|
||
|
|
if self.answer_reached:
|
||
|
|
if not self.final_stream:
|
||
|
|
self.final_stream = Message(content="")
|
||
|
|
await self.final_stream.send()
|
||
|
|
await self.final_stream.stream_token(processed_token)
|
||
|
|
self.has_streamed_final_answer = True
|
||
|
|
else:
|
||
|
|
self.answer_reached = self._check_if_answer_reached()
|
||
|
|
|
||
|
|
async def _persist_run(self, run: Run) -> None:
|
||
|
|
pass
|
||
|
|
|
||
|
|
def _get_run_parent_id(self, run: Run):
|
||
|
|
parent_id = str(run.parent_run_id) if run.parent_run_id else self.root_parent_id
|
||
|
|
|
||
|
|
return parent_id
|
||
|
|
|
||
|
|
def _get_non_ignored_parent_id(self, current_parent_id: Optional[str] = None):
|
||
|
|
if not current_parent_id:
|
||
|
|
return self.root_parent_id
|
||
|
|
|
||
|
|
if current_parent_id not in self.parent_id_map:
|
||
|
|
return None
|
||
|
|
|
||
|
|
while current_parent_id in self.parent_id_map:
|
||
|
|
# If the parent id is in the ignored runs, we need to get the parent id of the ignored run
|
||
|
|
if current_parent_id in self.ignored_runs:
|
||
|
|
current_parent_id = self.parent_id_map[current_parent_id]
|
||
|
|
else:
|
||
|
|
return current_parent_id
|
||
|
|
|
||
|
|
return self.root_parent_id
|
||
|
|
|
||
|
|
def _should_ignore_run(self, run: Run):
|
||
|
|
parent_id = self._get_run_parent_id(run)
|
||
|
|
|
||
|
|
if parent_id:
|
||
|
|
# Add the parent id of the ignored run in the mapping
|
||
|
|
# so we can re-attach a kept child to the right parent id
|
||
|
|
self.parent_id_map[str(run.id)] = parent_id
|
||
|
|
|
||
|
|
ignore_by_name = False
|
||
|
|
ignore_by_parent = parent_id in self.ignored_runs
|
||
|
|
|
||
|
|
for filter in self.to_ignore:
|
||
|
|
if filter in run.name:
|
||
|
|
ignore_by_name = True
|
||
|
|
break
|
||
|
|
|
||
|
|
ignore = ignore_by_name or ignore_by_parent
|
||
|
|
|
||
|
|
# If the ignore cause is the parent being ignored, check if we should nonetheless keep the child
|
||
|
|
if ignore_by_parent and not ignore_by_name and run.run_type in self.to_keep:
|
||
|
|
return False, self._get_non_ignored_parent_id(parent_id)
|
||
|
|
else:
|
||
|
|
if ignore:
|
||
|
|
# Tag the run as ignored
|
||
|
|
self.ignored_runs.add(str(run.id))
|
||
|
|
return ignore, parent_id
|
||
|
|
|
||
|
|
async def _start_trace(self, run: Run) -> None:
|
||
|
|
await super()._start_trace(run)
|
||
|
|
context_var.set(self.context)
|
||
|
|
|
||
|
|
ignore, parent_id = self._should_ignore_run(run)
|
||
|
|
|
||
|
|
if run.run_type in ["chain", "prompt"]:
|
||
|
|
self.generation_inputs[str(run.id)] = self.ensure_values_serializable(
|
||
|
|
run.inputs
|
||
|
|
)
|
||
|
|
|
||
|
|
if ignore:
|
||
|
|
return
|
||
|
|
|
||
|
|
step_type: TrueStepType = "undefined"
|
||
|
|
if run.run_type == "agent":
|
||
|
|
step_type = "run"
|
||
|
|
elif run.run_type == "chain":
|
||
|
|
if not self.steps:
|
||
|
|
step_type = "run"
|
||
|
|
elif run.run_type == "llm":
|
||
|
|
step_type = "llm"
|
||
|
|
elif run.run_type == "retriever":
|
||
|
|
step_type = "tool"
|
||
|
|
elif run.run_type == "tool":
|
||
|
|
step_type = "tool"
|
||
|
|
elif run.run_type == "embedding":
|
||
|
|
step_type = "embedding"
|
||
|
|
|
||
|
|
step = Step(
|
||
|
|
id=str(run.id),
|
||
|
|
name=run.name,
|
||
|
|
type=step_type,
|
||
|
|
parent_id=parent_id,
|
||
|
|
)
|
||
|
|
step.start = utc_now()
|
||
|
|
if step_type != "llm":
|
||
|
|
step.input, language = process_content(run.inputs)
|
||
|
|
step.show_input = language or False
|
||
|
|
|
||
|
|
step.tags = run.tags
|
||
|
|
self.steps[str(run.id)] = step
|
||
|
|
|
||
|
|
await step.send()
|
||
|
|
|
||
|
|
async def _on_run_update(self, run: Run) -> None:
|
||
|
|
"""Process a run upon update."""
|
||
|
|
context_var.set(self.context)
|
||
|
|
|
||
|
|
ignore, _parent_id = self._should_ignore_run(run)
|
||
|
|
|
||
|
|
if ignore:
|
||
|
|
return
|
||
|
|
|
||
|
|
current_step = self.steps.get(str(run.id), None)
|
||
|
|
|
||
|
|
if run.run_type == "llm" and current_step:
|
||
|
|
provider, model, tools, llm_settings = self._build_llm_settings(
|
||
|
|
(run.serialized or {}), (run.extra or {}).get("invocation_params")
|
||
|
|
)
|
||
|
|
generations = (run.outputs or {}).get("generations", [])
|
||
|
|
generation = generations[0][0]
|
||
|
|
variables = self.generation_inputs.get(str(run.parent_run_id), {})
|
||
|
|
variables = {k: str(v) for k, v in variables.items() if v is not None}
|
||
|
|
if message := generation.get("message"):
|
||
|
|
chat_start = self.chat_generations[str(run.id)]
|
||
|
|
duration = time.time() - chat_start["start"]
|
||
|
|
if duration and chat_start["token_count"]:
|
||
|
|
throughput = chat_start["token_count"] / duration
|
||
|
|
else:
|
||
|
|
throughput = None
|
||
|
|
message_completion = self._convert_message(message)
|
||
|
|
current_step.generation = ChatGeneration(
|
||
|
|
provider=provider,
|
||
|
|
model=model,
|
||
|
|
tools=tools,
|
||
|
|
variables=variables,
|
||
|
|
settings=llm_settings,
|
||
|
|
duration=duration,
|
||
|
|
token_throughput_in_s=throughput,
|
||
|
|
tt_first_token=chat_start.get("tt_first_token"),
|
||
|
|
messages=[
|
||
|
|
self._convert_message(m) for m in chat_start["input_messages"]
|
||
|
|
],
|
||
|
|
message_completion=message_completion,
|
||
|
|
)
|
||
|
|
|
||
|
|
# find first message with prompt_id
|
||
|
|
for m in chat_start["input_messages"]:
|
||
|
|
if m.additional_kwargs.get("prompt_id"):
|
||
|
|
current_step.generation.prompt_id = m.additional_kwargs[
|
||
|
|
"prompt_id"
|
||
|
|
]
|
||
|
|
if custom_variables := m.additional_kwargs.get("variables"):
|
||
|
|
current_step.generation.variables = {
|
||
|
|
k: str(v)
|
||
|
|
for k, v in custom_variables.items()
|
||
|
|
if v is not None
|
||
|
|
}
|
||
|
|
break
|
||
|
|
|
||
|
|
current_step.language = "json"
|
||
|
|
else:
|
||
|
|
completion_start = self.completion_generations[str(run.id)]
|
||
|
|
completion = generation.get("text", "")
|
||
|
|
duration = time.time() - completion_start["start"]
|
||
|
|
if duration and completion_start["token_count"]:
|
||
|
|
throughput = completion_start["token_count"] / duration
|
||
|
|
else:
|
||
|
|
throughput = None
|
||
|
|
current_step.generation = CompletionGeneration(
|
||
|
|
provider=provider,
|
||
|
|
model=model,
|
||
|
|
settings=llm_settings,
|
||
|
|
variables=variables,
|
||
|
|
duration=duration,
|
||
|
|
token_throughput_in_s=throughput,
|
||
|
|
tt_first_token=completion_start.get("tt_first_token"),
|
||
|
|
prompt=completion_start["prompt"],
|
||
|
|
completion=completion,
|
||
|
|
)
|
||
|
|
current_step.output = completion
|
||
|
|
|
||
|
|
if current_step:
|
||
|
|
current_step.end = utc_now()
|
||
|
|
await current_step.update()
|
||
|
|
|
||
|
|
if self.final_stream and self.has_streamed_final_answer:
|
||
|
|
await self.final_stream.update()
|
||
|
|
|
||
|
|
return
|
||
|
|
|
||
|
|
if current_step:
|
||
|
|
if current_step.type != "llm":
|
||
|
|
current_step.output, current_step.language = process_content(
|
||
|
|
run.outputs
|
||
|
|
)
|
||
|
|
current_step.end = utc_now()
|
||
|
|
await current_step.update()
|
||
|
|
|
||
|
|
async def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
|
||
|
|
context_var.set(self.context)
|
||
|
|
|
||
|
|
if current_step := self.steps.get(str(run_id), None):
|
||
|
|
current_step.is_error = True
|
||
|
|
current_step.output = str(error)
|
||
|
|
current_step.end = utc_now()
|
||
|
|
await current_step.update()
|
||
|
|
|
||
|
|
on_llm_error = _on_error
|
||
|
|
on_chain_error = _on_error
|
||
|
|
on_tool_error = _on_error
|
||
|
|
on_retriever_error = _on_error
|
||
|
|
|
||
|
|
|
||
|
|
LangchainCallbackHandler = LangchainTracer
|
||
|
|
AsyncLangchainCallbackHandler = LangchainTracer
|