ai-station/.venv/lib/python3.12/site-packages/chainlit/llama_index/callbacks.py

207 lines
7.1 KiB
Python

from typing import Any, Dict, List, Optional
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
from llama_index.core.callbacks import TokenCountingHandler
from llama_index.core.callbacks.schema import CBEventType, EventPayload
from llama_index.core.llms import ChatMessage, ChatResponse, CompletionResponse
from llama_index.core.tools.types import ToolMetadata
from chainlit.context import context_var
from chainlit.element import Text
from chainlit.step import Step, StepType
from chainlit.utils import utc_now
DEFAULT_IGNORE = [
CBEventType.CHUNKING,
CBEventType.SYNTHESIZE,
CBEventType.EMBEDDING,
CBEventType.NODE_PARSING,
CBEventType.TREE,
]
class LlamaIndexCallbackHandler(TokenCountingHandler):
"""Base callback handler that can be used to track event starts and ends."""
steps: Dict[str, Step]
def __init__(
self,
event_starts_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
event_ends_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
) -> None:
"""Initialize the base callback handler."""
super().__init__(
event_starts_to_ignore=event_starts_to_ignore,
event_ends_to_ignore=event_ends_to_ignore,
)
self.steps = {}
def _get_parent_id(self, event_parent_id: Optional[str] = None) -> Optional[str]:
if event_parent_id and event_parent_id in self.steps:
return event_parent_id
elif context_var.get().current_step:
return context_var.get().current_step.id
else:
return None
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""
step_type: StepType = "undefined"
step_name: str = event_type.value
step_input: Optional[Dict[str, Any]] = payload
if event_type == CBEventType.FUNCTION_CALL:
step_type = "tool"
if payload:
metadata: Optional[ToolMetadata] = payload.get(EventPayload.TOOL)
if metadata:
step_name = getattr(metadata, "name", step_name)
step_input = payload.get(EventPayload.FUNCTION_CALL)
elif event_type == CBEventType.RETRIEVE:
step_type = "tool"
elif event_type == CBEventType.QUERY:
step_type = "tool"
elif event_type == CBEventType.LLM:
step_type = "llm"
else:
return event_id
step = Step(
name=step_name,
type=step_type,
parent_id=self._get_parent_id(parent_id),
id=event_id,
)
self.steps[event_id] = step
step.start = utc_now()
step.input = step_input or {}
context_var.get().loop.create_task(step.send())
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Run when an event ends."""
step = self.steps.get(event_id, None)
if payload is None or step is None:
return
step.end = utc_now()
if event_type == CBEventType.FUNCTION_CALL:
response = payload.get(EventPayload.FUNCTION_OUTPUT)
if response:
step.output = f"{response}"
context_var.get().loop.create_task(step.update())
elif event_type == CBEventType.QUERY:
response = payload.get(EventPayload.RESPONSE)
source_nodes = getattr(response, "source_nodes", None)
if source_nodes:
source_refs = ", ".join(
[f"Source {idx}" for idx, _ in enumerate(source_nodes)]
)
step.elements = [
Text(
name=f"Source {idx}",
content=source.text or "Empty node",
display="side",
)
for idx, source in enumerate(source_nodes)
]
step.output = f"Retrieved the following sources: {source_refs}"
context_var.get().loop.create_task(step.update())
elif event_type == CBEventType.RETRIEVE:
sources = payload.get(EventPayload.NODES)
if sources:
source_refs = ", ".join(
[f"Source {idx}" for idx, _ in enumerate(sources)]
)
step.elements = [
Text(
name=f"Source {idx}",
display="side",
content=source.node.get_text() or "Empty node",
)
for idx, source in enumerate(sources)
]
step.output = f"Retrieved the following sources: {source_refs}"
context_var.get().loop.create_task(step.update())
elif event_type == CBEventType.LLM:
formatted_messages = payload.get(EventPayload.MESSAGES) # type: Optional[List[ChatMessage]]
formatted_prompt = payload.get(EventPayload.PROMPT)
response = payload.get(EventPayload.RESPONSE)
if formatted_messages:
messages = [
GenerationMessage(
role=m.role.value, # type: ignore
content=m.content or "",
)
for m in formatted_messages
]
else:
messages = None
if isinstance(response, ChatResponse):
content = response.message.content or ""
elif isinstance(response, CompletionResponse):
content = response.text
else:
content = ""
step.output = content
token_count = self.total_llm_token_count or None
raw_response = response.raw if response else None
model = getattr(raw_response, "model", None)
if messages and isinstance(response, ChatResponse):
msg: ChatMessage = response.message
step.generation = ChatGeneration(
model=model,
messages=messages,
message_completion=GenerationMessage(
role=msg.role.value, # type: ignore
content=content,
),
token_count=token_count,
)
elif formatted_prompt:
step.generation = CompletionGeneration(
model=model,
prompt=formatted_prompt,
completion=content,
token_count=token_count,
)
context_var.get().loop.create_task(step.update())
else:
step.output = payload
context_var.get().loop.create_task(step.update())
self.steps.pop(event_id, None)
def _noop(self, *args, **kwargs):
pass
start_trace = _noop
end_trace = _noop