ai-station/.venv/lib/python3.12/site-packages/literalai/instrumentation/llamaindex/event_handler.py

411 lines
15 KiB
Python
Raw Permalink Normal View History

import logging
import uuid
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.base.response.schema import Response, StreamingResponse
from llama_index.core.instrumentation.event_handlers import BaseEventHandler
from llama_index.core.instrumentation.events import BaseEvent
from llama_index.core.instrumentation.events.agent import (
AgentChatWithStepEndEvent,
AgentChatWithStepStartEvent,
AgentRunStepEndEvent,
AgentRunStepStartEvent,
)
from llama_index.core.instrumentation.events.embedding import (
EmbeddingEndEvent,
EmbeddingStartEvent,
)
from llama_index.core.instrumentation.events.llm import (
LLMChatEndEvent,
LLMChatStartEvent,
)
from llama_index.core.instrumentation.events.query import QueryEndEvent, QueryStartEvent
from llama_index.core.instrumentation.events.retrieval import (
RetrievalEndEvent,
RetrievalStartEvent,
)
from llama_index.core.instrumentation.events.synthesis import SynthesizeEndEvent
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from pydantic import PrivateAttr
from literalai.context import active_thread_var
from literalai.instrumentation.llamaindex.span_handler import LiteralSpanHandler
from literalai.observability.generation import (
ChatGeneration,
GenerationMessage,
GenerationMessageRole,
)
from literalai.observability.step import Step, StepType
if TYPE_CHECKING:
from literalai.client import LiteralClient
def convert_message_role(role: MessageRole) -> GenerationMessageRole:
mapping = {
MessageRole.SYSTEM: "system",
MessageRole.USER: "user",
MessageRole.ASSISTANT: "assistant",
MessageRole.FUNCTION: "function",
MessageRole.TOOL: "tool",
MessageRole.CHATBOT: "assistant",
MessageRole.MODEL: "assistant",
}
return cast(GenerationMessageRole, mapping.get(role, "user"))
def extract_query_from_bundle(str_or_query_bundle: Union[str, QueryBundle]):
if isinstance(str_or_query_bundle, QueryBundle):
return str_or_query_bundle.query_str
return str_or_query_bundle
def extract_document_info(nodes: List[NodeWithScore]):
if not nodes:
return []
return [
{
"id_": node.id_,
"metadata": node.metadata,
"text": node.get_text(),
"mimetype": node.node.mimetype,
"start_char_idx": node.node.start_char_idx,
"end_char_idx": node.node.end_char_idx,
"char_length": (
node.node.end_char_idx - node.node.start_char_idx
if node.node.end_char_idx is not None
and node.node.start_char_idx is not None
else None
),
"score": node.get_score(),
}
for node in nodes
if isinstance(node.node, TextNode)
]
def build_message_dict(message: ChatMessage):
message_dict: GenerationMessage = {
"role": convert_message_role(message.role),
"content": message.content,
}
kwargs = message.additional_kwargs
if kwargs:
if kwargs.get("tool_call_id", None):
message_dict["tool_call_id"] = kwargs.get("tool_call_id")
if kwargs.get("name", None):
message_dict["name"] = kwargs.get("name")
tool_calls = kwargs.get("tool_calls", [])
if len(tool_calls) > 0:
message_dict["tool_calls"] = [
tool_call.to_dict() for tool_call in tool_calls
]
return message_dict
def create_generation(event: LLMChatStartEvent):
model_dict = event.model_dict
return ChatGeneration(
provider=model_dict.get("class_name"),
model=model_dict.get("model"),
settings={
"model": model_dict.get("model"),
"temperature": model_dict.get("temperature"),
"max_tokens": model_dict.get("max_tokens"),
"logprobs": model_dict.get("logprobs"),
"top_logprobs": model_dict.get("top_logprobs"),
},
messages=[build_message_dict(message) for message in event.messages],
)
def extract_query(x: Union[str, QueryBundle]):
return x.query_str if isinstance(x, QueryBundle) else x
class LiteralEventHandler(BaseEventHandler):
"""This class handles events coming from LlamaIndex."""
_client: "LiteralClient" = PrivateAttr()
_span_handler: "LiteralSpanHandler" = PrivateAttr()
runs: Dict[str, List[Step]] = {}
streaming_run_ids: List[str] = []
_standalone_step_id: Optional[str] = None
open_runs: List[Step] = []
class Config:
arbitrary_types_allowed = True
def __init__(
self,
literal_client: "LiteralClient",
llama_index_span_handler: "LiteralSpanHandler",
):
super().__init__()
object.__setattr__(self, "_client", literal_client)
object.__setattr__(self, "_span_handler", llama_index_span_handler)
def _convert_message(self, message: ChatMessage):
tool_calls = message.additional_kwargs.get("tool_calls")
msg: GenerationMessage = {
"name": getattr(message, "name", None),
"role": convert_message_role(message.role),
"content": message.content,
"tool_calls": (
[tool_call.to_dict() for tool_call in tool_calls]
if tool_calls
else None
),
}
return msg
def handle(self, event: BaseEvent, **kwargs) -> None:
"""Logic for handling event."""
try:
thread_id = self._span_handler.get_thread_id(event.span_id)
run_id = self._span_handler.get_run_id(event.span_id)
# AgentChatWithStep wraps several AgentRunStep events
# as the agent may want to perform multiple tool calls in a row.
if isinstance(event, AgentChatWithStepStartEvent) or isinstance(
event, AgentRunStepStartEvent
):
run_name = (
"Agent Chat"
if isinstance(event, AgentChatWithStepStartEvent)
else "Agent Step"
)
parent_run_id = None
if len(self.open_runs) > 0:
parent_run_id = self.open_runs[-1].id
agent_run_id = str(uuid.uuid4())
run = self._client.start_step(
name=run_name,
type="run",
id=agent_run_id,
parent_id=parent_run_id,
)
self.open_runs.append(run)
if isinstance(event, AgentChatWithStepEndEvent) or isinstance(
event, AgentRunStepEndEvent
):
try:
step = self.open_runs.pop()
except IndexError:
logging.error(
"[Literal] Error in Llamaindex instrumentation: AgentRunStepEndEvent called without an open run."
)
if step:
step.end()
if isinstance(event, QueryStartEvent):
active_thread = active_thread_var.get()
query = extract_query(event.query)
if not active_thread or not active_thread.name:
self._client.api.upsert_thread(id=thread_id, name=query)
self._client.message(
name="User query",
id=str(event.id_),
type="user_message",
thread_id=thread_id,
content=query,
)
# Retrieval wraps the Embedding step in LlamaIndex
if isinstance(event, RetrievalStartEvent):
run = self._client.start_step(
name="RAG",
type="run",
id=run_id,
thread_id=thread_id,
)
self.store_step(run_id=run_id, step=run)
retrieval_step = self._client.start_step(
type="retrieval",
name="Retrieval",
parent_id=run_id,
thread_id=thread_id,
)
self.store_step(run_id=run_id, step=retrieval_step)
if isinstance(event, EmbeddingStartEvent):
retrieval_step = self.get_first_step_of_type(
run_id=run_id, step_type="retrieval"
)
if run_id and retrieval_step:
embedding_step = self._client.start_step(
type="embedding",
name="Embedding",
parent_id=retrieval_step.id,
thread_id=thread_id,
)
embedding_step.metadata = event.model_dict
self.store_step(run_id=run_id, step=embedding_step)
if isinstance(event, EmbeddingEndEvent):
embedding_step = self.get_first_step_of_type(
run_id=run_id, step_type="embedding"
)
if run_id and embedding_step:
embedding_step.input = {"query": event.chunks}
embedding_step.output = {"embeddings": event.embeddings}
embedding_step.end()
if isinstance(event, RetrievalEndEvent):
retrieval_step = self.get_first_step_of_type(
run_id=run_id, step_type="retrieval"
)
if run_id and retrieval_step:
retrieved_documents = extract_document_info(event.nodes)
query = extract_query(event.str_or_query_bundle)
retrieval_step.input = {"query": query}
retrieval_step.output = {"retrieved_documents": retrieved_documents}
retrieval_step.end()
# Only event where we create LLM steps
if isinstance(event, LLMChatStartEvent):
if run_id:
self._client.step()
llm_step = self._client.start_step(
type="llm",
parent_id=run_id,
thread_id=thread_id,
)
self.store_step(run_id=run_id, step=llm_step)
llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm")
if not run_id and not llm_step:
self._standalone_step_id = str(uuid.uuid4())
llm_step = self._client.start_step(
name=event.model_dict.get("model", "LLM"),
type="llm",
id=self._standalone_step_id,
# Remove thread_id for standalone runs
)
self.store_step(run_id=self._standalone_step_id, step=llm_step)
if llm_step:
generation = create_generation(event=event)
llm_step.generation = generation
llm_step.name = event.model_dict.get("model")
# Actual creation of the event happens upon ending the event
if isinstance(event, LLMChatEndEvent):
llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm")
if not llm_step and self._standalone_step_id:
llm_step = self.get_first_step_of_type(
run_id=self._standalone_step_id, step_type="llm"
)
response = event.response
if llm_step and response:
chat_completion = response.raw
# ChatCompletionChunk needed for chat stream methods
if isinstance(chat_completion, ChatCompletion) or isinstance(
chat_completion, ChatCompletionChunk
):
usage = chat_completion.usage
if isinstance(usage, CompletionUsage):
llm_step.generation.input_token_count = usage.prompt_tokens
llm_step.generation.output_token_count = (
usage.completion_tokens
)
if self._standalone_step_id:
llm_step.generation.message_completion = (
self._convert_message(response.message)
)
llm_step.end()
self._standalone_step_id = None
if isinstance(event, SynthesizeEndEvent):
llm_step = self.get_first_step_of_type(run_id=run_id, step_type="llm")
run = self.get_first_step_of_type(run_id=run_id, step_type="run")
if llm_step and run:
synthesized_response = event.response
text_response = ""
if isinstance(synthesized_response, StreamingResponse):
text_response = str(synthesized_response.get_response())
if isinstance(synthesized_response, Response):
text_response = str(synthesized_response)
llm_step.generation.message_completion = {
"role": "assistant",
"content": text_response,
}
llm_step.end()
run.end()
self._client.message(
type="assistant_message",
thread_id=thread_id,
content=text_response,
)
if isinstance(event, QueryEndEvent):
if run_id in self.runs:
del self.runs[run_id]
except Exception as e:
logging.error(
"[Literal] Error in Llamaindex instrumentation : %s",
str(e),
exc_info=True,
)
def store_step(self, run_id: str, step: Step):
if run_id not in self.runs:
self.runs[run_id] = []
self.runs[run_id].append(step)
def get_first_step_of_type(
self, run_id: Optional[str], step_type: StepType
) -> Optional[Step]:
if not run_id:
return None
if run_id not in self.runs:
return None
for step in self.runs[run_id]:
if step.type == step_type:
return step
return None
@classmethod
def class_name(cls) -> str:
"""Class name."""
return "LiteralEventHandler"