912 lines
30 KiB
Python
912 lines
30 KiB
Python
import json
|
|
import time
|
|
from typing import Any, Dict, List, Optional, Type, Union
|
|
from uuid import UUID
|
|
|
|
from langchain_core.callbacks import (
|
|
BaseCallbackHandler,
|
|
CallbackManager,
|
|
AsyncCallbackManager,
|
|
)
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
BaseMessage,
|
|
HumanMessage,
|
|
HumanMessageChunk,
|
|
SystemMessage,
|
|
SystemMessageChunk,
|
|
ToolMessage,
|
|
ToolMessageChunk,
|
|
)
|
|
from langchain_core.outputs import (
|
|
ChatGeneration,
|
|
ChatGenerationChunk,
|
|
Generation,
|
|
GenerationChunk,
|
|
LLMResult,
|
|
)
|
|
from opentelemetry import context as context_api
|
|
from opentelemetry.instrumentation.langchain.config import Config
|
|
from opentelemetry.instrumentation.langchain.event_emitter import emit_event
|
|
from opentelemetry.instrumentation.langchain.event_models import (
|
|
ChoiceEvent,
|
|
MessageEvent,
|
|
ToolCall,
|
|
)
|
|
from opentelemetry.instrumentation.langchain.span_utils import (
|
|
SpanHolder,
|
|
_set_span_attribute,
|
|
extract_model_name_from_response_metadata,
|
|
_extract_model_name_from_association_metadata,
|
|
set_chat_request,
|
|
set_chat_response,
|
|
set_chat_response_usage,
|
|
set_llm_request,
|
|
set_request_params,
|
|
)
|
|
from opentelemetry.instrumentation.langchain.vendor_detection import (
|
|
detect_vendor_from_class,
|
|
)
|
|
from opentelemetry.instrumentation.langchain.utils import (
|
|
CallbackFilteredJSONEncoder,
|
|
dont_throw,
|
|
should_emit_events,
|
|
should_send_prompts,
|
|
)
|
|
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
|
|
from opentelemetry.metrics import Histogram
|
|
from opentelemetry.semconv._incubating.attributes import (
|
|
gen_ai_attributes as GenAIAttributes,
|
|
)
|
|
from opentelemetry.semconv_ai import (
|
|
SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY,
|
|
LLMRequestTypeValues,
|
|
SpanAttributes,
|
|
TraceloopSpanKindValues,
|
|
)
|
|
from opentelemetry.trace import SpanKind, Tracer, set_span_in_context
|
|
from opentelemetry.trace.span import Span
|
|
from opentelemetry.trace.status import Status, StatusCode
|
|
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
|
|
|
|
|
|
def _extract_class_name_from_serialized(serialized: Optional[dict[str, Any]]) -> str:
|
|
"""
|
|
Extract class name from serialized model information.
|
|
|
|
Args:
|
|
serialized: Serialized model information from LangChain callback
|
|
|
|
Returns:
|
|
Class name string, or empty string if not found
|
|
"""
|
|
class_id = (serialized or {}).get("id", [])
|
|
if isinstance(class_id, list) and len(class_id) > 0:
|
|
return class_id[-1]
|
|
elif class_id:
|
|
return str(class_id)
|
|
else:
|
|
return ""
|
|
|
|
|
|
def _sanitize_metadata_value(value: Any) -> Any:
|
|
"""Convert metadata values to OpenTelemetry-compatible types."""
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, (bool, str, bytes, int, float)):
|
|
return value
|
|
if isinstance(value, (list, tuple)):
|
|
return [str(_sanitize_metadata_value(v)) for v in value]
|
|
# Convert other types to strings
|
|
return str(value)
|
|
|
|
|
|
def valid_role(role: str) -> bool:
|
|
return role in ["user", "assistant", "system", "tool"]
|
|
|
|
|
|
def get_message_role(message: Type[BaseMessage]) -> str:
|
|
if isinstance(message, (SystemMessage, SystemMessageChunk)):
|
|
return "system"
|
|
elif isinstance(message, (HumanMessage, HumanMessageChunk)):
|
|
return "user"
|
|
elif isinstance(message, (AIMessage, AIMessageChunk)):
|
|
return "assistant"
|
|
elif isinstance(message, (ToolMessage, ToolMessageChunk)):
|
|
return "tool"
|
|
else:
|
|
return "unknown"
|
|
|
|
|
|
def _extract_tool_call_data(
|
|
tool_calls: Optional[List[dict[str, Any]]],
|
|
) -> Union[List[ToolCall], None]:
|
|
if tool_calls is None:
|
|
return tool_calls
|
|
|
|
response = []
|
|
|
|
for tool_call in tool_calls:
|
|
tool_call_function = {"name": tool_call.get("name", "")}
|
|
|
|
if tool_call.get("arguments"):
|
|
tool_call_function["arguments"] = tool_call["arguments"]
|
|
elif tool_call.get("args"):
|
|
tool_call_function["arguments"] = tool_call["args"]
|
|
response.append(
|
|
ToolCall(
|
|
id=tool_call.get("id", ""),
|
|
function=tool_call_function,
|
|
type="function",
|
|
)
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
class TraceloopCallbackHandler(BaseCallbackHandler):
|
|
def __init__(
|
|
self, tracer: Tracer, duration_histogram: Histogram, token_histogram: Histogram
|
|
) -> None:
|
|
super().__init__()
|
|
self.tracer = tracer
|
|
self.duration_histogram = duration_histogram
|
|
self.token_histogram = token_histogram
|
|
self.spans: dict[UUID, SpanHolder] = {}
|
|
self.run_inline = True
|
|
self._callback_manager: CallbackManager | AsyncCallbackManager = None
|
|
|
|
@staticmethod
|
|
def _get_name_from_callback(
|
|
serialized: dict[str, Any],
|
|
_tags: Optional[list[str]] = None,
|
|
_metadata: Optional[dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Get the name to be used for the span. Based on heuristic. Can be extended."""
|
|
if serialized and "kwargs" in serialized and serialized["kwargs"].get("name"):
|
|
return serialized["kwargs"]["name"]
|
|
if kwargs.get("name"):
|
|
return kwargs["name"]
|
|
if serialized.get("name"):
|
|
return serialized["name"]
|
|
if "id" in serialized:
|
|
return serialized["id"][-1]
|
|
|
|
return "unknown"
|
|
|
|
def _get_span(self, run_id: UUID) -> Span:
|
|
return self.spans[run_id].span
|
|
|
|
def _end_span(self, span: Span, run_id: UUID) -> None:
|
|
for child_id in self.spans[run_id].children:
|
|
if child_id in self.spans:
|
|
child_span = self.spans[child_id].span
|
|
if child_span.end_time is None: # avoid warning on ended spans
|
|
child_span.end()
|
|
span.end()
|
|
token = self.spans[run_id].token
|
|
if token:
|
|
self._safe_detach_context(token)
|
|
|
|
del self.spans[run_id]
|
|
|
|
def _safe_attach_context(self, span: Span):
|
|
"""
|
|
Safely attach span to context, handling potential failures in async scenarios.
|
|
|
|
Returns the context token for later detachment, or None if attachment fails.
|
|
"""
|
|
try:
|
|
return context_api.attach(set_span_in_context(span))
|
|
except Exception:
|
|
# Context attachment can fail in some edge cases, particularly in
|
|
# complex async scenarios or when context is corrupted.
|
|
# Return None to indicate no token needs to be detached later.
|
|
return None
|
|
|
|
def _safe_detach_context(self, token):
|
|
"""
|
|
Safely detach context token without causing application crashes.
|
|
|
|
This method implements a fail-safe approach to context detachment that handles
|
|
all known edge cases in async/concurrent scenarios where context tokens may
|
|
become invalid or be detached in different execution contexts.
|
|
|
|
We use the runtime context directly to avoid logging errors from context_api.detach()
|
|
"""
|
|
if not token:
|
|
return
|
|
|
|
try:
|
|
# Use the runtime context directly to avoid error logging from context_api.detach()
|
|
from opentelemetry.context import _RUNTIME_CONTEXT
|
|
|
|
_RUNTIME_CONTEXT.detach(token)
|
|
except Exception:
|
|
# Context detach can fail in async scenarios when tokens are created in different contexts
|
|
# This includes ValueError, RuntimeError, and other context-related exceptions
|
|
# This is expected behavior and doesn't affect the correct span hierarchy
|
|
#
|
|
# Common scenarios where this happens:
|
|
# 1. Token created in one async task/thread, detached in another
|
|
# 2. Context was already detached by another process
|
|
# 3. Token became invalid due to context switching
|
|
# 4. Race conditions in highly concurrent scenarios
|
|
#
|
|
# This is safe to ignore as the span itself was properly ended
|
|
# and the tracing data is correctly captured.
|
|
pass
|
|
|
|
def _create_span(
|
|
self,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID],
|
|
span_name: str,
|
|
kind: SpanKind = SpanKind.INTERNAL,
|
|
workflow_name: str = "",
|
|
entity_name: str = "",
|
|
entity_path: str = "",
|
|
metadata: Optional[dict[str, Any]] = None,
|
|
) -> Span:
|
|
if metadata is not None:
|
|
current_association_properties = (
|
|
context_api.get_value("association_properties") or {}
|
|
)
|
|
# Sanitize metadata values to ensure they're compatible with OpenTelemetry
|
|
sanitized_metadata = {
|
|
k: _sanitize_metadata_value(v)
|
|
for k, v in metadata.items()
|
|
if v is not None
|
|
}
|
|
try:
|
|
context_api.attach(
|
|
context_api.set_value(
|
|
"association_properties",
|
|
{**current_association_properties, **sanitized_metadata},
|
|
)
|
|
)
|
|
except Exception:
|
|
# If setting association properties fails, continue without them
|
|
# This doesn't affect the core span functionality
|
|
pass
|
|
|
|
if parent_run_id is not None and parent_run_id in self.spans:
|
|
span = self.tracer.start_span(
|
|
span_name,
|
|
context=set_span_in_context(self.spans[parent_run_id].span),
|
|
kind=kind,
|
|
)
|
|
else:
|
|
span = self.tracer.start_span(span_name, kind=kind)
|
|
|
|
token = self._safe_attach_context(span)
|
|
|
|
_set_span_attribute(span, SpanAttributes.TRACELOOP_WORKFLOW_NAME, workflow_name)
|
|
_set_span_attribute(span, SpanAttributes.TRACELOOP_ENTITY_PATH, entity_path)
|
|
|
|
# Set metadata as span attributes if available
|
|
if metadata is not None:
|
|
for key, value in sanitized_metadata.items():
|
|
_set_span_attribute(
|
|
span,
|
|
f"{Config.metadata_key_prefix}.{key}",
|
|
value,
|
|
)
|
|
|
|
self.spans[run_id] = SpanHolder(
|
|
span, token, None, [], workflow_name, entity_name, entity_path
|
|
)
|
|
|
|
if parent_run_id is not None and parent_run_id in self.spans:
|
|
self.spans[parent_run_id].children.append(run_id)
|
|
|
|
return span
|
|
|
|
def _create_task_span(
|
|
self,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID],
|
|
name: str,
|
|
kind: TraceloopSpanKindValues,
|
|
workflow_name: str,
|
|
entity_name: str = "",
|
|
entity_path: str = "",
|
|
metadata: Optional[dict[str, Any]] = None,
|
|
) -> Span:
|
|
span_name = f"{name}.{kind.value}"
|
|
span = self._create_span(
|
|
run_id,
|
|
parent_run_id,
|
|
span_name,
|
|
workflow_name=workflow_name,
|
|
entity_name=entity_name,
|
|
entity_path=entity_path,
|
|
metadata=metadata,
|
|
)
|
|
|
|
_set_span_attribute(span, SpanAttributes.TRACELOOP_SPAN_KIND, kind.value)
|
|
_set_span_attribute(span, SpanAttributes.TRACELOOP_ENTITY_NAME, entity_name)
|
|
|
|
return span
|
|
|
|
def _create_llm_span(
|
|
self,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID],
|
|
name: str,
|
|
request_type: LLMRequestTypeValues,
|
|
metadata: Optional[dict[str, Any]] = None,
|
|
serialized: Optional[dict[str, Any]] = None,
|
|
) -> Span:
|
|
workflow_name = self.get_workflow_name(parent_run_id)
|
|
entity_path = self.get_entity_path(parent_run_id)
|
|
|
|
span = self._create_span(
|
|
run_id,
|
|
parent_run_id,
|
|
f"{name}.{request_type.value}",
|
|
kind=SpanKind.CLIENT,
|
|
workflow_name=workflow_name,
|
|
entity_path=entity_path,
|
|
metadata=metadata,
|
|
)
|
|
|
|
vendor = detect_vendor_from_class(
|
|
_extract_class_name_from_serialized(serialized)
|
|
)
|
|
|
|
_set_span_attribute(span, GenAIAttributes.GEN_AI_SYSTEM, vendor)
|
|
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_TYPE, request_type.value)
|
|
|
|
# we already have an LLM span by this point,
|
|
# so skip any downstream instrumentation from here
|
|
try:
|
|
token = context_api.attach(
|
|
context_api.set_value(SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY, True)
|
|
)
|
|
except Exception:
|
|
# If context setting fails, continue without suppression token
|
|
token = None
|
|
|
|
self.spans[run_id] = SpanHolder(
|
|
span, token, None, [], workflow_name, None, entity_path
|
|
)
|
|
|
|
return span
|
|
|
|
@dont_throw
|
|
def on_chain_start(
|
|
self,
|
|
serialized: dict[str, Any],
|
|
inputs: dict[str, Any],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[list[str]] = None,
|
|
metadata: Optional[dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when chain starts running."""
|
|
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
|
|
return
|
|
|
|
workflow_name = ""
|
|
entity_path = ""
|
|
|
|
name = self._get_name_from_callback(serialized, **kwargs)
|
|
kind = (
|
|
TraceloopSpanKindValues.WORKFLOW
|
|
if parent_run_id is None or parent_run_id not in self.spans
|
|
else TraceloopSpanKindValues.TASK
|
|
)
|
|
|
|
if kind == TraceloopSpanKindValues.WORKFLOW:
|
|
workflow_name = name
|
|
else:
|
|
workflow_name = self.get_workflow_name(parent_run_id)
|
|
entity_path = self.get_entity_path(parent_run_id)
|
|
|
|
span = self._create_task_span(
|
|
run_id,
|
|
parent_run_id,
|
|
name,
|
|
kind,
|
|
workflow_name,
|
|
name,
|
|
entity_path,
|
|
metadata,
|
|
)
|
|
if not should_emit_events() and should_send_prompts():
|
|
span.set_attribute(
|
|
SpanAttributes.TRACELOOP_ENTITY_INPUT,
|
|
json.dumps(
|
|
{
|
|
"inputs": inputs,
|
|
"tags": tags,
|
|
"metadata": metadata,
|
|
"kwargs": kwargs,
|
|
},
|
|
cls=CallbackFilteredJSONEncoder,
|
|
),
|
|
)
|
|
|
|
# The start_time is now automatically set when creating the SpanHolder
|
|
|
|
@dont_throw
|
|
def on_chain_end(
|
|
self,
|
|
outputs: dict[str, Any],
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when chain ends running."""
|
|
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
|
|
return
|
|
|
|
span_holder = self.spans[run_id]
|
|
span = span_holder.span
|
|
if not should_emit_events() and should_send_prompts():
|
|
span.set_attribute(
|
|
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
|
|
json.dumps(
|
|
{"outputs": outputs, "kwargs": kwargs},
|
|
cls=CallbackFilteredJSONEncoder,
|
|
),
|
|
)
|
|
|
|
self._end_span(span, run_id)
|
|
if parent_run_id is None:
|
|
try:
|
|
context_api.attach(
|
|
context_api.set_value(
|
|
SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY, False
|
|
)
|
|
)
|
|
except Exception:
|
|
# If context reset fails, it's not critical for functionality
|
|
pass
|
|
|
|
@dont_throw
|
|
def on_chat_model_start(
|
|
self,
|
|
serialized: dict[str, Any],
|
|
messages: list[list[BaseMessage]],
|
|
*,
|
|
run_id: UUID,
|
|
tags: Optional[list[str]] = None,
|
|
parent_run_id: Optional[UUID] = None,
|
|
metadata: Optional[dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when Chat Model starts running."""
|
|
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
|
|
return
|
|
|
|
name = self._get_name_from_callback(serialized, kwargs=kwargs)
|
|
span = self._create_llm_span(
|
|
run_id,
|
|
parent_run_id,
|
|
name,
|
|
LLMRequestTypeValues.CHAT,
|
|
metadata=metadata,
|
|
serialized=serialized,
|
|
)
|
|
set_request_params(span, kwargs, self.spans[run_id])
|
|
if should_emit_events():
|
|
self._emit_chat_input_events(messages)
|
|
else:
|
|
set_chat_request(span, serialized, messages, kwargs, self.spans[run_id])
|
|
|
|
@dont_throw
|
|
def on_llm_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
prompts: List[str],
|
|
*,
|
|
run_id: UUID,
|
|
tags: Optional[list[str]] = None,
|
|
parent_run_id: Optional[UUID] = None,
|
|
metadata: Optional[dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run when Chat Model starts running."""
|
|
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
|
|
return
|
|
|
|
name = self._get_name_from_callback(serialized, kwargs=kwargs)
|
|
span = self._create_llm_span(
|
|
run_id,
|
|
parent_run_id,
|
|
name,
|
|
LLMRequestTypeValues.COMPLETION,
|
|
serialized=serialized,
|
|
)
|
|
set_request_params(span, kwargs, self.spans[run_id])
|
|
if should_emit_events():
|
|
for prompt in prompts:
|
|
emit_event(MessageEvent(content=prompt, role="user"))
|
|
else:
|
|
set_llm_request(span, serialized, prompts, kwargs, self.spans[run_id])
|
|
|
|
@dont_throw
|
|
def on_llm_end(
|
|
self,
|
|
response: LLMResult,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Union[UUID, None] = None,
|
|
**kwargs: Any,
|
|
):
|
|
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
|
|
return
|
|
|
|
span = self._get_span(run_id)
|
|
|
|
model_name = None
|
|
if response.llm_output is not None:
|
|
model_name = response.llm_output.get(
|
|
"model_name"
|
|
) or response.llm_output.get("model_id")
|
|
if model_name is not None:
|
|
_set_span_attribute(
|
|
span, GenAIAttributes.GEN_AI_RESPONSE_MODEL, model_name or "unknown"
|
|
)
|
|
|
|
if self.spans[run_id].request_model is None:
|
|
_set_span_attribute(
|
|
span, GenAIAttributes.GEN_AI_REQUEST_MODEL, model_name
|
|
)
|
|
id = response.llm_output.get("id")
|
|
if id is not None and id != "":
|
|
_set_span_attribute(span, GenAIAttributes.GEN_AI_RESPONSE_ID, id)
|
|
if model_name is None:
|
|
model_name = extract_model_name_from_response_metadata(response)
|
|
if model_name is None and hasattr(context_api, "get_value"):
|
|
association_properties = (
|
|
context_api.get_value("association_properties") or {}
|
|
)
|
|
model_name = _extract_model_name_from_association_metadata(
|
|
association_properties
|
|
)
|
|
token_usage = (response.llm_output or {}).get("token_usage") or (
|
|
response.llm_output or {}
|
|
).get("usage")
|
|
if token_usage is not None:
|
|
prompt_tokens = (
|
|
token_usage.get("prompt_tokens")
|
|
or token_usage.get("input_token_count")
|
|
or token_usage.get("input_tokens")
|
|
)
|
|
completion_tokens = (
|
|
token_usage.get("completion_tokens")
|
|
or token_usage.get("generated_token_count")
|
|
or token_usage.get("output_tokens")
|
|
)
|
|
total_tokens = token_usage.get("total_tokens") or (
|
|
prompt_tokens + completion_tokens
|
|
)
|
|
|
|
_set_span_attribute(
|
|
span, GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS, prompt_tokens
|
|
)
|
|
_set_span_attribute(
|
|
span, GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, completion_tokens
|
|
)
|
|
_set_span_attribute(
|
|
span, SpanAttributes.LLM_USAGE_TOTAL_TOKENS, total_tokens
|
|
)
|
|
|
|
# Record token usage metrics
|
|
vendor = span.attributes.get(GenAIAttributes.GEN_AI_SYSTEM, "Langchain")
|
|
if prompt_tokens > 0:
|
|
self.token_histogram.record(
|
|
prompt_tokens,
|
|
attributes={
|
|
GenAIAttributes.GEN_AI_SYSTEM: vendor,
|
|
GenAIAttributes.GEN_AI_TOKEN_TYPE: "input",
|
|
GenAIAttributes.GEN_AI_RESPONSE_MODEL: model_name or "unknown",
|
|
},
|
|
)
|
|
|
|
if completion_tokens > 0:
|
|
self.token_histogram.record(
|
|
completion_tokens,
|
|
attributes={
|
|
GenAIAttributes.GEN_AI_SYSTEM: vendor,
|
|
GenAIAttributes.GEN_AI_TOKEN_TYPE: "output",
|
|
GenAIAttributes.GEN_AI_RESPONSE_MODEL: model_name or "unknown",
|
|
},
|
|
)
|
|
set_chat_response_usage(
|
|
span, response, self.token_histogram, token_usage is None, model_name
|
|
)
|
|
if should_emit_events():
|
|
self._emit_llm_end_events(response)
|
|
# Also set span attributes for backward compatibility
|
|
set_chat_response(span, response)
|
|
else:
|
|
set_chat_response(span, response)
|
|
|
|
# Record duration before ending span
|
|
duration = time.time() - self.spans[run_id].start_time
|
|
vendor = span.attributes.get(GenAIAttributes.GEN_AI_SYSTEM, "Langchain")
|
|
self.duration_histogram.record(
|
|
duration,
|
|
attributes={
|
|
GenAIAttributes.GEN_AI_SYSTEM: vendor,
|
|
GenAIAttributes.GEN_AI_RESPONSE_MODEL: model_name or "unknown",
|
|
},
|
|
)
|
|
|
|
self._end_span(span, run_id)
|
|
|
|
@dont_throw
|
|
def on_tool_start(
|
|
self,
|
|
serialized: dict[str, Any],
|
|
input_str: str,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[list[str]] = None,
|
|
metadata: Optional[dict[str, Any]] = None,
|
|
inputs: Optional[dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when tool starts running."""
|
|
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
|
|
return
|
|
|
|
name = self._get_name_from_callback(serialized, kwargs=kwargs)
|
|
workflow_name = self.get_workflow_name(parent_run_id)
|
|
entity_path = self.get_entity_path(parent_run_id)
|
|
|
|
span = self._create_task_span(
|
|
run_id,
|
|
parent_run_id,
|
|
name,
|
|
TraceloopSpanKindValues.TOOL,
|
|
workflow_name,
|
|
name,
|
|
entity_path,
|
|
)
|
|
if not should_emit_events() and should_send_prompts():
|
|
span.set_attribute(
|
|
SpanAttributes.TRACELOOP_ENTITY_INPUT,
|
|
json.dumps(
|
|
{
|
|
"input_str": input_str,
|
|
"tags": tags,
|
|
"metadata": metadata,
|
|
"inputs": inputs,
|
|
"kwargs": kwargs,
|
|
},
|
|
cls=CallbackFilteredJSONEncoder,
|
|
),
|
|
)
|
|
|
|
@dont_throw
|
|
def on_tool_end(
|
|
self,
|
|
output: Any,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when tool ends running."""
|
|
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
|
|
return
|
|
|
|
span = self._get_span(run_id)
|
|
if not should_emit_events() and should_send_prompts():
|
|
span.set_attribute(
|
|
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
|
|
json.dumps(
|
|
{"output": output, "kwargs": kwargs},
|
|
cls=CallbackFilteredJSONEncoder,
|
|
),
|
|
)
|
|
self._end_span(span, run_id)
|
|
|
|
def get_parent_span(self, parent_run_id: Optional[str] = None):
|
|
if parent_run_id is None:
|
|
return None
|
|
return self.spans[parent_run_id]
|
|
|
|
def get_workflow_name(self, parent_run_id: str):
|
|
parent_span = self.get_parent_span(parent_run_id)
|
|
|
|
if parent_span is None:
|
|
return ""
|
|
|
|
return parent_span.workflow_name
|
|
|
|
def get_entity_path(self, parent_run_id: str):
|
|
parent_span = self.get_parent_span(parent_run_id)
|
|
|
|
if parent_span is None:
|
|
return ""
|
|
elif (
|
|
parent_span.entity_path == ""
|
|
and parent_span.entity_name == parent_span.workflow_name
|
|
):
|
|
return ""
|
|
elif parent_span.entity_path == "":
|
|
return f"{parent_span.entity_name}"
|
|
else:
|
|
return f"{parent_span.entity_path}.{parent_span.entity_name}"
|
|
|
|
def _handle_error(
|
|
self,
|
|
error: BaseException,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Common error handling logic for all components."""
|
|
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
|
|
return
|
|
|
|
span = self._get_span(run_id)
|
|
span.set_status(Status(StatusCode.ERROR), str(error))
|
|
span.record_exception(error)
|
|
self._end_span(span, run_id)
|
|
|
|
@dont_throw
|
|
def on_llm_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when LLM errors."""
|
|
self._handle_error(error, run_id, parent_run_id, **kwargs)
|
|
|
|
@dont_throw
|
|
def on_chain_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when chain errors."""
|
|
self._handle_error(error, run_id, parent_run_id, **kwargs)
|
|
|
|
@dont_throw
|
|
def on_tool_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when tool errors."""
|
|
span = self._get_span(run_id)
|
|
span.set_attribute(ERROR_TYPE, type(error).__name__)
|
|
self._handle_error(error, run_id, parent_run_id, **kwargs)
|
|
|
|
@dont_throw
|
|
def on_agent_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when agent errors."""
|
|
self._handle_error(error, run_id, parent_run_id, **kwargs)
|
|
|
|
@dont_throw
|
|
def on_retriever_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when retriever errors."""
|
|
self._handle_error(error, run_id, parent_run_id, **kwargs)
|
|
|
|
def _emit_chat_input_events(self, messages):
|
|
for message_list in messages:
|
|
for message in message_list:
|
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
tool_calls = _extract_tool_call_data(message.tool_calls)
|
|
else:
|
|
tool_calls = None
|
|
emit_event(
|
|
MessageEvent(
|
|
content=message.content,
|
|
role=get_message_role(message),
|
|
tool_calls=tool_calls,
|
|
)
|
|
)
|
|
|
|
def _emit_llm_end_events(self, response):
|
|
for generation_list in response.generations:
|
|
for i, generation in enumerate(generation_list):
|
|
self._emit_generation_choice_event(index=i, generation=generation)
|
|
|
|
def _emit_generation_choice_event(
|
|
self,
|
|
index: int,
|
|
generation: Union[
|
|
ChatGeneration, ChatGenerationChunk, Generation, GenerationChunk
|
|
],
|
|
):
|
|
if isinstance(generation, (ChatGeneration, ChatGenerationChunk)):
|
|
# Get finish reason
|
|
if hasattr(generation, "generation_info") and generation.generation_info:
|
|
finish_reason = generation.generation_info.get(
|
|
"finish_reason", "unknown"
|
|
)
|
|
else:
|
|
finish_reason = "unknown"
|
|
|
|
# Get tool calls
|
|
if (
|
|
hasattr(generation.message, "tool_calls")
|
|
and generation.message.tool_calls
|
|
):
|
|
tool_calls = _extract_tool_call_data(generation.message.tool_calls)
|
|
elif hasattr(
|
|
generation.message, "additional_kwargs"
|
|
) and generation.message.additional_kwargs.get("function_call"):
|
|
tool_calls = _extract_tool_call_data(
|
|
[generation.message.additional_kwargs.get("function_call")]
|
|
)
|
|
else:
|
|
tool_calls = None
|
|
|
|
# Emit the event
|
|
if hasattr(generation, "text") and generation.text != "":
|
|
emit_event(
|
|
ChoiceEvent(
|
|
index=index,
|
|
message={"content": generation.text, "role": "assistant"},
|
|
finish_reason=finish_reason,
|
|
tool_calls=tool_calls,
|
|
)
|
|
)
|
|
else:
|
|
emit_event(
|
|
ChoiceEvent(
|
|
index=index,
|
|
message={
|
|
"content": generation.message.content,
|
|
"role": "assistant",
|
|
},
|
|
finish_reason=finish_reason,
|
|
tool_calls=tool_calls,
|
|
)
|
|
)
|
|
elif isinstance(generation, (Generation, GenerationChunk)):
|
|
# Get finish reason
|
|
if hasattr(generation, "generation_info") and generation.generation_info:
|
|
finish_reason = generation.generation_info.get(
|
|
"finish_reason", "unknown"
|
|
)
|
|
else:
|
|
finish_reason = "unknown"
|
|
|
|
# Emit the event
|
|
emit_event(
|
|
ChoiceEvent(
|
|
index=index,
|
|
message={"content": generation.text, "role": "assistant"},
|
|
finish_reason=finish_reason,
|
|
)
|
|
)
|