ai-station/.venv/lib/python3.12/site-packages/literalai/instrumentation/openai.py

510 lines
18 KiB
Python
Raw Normal View History

import logging
import time
from typing import TYPE_CHECKING, Dict, Union
from literalai.instrumentation import OPENAI_PROVIDER
from literalai.requirements import check_all_requirements
if TYPE_CHECKING:
from literalai.client import LiteralClient
from literalai.context import active_steps_var, active_thread_var
from literalai.helper import ensure_values_serializable
from literalai.observability.generation import (
ChatGeneration,
CompletionGeneration,
GenerationMessage,
GenerationType,
)
from literalai.wrappers import AfterContext, BeforeContext, wrap_all
logger = logging.getLogger(__name__)
REQUIREMENTS = ["openai>=1.0.0"]
TO_WRAP = [
{
"module": "openai.resources.chat.completions",
"object": "Completions",
"method": "create",
"metadata": {
"type": GenerationType.CHAT,
},
"async": False,
},
{
"module": "openai.resources.completions",
"object": "Completions",
"method": "create",
"metadata": {
"type": GenerationType.COMPLETION,
},
"async": False,
},
{
"module": "openai.resources.chat.completions",
"object": "AsyncCompletions",
"method": "create",
"metadata": {
"type": GenerationType.CHAT,
},
"async": True,
},
{
"module": "openai.resources.completions",
"object": "AsyncCompletions",
"method": "create",
"metadata": {
"type": GenerationType.COMPLETION,
},
"async": True,
},
]
is_openai_instrumented = False
def instrument_openai(client: "LiteralClient", on_new_generation=None):
"""
Instruments all OpenAI LLM calls to automatically send logs to Literal AI.
"""
global is_openai_instrumented
if is_openai_instrumented:
return
if not check_all_requirements(REQUIREMENTS):
raise Exception(
f"OpenAI instrumentation requirements not satisfied: {REQUIREMENTS}"
)
import inspect
if callable(on_new_generation):
sig = inspect.signature(on_new_generation)
parameters = list(sig.parameters.values())
if len(parameters) != 2:
raise ValueError(
"on_new_generation should take 2 parameters: generation and timing"
)
from openai import AsyncStream, Stream
from openai.types.chat.chat_completion_chunk import ChoiceDelta
def init_generation(generation_type: "GenerationType", kwargs):
model = kwargs.get("model")
tools = kwargs.get("tools")
if generation_type == GenerationType.CHAT:
orig_messages = kwargs.get("messages")
messages = ensure_values_serializable(orig_messages)
prompt_id = None
variables = None
for index, message in enumerate(messages):
orig_message = orig_messages[index]
if literal_prompt := getattr(orig_message, "__literal_prompt__", None):
prompt_id = literal_prompt.get("prompt_id")
variables = literal_prompt.get("variables")
message["uuid"] = literal_prompt.get("uuid")
message["templated"] = True
settings = {
"model": model,
"frequency_penalty": kwargs.get("frequency_penalty"),
"logit_bias": kwargs.get("logit_bias"),
"logprobs": kwargs.get("logprobs"),
"top_logprobs": kwargs.get("top_logprobs"),
"max_tokens": kwargs.get("max_tokens"),
"n": kwargs.get("n"),
"presence_penalty": kwargs.get("presence_penalty"),
"response_format": kwargs.get("response_format"),
"seed": kwargs.get("seed"),
"stop": kwargs.get("stop"),
"stream": kwargs.get("stream"),
"temperature": kwargs.get("temperature"),
"top_p": kwargs.get("top_p"),
"tool_choice": kwargs.get("tool_choice"),
}
settings = {k: v for k, v in settings.items() if v is not None}
return ChatGeneration(
prompt_id=prompt_id,
variables=variables,
provider=OPENAI_PROVIDER,
model=model,
tools=tools,
settings=settings,
messages=messages,
metadata=kwargs.get("literalai_metadata"),
tags=kwargs.get("literalai_tags"),
)
elif generation_type == GenerationType.COMPLETION:
settings = {
"model": model,
"best_of": kwargs.get("best_of"),
"echo": kwargs.get("echo"),
"frequency_penalty": kwargs.get("frequency_penalty"),
"logit_bias": kwargs.get("logit_bias"),
"logprobs": kwargs.get("logprobs"),
"max_tokens": kwargs.get("max_tokens"),
"n": kwargs.get("n"),
"presence_penalty": kwargs.get("presence_penalty"),
"seed": kwargs.get("seed"),
"stop": kwargs.get("stop"),
"stream": kwargs.get("stream"),
"suffix": kwargs.get("suffix"),
"temperature": kwargs.get("temperature"),
"top_p": kwargs.get("top_p"),
}
settings = {k: v for k, v in settings.items() if v is not None}
return CompletionGeneration(
provider=OPENAI_PROVIDER,
model=model,
settings=settings,
prompt=kwargs.get("prompt"),
metadata=kwargs.get("literalai_metadata"),
tags=kwargs.get("literalai_tags"),
)
def update_step_after(
generation: Union[ChatGeneration, CompletionGeneration], result
):
if generation and isinstance(generation, ChatGeneration):
generation.message_completion = result.choices[0].message.model_dump()
elif generation and isinstance(generation, CompletionGeneration):
if generation and generation.type == GenerationType.COMPLETION:
generation.completion = result.choices[0].text
if generation:
generation.input_token_count = result.usage.prompt_tokens
generation.output_token_count = result.usage.completion_tokens
generation.token_count = result.usage.total_tokens
def before_wrapper(metadata: Dict):
def before(context: BeforeContext, *args, **kwargs):
active_thread = active_thread_var.get()
active_steps = active_steps_var.get()
generation = init_generation(metadata["type"], kwargs)
if (active_thread or active_steps) and not callable(on_new_generation):
step = client.start_step(
name=context["original_func"].__name__, type="llm"
)
step.name = generation.model or OPENAI_PROVIDER
if isinstance(generation, ChatGeneration):
step.input = {"content": generation.messages}
else:
step.input = {"content": generation.prompt}
context["step"] = step
context["generation"] = generation
context["start"] = time.time()
return before
def async_before_wrapper(metadata: Dict):
async def before(context: BeforeContext, *args, **kwargs):
active_thread = active_thread_var.get()
active_steps = active_steps_var.get()
generation = init_generation(metadata["type"], kwargs)
if (active_thread or active_steps) and not callable(on_new_generation):
step = client.start_step(
name=context["original_func"].__name__, type="llm"
)
step.name = generation.model or OPENAI_PROVIDER
if isinstance(generation, ChatGeneration):
step.input = {"content": generation.messages}
else:
step.input = {"content": generation.prompt}
context["step"] = step
context["generation"] = generation
context["start"] = time.time()
return before
def process_delta(new_delta: ChoiceDelta, message_completion: GenerationMessage):
if new_delta.function_call:
if new_delta.function_call.name:
message_completion["function_call"] = {
"name": new_delta.function_call.name
}
if not message_completion["function_call"]:
return False
if new_delta.function_call.arguments:
if "arguments" not in message_completion["function_call"]:
message_completion["function_call"]["arguments"] = ""
message_completion["function_call"][
"arguments"
] += new_delta.function_call.arguments
return True
elif new_delta.tool_calls:
if "tool_calls" not in message_completion:
message_completion["tool_calls"] = []
delta_tool_call = new_delta.tool_calls[0]
delta_function = delta_tool_call.function
if not delta_function:
return False
if delta_function.name:
message_completion["tool_calls"].append( # type: ignore
{
"id": delta_tool_call.id,
"type": "function",
"function": {
"name": delta_function.name,
"arguments": "",
},
}
)
if delta_function.arguments:
message_completion["tool_calls"][delta_tool_call.index]["function"][ # type: ignore
"arguments"
] += delta_function.arguments
return True
elif new_delta.content:
if isinstance(message_completion["content"], str):
message_completion["content"] += new_delta.content
return True
else:
return False
def streaming_response(
generation: Union[ChatGeneration, CompletionGeneration],
result,
context: AfterContext,
):
completion = ""
message_completion = {
"role": "assistant",
"content": "",
} # type: GenerationMessage
token_count = 0
for chunk in result:
if generation and isinstance(generation, ChatGeneration):
if len(chunk.choices) > 0:
ok = process_delta(chunk.choices[0].delta, message_completion)
if not ok:
yield chunk
continue
if generation.tt_first_token is None:
generation.tt_first_token = (
time.time() - context["start"]
) * 1000
token_count += 1
elif generation and isinstance(generation, CompletionGeneration):
if len(chunk.choices) > 0 and chunk.choices[0].text is not None:
if generation.tt_first_token is None:
generation.tt_first_token = (
time.time() - context["start"]
) * 1000
token_count += 1
completion += chunk.choices[0].text
if (
generation
and getattr(chunk, "model", None)
and generation.model != chunk.model
):
generation.model = chunk.model
if generation.settings:
generation.settings["model"] = chunk.model
yield chunk
if generation:
generation.duration = time.time() - context["start"]
if generation.duration and token_count:
generation.token_throughput_in_s = token_count / generation.duration
if isinstance(generation, ChatGeneration):
generation.message_completion = message_completion
else:
generation.completion = completion
step = context.get("step")
if callable(on_new_generation):
on_new_generation(
generation,
{
"start": context["start"],
"end": time.time(),
},
)
elif step:
if isinstance(generation, ChatGeneration):
step.output = generation.message_completion # type: ignore
else:
step.output = {"content": generation.completion}
step.generation = generation
step.end()
else:
client.api.create_generation(generation)
def after_wrapper(metadata: Dict):
# Needs to be done in a separate function to avoid transforming all returned data into generators
def after(result, context: AfterContext, *args, **kwargs):
step = context.get("step")
generation = context.get("generation")
if not generation:
return result
if model := getattr(result, "model", None):
generation.model = model
if generation.settings:
generation.settings["model"] = model
if isinstance(result, Stream):
return streaming_response(generation, result, context)
else:
generation.duration = time.time() - context["start"]
update_step_after(generation, result)
if callable(on_new_generation):
on_new_generation(
generation,
{
"start": context["start"],
"end": time.time(),
},
)
elif step:
if isinstance(generation, ChatGeneration):
step.output = generation.message_completion # type: ignore
else:
step.output = {"content": generation.completion}
step.generation = generation
step.end()
else:
client.api.create_generation(generation)
return result
return after
async def async_streaming_response(
generation: Union[ChatGeneration, CompletionGeneration],
result,
context: AfterContext,
):
completion = ""
message_completion = {
"role": "assistant",
"content": "",
} # type: GenerationMessage
token_count = 0
async for chunk in result:
if generation and isinstance(generation, ChatGeneration):
if len(chunk.choices) > 0:
ok = process_delta(chunk.choices[0].delta, message_completion)
if not ok:
yield chunk
continue
if generation.tt_first_token is None:
generation.tt_first_token = (
time.time() - context["start"]
) * 1000
token_count += 1
elif generation and isinstance(generation, CompletionGeneration):
if len(chunk.choices) > 0 and chunk.choices[0].text is not None:
if generation.tt_first_token is None:
generation.tt_first_token = (
time.time() - context["start"]
) * 1000
token_count += 1
completion += chunk.choices[0].text
if (
generation
and getattr(chunk, "model", None)
and generation.model != chunk.model
):
generation.model = chunk.model
yield chunk
if generation:
generation.duration = time.time() - context["start"]
if generation.duration and token_count:
generation.token_throughput_in_s = token_count / generation.duration
if isinstance(generation, ChatGeneration):
generation.message_completion = message_completion
else:
generation.completion = completion
step = context.get("step")
if callable(on_new_generation):
on_new_generation(
generation,
{
"start": context["start"],
"end": time.time(),
},
)
elif step:
if isinstance(generation, ChatGeneration):
step.output = generation.message_completion # type: ignore
else:
step.output = {"content": generation.completion}
step.generation = generation
step.end()
else:
client.api.create_generation(generation)
def async_after_wrapper(metadata: Dict):
async def after(result, context: AfterContext, *args, **kwargs):
step = context.get("step")
generation = context.get("generation")
if not generation:
return result
if model := getattr(result, "model", None):
generation.model = model
if generation.settings:
generation.settings["model"] = model
if isinstance(result, AsyncStream):
return async_streaming_response(generation, result, context)
else:
generation.duration = time.time() - context["start"]
update_step_after(generation, result)
if callable(on_new_generation):
on_new_generation(
generation,
{
"start": context["start"],
"end": time.time(),
},
)
elif step:
if isinstance(generation, ChatGeneration):
step.output = generation.message_completion # type: ignore
else:
step.output = {"content": generation.completion}
step.generation = generation
step.end()
else:
client.api.create_generation(generation)
return result
return after
wrap_all(
TO_WRAP,
before_wrapper,
after_wrapper,
async_before_wrapper,
async_after_wrapper,
)
is_openai_instrumented = True