4006 lines
154 KiB
Python
4006 lines
154 KiB
Python
|
|
import asyncio
|
||
|
|
import copy
|
||
|
|
import hashlib
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import smtplib
|
||
|
|
import threading
|
||
|
|
import time
|
||
|
|
import traceback
|
||
|
|
from datetime import datetime, timedelta
|
||
|
|
from email.mime.multipart import MIMEMultipart
|
||
|
|
from email.mime.text import MIMEText
|
||
|
|
from typing import (
|
||
|
|
TYPE_CHECKING,
|
||
|
|
Any,
|
||
|
|
Dict,
|
||
|
|
List,
|
||
|
|
Literal,
|
||
|
|
Optional,
|
||
|
|
Union,
|
||
|
|
cast,
|
||
|
|
overload,
|
||
|
|
)
|
||
|
|
|
||
|
|
from litellm.constants import MAX_TEAM_LIST_LIMIT, DEFAULT_MODEL_CREATED_AT_TIME
|
||
|
|
from litellm.proxy._types import (
|
||
|
|
DB_CONNECTION_ERROR_TYPES,
|
||
|
|
CommonProxyErrors,
|
||
|
|
ProxyErrorTypes,
|
||
|
|
ProxyException,
|
||
|
|
SpendLogsMetadata,
|
||
|
|
SpendLogsPayload,
|
||
|
|
)
|
||
|
|
from litellm.types.guardrails import GuardrailEventHooks
|
||
|
|
|
||
|
|
try:
|
||
|
|
import backoff
|
||
|
|
except ImportError:
|
||
|
|
raise ImportError(
|
||
|
|
"backoff is not installed. Please install it via 'pip install backoff'"
|
||
|
|
)
|
||
|
|
|
||
|
|
from fastapi import HTTPException, status
|
||
|
|
|
||
|
|
import litellm
|
||
|
|
import litellm.litellm_core_utils
|
||
|
|
import litellm.litellm_core_utils.litellm_logging
|
||
|
|
from litellm import (
|
||
|
|
EmbeddingResponse,
|
||
|
|
ImageResponse,
|
||
|
|
ModelResponse,
|
||
|
|
ModelResponseStream,
|
||
|
|
Router,
|
||
|
|
)
|
||
|
|
from litellm._logging import verbose_proxy_logger
|
||
|
|
from litellm._service_logger import ServiceLogging, ServiceTypes
|
||
|
|
from litellm.caching.caching import DualCache, RedisCache
|
||
|
|
from litellm.exceptions import (
|
||
|
|
BlockedPiiEntityError,
|
||
|
|
GuardrailRaisedException,
|
||
|
|
RejectedRequestError,
|
||
|
|
)
|
||
|
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||
|
|
from litellm.integrations.custom_logger import CustomLogger
|
||
|
|
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||
|
|
from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert
|
||
|
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||
|
|
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||
|
|
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||
|
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||
|
|
from litellm.proxy._types import (
|
||
|
|
AlertType,
|
||
|
|
CallInfo,
|
||
|
|
LiteLLM_VerificationTokenView,
|
||
|
|
Member,
|
||
|
|
UserAPIKeyAuth,
|
||
|
|
)
|
||
|
|
from litellm.proxy.auth.route_checks import RouteChecks
|
||
|
|
from litellm.proxy.db.create_views import (
|
||
|
|
create_missing_views,
|
||
|
|
should_create_missing_views,
|
||
|
|
)
|
||
|
|
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
||
|
|
from litellm.proxy.db.log_db_metrics import log_db_metrics
|
||
|
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||
|
|
from litellm.proxy.hooks import PROXY_HOOKS, get_proxy_hook
|
||
|
|
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||
|
|
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||
|
|
from litellm.proxy.hooks.parallel_request_limiter import (
|
||
|
|
_PROXY_MaxParallelRequestsHandler,
|
||
|
|
)
|
||
|
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||
|
|
from litellm.secret_managers.main import str_to_bool
|
||
|
|
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
|
||
|
|
from litellm.types.mcp import (
|
||
|
|
MCPDuringCallResponseObject,
|
||
|
|
MCPPreCallRequestObject,
|
||
|
|
MCPPreCallResponseObject,
|
||
|
|
)
|
||
|
|
from litellm.types.utils import CallTypes, LLMResponseTypes, LoggedLiteLLMParams
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from opentelemetry.trace import Span as _Span
|
||
|
|
|
||
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||
|
|
|
||
|
|
Span = Union[_Span, Any]
|
||
|
|
else:
|
||
|
|
Span = Any
|
||
|
|
|
||
|
|
|
||
|
|
def print_verbose(print_statement):
|
||
|
|
"""
|
||
|
|
Prints the given `print_statement` to the console if `litellm.set_verbose` is True.
|
||
|
|
Also logs the `print_statement` at the debug level using `verbose_proxy_logger`.
|
||
|
|
|
||
|
|
:param print_statement: The statement to be printed and logged.
|
||
|
|
:type print_statement: Any
|
||
|
|
"""
|
||
|
|
import traceback
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug("{}\n{}".format(print_statement, traceback.format_exc()))
|
||
|
|
if litellm.set_verbose:
|
||
|
|
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
||
|
|
|
||
|
|
|
||
|
|
class InternalUsageCache:
|
||
|
|
def __init__(self, dual_cache: DualCache):
|
||
|
|
self.dual_cache: DualCache = dual_cache
|
||
|
|
|
||
|
|
async def async_get_cache(
|
||
|
|
self,
|
||
|
|
key,
|
||
|
|
litellm_parent_otel_span: Union[Span, None],
|
||
|
|
local_only: bool = False,
|
||
|
|
**kwargs,
|
||
|
|
) -> Any:
|
||
|
|
return await self.dual_cache.async_get_cache(
|
||
|
|
key=key,
|
||
|
|
local_only=local_only,
|
||
|
|
parent_otel_span=litellm_parent_otel_span,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def async_set_cache(
|
||
|
|
self,
|
||
|
|
key,
|
||
|
|
value,
|
||
|
|
litellm_parent_otel_span: Union[Span, None],
|
||
|
|
local_only: bool = False,
|
||
|
|
**kwargs,
|
||
|
|
) -> None:
|
||
|
|
return await self.dual_cache.async_set_cache(
|
||
|
|
key=key,
|
||
|
|
value=value,
|
||
|
|
local_only=local_only,
|
||
|
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def async_batch_set_cache(
|
||
|
|
self,
|
||
|
|
cache_list: List,
|
||
|
|
litellm_parent_otel_span: Union[Span, None],
|
||
|
|
local_only: bool = False,
|
||
|
|
**kwargs,
|
||
|
|
) -> None:
|
||
|
|
return await self.dual_cache.async_set_cache_pipeline(
|
||
|
|
cache_list=cache_list,
|
||
|
|
local_only=local_only,
|
||
|
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def async_batch_get_cache(
|
||
|
|
self,
|
||
|
|
keys: list,
|
||
|
|
parent_otel_span: Optional[Span] = None,
|
||
|
|
local_only: bool = False,
|
||
|
|
):
|
||
|
|
return await self.dual_cache.async_batch_get_cache(
|
||
|
|
keys=keys,
|
||
|
|
parent_otel_span=parent_otel_span,
|
||
|
|
local_only=local_only,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def async_increment_cache(
|
||
|
|
self,
|
||
|
|
key,
|
||
|
|
value: float,
|
||
|
|
litellm_parent_otel_span: Union[Span, None],
|
||
|
|
local_only: bool = False,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
return await self.dual_cache.async_increment_cache(
|
||
|
|
key=key,
|
||
|
|
value=value,
|
||
|
|
local_only=local_only,
|
||
|
|
parent_otel_span=litellm_parent_otel_span,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
def set_cache(
|
||
|
|
self,
|
||
|
|
key,
|
||
|
|
value,
|
||
|
|
local_only: bool = False,
|
||
|
|
**kwargs,
|
||
|
|
) -> None:
|
||
|
|
return self.dual_cache.set_cache(
|
||
|
|
key=key,
|
||
|
|
value=value,
|
||
|
|
local_only=local_only,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
def get_cache(
|
||
|
|
self,
|
||
|
|
key,
|
||
|
|
local_only: bool = False,
|
||
|
|
**kwargs,
|
||
|
|
) -> Any:
|
||
|
|
return self.dual_cache.get_cache(
|
||
|
|
key=key,
|
||
|
|
local_only=local_only,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
### LOGGING ###
|
||
|
|
class ProxyLogging:
|
||
|
|
"""
|
||
|
|
Logging/Custom Handlers for proxy.
|
||
|
|
|
||
|
|
Implemented mainly to:
|
||
|
|
- log successful/failed db read/writes
|
||
|
|
- support the max parallel request integration
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
user_api_key_cache: DualCache,
|
||
|
|
premium_user: bool = False,
|
||
|
|
):
|
||
|
|
## INITIALIZE LITELLM CALLBACKS ##
|
||
|
|
self.call_details: dict = {}
|
||
|
|
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||
|
|
self.internal_usage_cache: InternalUsageCache = InternalUsageCache(
|
||
|
|
dual_cache=DualCache(default_in_memory_ttl=1) # ping redis cache every 1s
|
||
|
|
)
|
||
|
|
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler(
|
||
|
|
self.internal_usage_cache
|
||
|
|
)
|
||
|
|
self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
|
||
|
|
self.cache_control_check = _PROXY_CacheControlCheck()
|
||
|
|
self.alerting: Optional[List] = None
|
||
|
|
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
||
|
|
self.alert_types: List[AlertType] = DEFAULT_ALERT_TYPES
|
||
|
|
self.alert_to_webhook_url: Optional[dict] = None
|
||
|
|
self.slack_alerting_instance: SlackAlerting = SlackAlerting(
|
||
|
|
alerting_threshold=self.alerting_threshold,
|
||
|
|
alerting=self.alerting,
|
||
|
|
internal_usage_cache=self.internal_usage_cache.dual_cache,
|
||
|
|
)
|
||
|
|
self.premium_user = premium_user
|
||
|
|
self.service_logging_obj = ServiceLogging()
|
||
|
|
self.db_spend_update_writer = DBSpendUpdateWriter()
|
||
|
|
self.proxy_hook_mapping: Dict[str, CustomLogger] = {}
|
||
|
|
|
||
|
|
# Guard flags to prevent duplicate background tasks
|
||
|
|
self.daily_report_started: bool = False
|
||
|
|
self.hanging_requests_check_started: bool = False
|
||
|
|
|
||
|
|
def startup_event(
|
||
|
|
self,
|
||
|
|
llm_router: Optional[Router],
|
||
|
|
redis_usage_cache: Optional[RedisCache],
|
||
|
|
):
|
||
|
|
"""Initialize logging and alerting on proxy startup"""
|
||
|
|
## UPDATE SLACK ALERTING ##
|
||
|
|
self.slack_alerting_instance.update_values(llm_router=llm_router)
|
||
|
|
|
||
|
|
## UPDATE INTERNAL USAGE CACHE ##
|
||
|
|
self.update_values(
|
||
|
|
redis_cache=redis_usage_cache
|
||
|
|
) # used by parallel request limiter for rate limiting keys across instances
|
||
|
|
|
||
|
|
self._init_litellm_callbacks(
|
||
|
|
llm_router=llm_router
|
||
|
|
) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||
|
|
|
||
|
|
if (
|
||
|
|
self.slack_alerting_instance is not None
|
||
|
|
and "daily_reports" in self.slack_alerting_instance.alert_types
|
||
|
|
and not self.daily_report_started
|
||
|
|
):
|
||
|
|
asyncio.create_task(
|
||
|
|
self.slack_alerting_instance._run_scheduled_daily_report(
|
||
|
|
llm_router=llm_router
|
||
|
|
)
|
||
|
|
) # RUN DAILY REPORT (if scheduled)
|
||
|
|
self.daily_report_started = True
|
||
|
|
|
||
|
|
if (
|
||
|
|
self.slack_alerting_instance is not None
|
||
|
|
and AlertType.llm_requests_hanging
|
||
|
|
in self.slack_alerting_instance.alert_types
|
||
|
|
and not self.hanging_requests_check_started
|
||
|
|
):
|
||
|
|
asyncio.create_task(
|
||
|
|
self.slack_alerting_instance.hanging_request_check.check_for_hanging_requests()
|
||
|
|
) # RUN HANGING REQUEST CHECK (if user wants to alert on hanging requests)
|
||
|
|
self.hanging_requests_check_started = True
|
||
|
|
|
||
|
|
def update_values(
|
||
|
|
self,
|
||
|
|
alerting: Optional[List] = None,
|
||
|
|
alerting_threshold: Optional[float] = None,
|
||
|
|
redis_cache: Optional[RedisCache] = None,
|
||
|
|
alert_types: Optional[List[AlertType]] = None,
|
||
|
|
alerting_args: Optional[dict] = None,
|
||
|
|
alert_to_webhook_url: Optional[dict] = None,
|
||
|
|
):
|
||
|
|
updated_slack_alerting: bool = False
|
||
|
|
if alerting is not None:
|
||
|
|
self.alerting = alerting
|
||
|
|
updated_slack_alerting = True
|
||
|
|
if alerting_threshold is not None:
|
||
|
|
self.alerting_threshold = alerting_threshold
|
||
|
|
updated_slack_alerting = True
|
||
|
|
if alert_types is not None:
|
||
|
|
self.alert_types = alert_types
|
||
|
|
updated_slack_alerting = True
|
||
|
|
if alert_to_webhook_url is not None:
|
||
|
|
self.alert_to_webhook_url = alert_to_webhook_url
|
||
|
|
updated_slack_alerting = True
|
||
|
|
|
||
|
|
if updated_slack_alerting is True:
|
||
|
|
self.slack_alerting_instance.update_values(
|
||
|
|
alerting=self.alerting,
|
||
|
|
alerting_threshold=self.alerting_threshold,
|
||
|
|
alert_types=self.alert_types,
|
||
|
|
alerting_args=alerting_args,
|
||
|
|
alert_to_webhook_url=self.alert_to_webhook_url,
|
||
|
|
)
|
||
|
|
|
||
|
|
if self.alerting is not None and "slack" in self.alerting:
|
||
|
|
# NOTE: ENSURE we only add callbacks when alerting is on
|
||
|
|
# We should NOT add callbacks when alerting is off
|
||
|
|
if (
|
||
|
|
"daily_reports" in self.alert_types
|
||
|
|
or "outage_alerts" in self.alert_types
|
||
|
|
or "region_outage_alerts" in self.alert_types
|
||
|
|
):
|
||
|
|
litellm.logging_callback_manager.add_litellm_callback(self.slack_alerting_instance) # type: ignore
|
||
|
|
litellm.logging_callback_manager.add_litellm_success_callback(
|
||
|
|
self.slack_alerting_instance.response_taking_too_long_callback
|
||
|
|
)
|
||
|
|
|
||
|
|
if redis_cache is not None:
|
||
|
|
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
|
||
|
|
self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache
|
||
|
|
self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache
|
||
|
|
|
||
|
|
def _add_proxy_hooks(self, llm_router: Optional[Router] = None):
|
||
|
|
"""
|
||
|
|
Add proxy hooks to litellm.callbacks
|
||
|
|
"""
|
||
|
|
from litellm.proxy.proxy_server import prisma_client
|
||
|
|
|
||
|
|
for hook in PROXY_HOOKS:
|
||
|
|
proxy_hook = get_proxy_hook(hook)
|
||
|
|
import inspect
|
||
|
|
|
||
|
|
expected_args = inspect.getfullargspec(proxy_hook).args
|
||
|
|
passed_in_args: Dict[str, Any] = {}
|
||
|
|
if "internal_usage_cache" in expected_args:
|
||
|
|
passed_in_args["internal_usage_cache"] = self.internal_usage_cache
|
||
|
|
if "prisma_client" in expected_args:
|
||
|
|
passed_in_args["prisma_client"] = prisma_client
|
||
|
|
proxy_hook_obj = cast(CustomLogger, proxy_hook(**passed_in_args))
|
||
|
|
litellm.logging_callback_manager.add_litellm_callback(proxy_hook_obj)
|
||
|
|
|
||
|
|
self.proxy_hook_mapping[hook] = proxy_hook_obj
|
||
|
|
|
||
|
|
def get_proxy_hook(self, hook: str) -> Optional[CustomLogger]:
|
||
|
|
"""
|
||
|
|
Get a proxy hook from the proxy_hook_mapping
|
||
|
|
"""
|
||
|
|
return self.proxy_hook_mapping.get(hook)
|
||
|
|
|
||
|
|
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
|
||
|
|
self._add_proxy_hooks(llm_router)
|
||
|
|
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
|
||
|
|
for callback in litellm.callbacks:
|
||
|
|
if isinstance(callback, str):
|
||
|
|
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||
|
|
callback,
|
||
|
|
internal_usage_cache=self.internal_usage_cache.dual_cache,
|
||
|
|
llm_router=llm_router,
|
||
|
|
)
|
||
|
|
if callback is None:
|
||
|
|
continue
|
||
|
|
if callback not in litellm.input_callback:
|
||
|
|
litellm.input_callback.append(callback) # type: ignore
|
||
|
|
if callback not in litellm.success_callback:
|
||
|
|
litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore
|
||
|
|
if callback not in litellm.failure_callback:
|
||
|
|
litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore
|
||
|
|
if callback not in litellm._async_success_callback:
|
||
|
|
litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore
|
||
|
|
if callback not in litellm._async_failure_callback:
|
||
|
|
litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore
|
||
|
|
if callback not in litellm.service_callback:
|
||
|
|
litellm.service_callback.append(callback) # type: ignore
|
||
|
|
|
||
|
|
if (
|
||
|
|
len(litellm.input_callback) > 0
|
||
|
|
or len(litellm.success_callback) > 0
|
||
|
|
or len(litellm.failure_callback) > 0
|
||
|
|
):
|
||
|
|
callback_list = list(
|
||
|
|
set(
|
||
|
|
litellm.input_callback
|
||
|
|
+ litellm.success_callback
|
||
|
|
+ litellm.failure_callback
|
||
|
|
)
|
||
|
|
)
|
||
|
|
litellm.litellm_core_utils.litellm_logging.set_callbacks(
|
||
|
|
callback_list=callback_list
|
||
|
|
)
|
||
|
|
|
||
|
|
async def update_request_status(
|
||
|
|
self, litellm_call_id: str, status: Literal["success", "fail"]
|
||
|
|
):
|
||
|
|
# only use this if slack alerting is being used
|
||
|
|
if self.alerting is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
# current alerting threshold
|
||
|
|
alerting_threshold: float = self.alerting_threshold
|
||
|
|
|
||
|
|
# add a 100 second buffer to the alerting threshold
|
||
|
|
# ensures we don't send errant hanging request slack alerts
|
||
|
|
alerting_threshold += 100
|
||
|
|
|
||
|
|
await self.internal_usage_cache.async_set_cache(
|
||
|
|
key="request_status:{}".format(litellm_call_id),
|
||
|
|
value=status,
|
||
|
|
local_only=True,
|
||
|
|
ttl=alerting_threshold,
|
||
|
|
litellm_parent_otel_span=None,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def async_pre_mcp_tool_call_hook(
|
||
|
|
self,
|
||
|
|
kwargs: dict,
|
||
|
|
request_obj: Any,
|
||
|
|
start_time: datetime,
|
||
|
|
end_time: datetime,
|
||
|
|
) -> Optional[Any]:
|
||
|
|
"""
|
||
|
|
Pre MCP Tool Call Hook
|
||
|
|
|
||
|
|
Use this to validate and modify MCP tool calls before execution.
|
||
|
|
Reuses existing LLM guardrail logic by converting MCP calls to message format.
|
||
|
|
"""
|
||
|
|
from litellm.types.llms.base import HiddenParams
|
||
|
|
from litellm.types.mcp import MCPPreCallRequestObject
|
||
|
|
|
||
|
|
callbacks = self.get_combined_callback_list(
|
||
|
|
dynamic_success_callbacks=getattr(self, "dynamic_success_callbacks", None),
|
||
|
|
global_callbacks=litellm.success_callback,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create the request object if it's not already one
|
||
|
|
if not isinstance(request_obj, MCPPreCallRequestObject):
|
||
|
|
# Convert UserAPIKeyAuth object to dict if needed
|
||
|
|
user_api_key_auth_dict = self._convert_user_api_key_auth_to_dict(
|
||
|
|
kwargs.get("user_api_key_auth")
|
||
|
|
)
|
||
|
|
|
||
|
|
request_obj = MCPPreCallRequestObject(
|
||
|
|
tool_name=kwargs.get("name", ""),
|
||
|
|
arguments=kwargs.get("arguments", {}),
|
||
|
|
server_name=kwargs.get("server_name"),
|
||
|
|
user_api_key_auth=user_api_key_auth_dict,
|
||
|
|
hidden_params=HiddenParams(),
|
||
|
|
)
|
||
|
|
|
||
|
|
for callback in callbacks:
|
||
|
|
try:
|
||
|
|
_callback: Optional[CustomLogger] = None
|
||
|
|
if isinstance(callback, str):
|
||
|
|
from typing import cast
|
||
|
|
|
||
|
|
from litellm import _custom_logger_compatible_callbacks_literal
|
||
|
|
|
||
|
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||
|
|
cast(_custom_logger_compatible_callbacks_literal, callback)
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
_callback = callback # type: ignore
|
||
|
|
|
||
|
|
if _callback is not None and isinstance(_callback, CustomGuardrail):
|
||
|
|
from litellm.types.guardrails import GuardrailEventHooks
|
||
|
|
|
||
|
|
# Check if guardrail should be run for pre_call hook (reusing existing logic)
|
||
|
|
if (
|
||
|
|
_callback.should_run_guardrail(
|
||
|
|
data=kwargs, event_type=GuardrailEventHooks.pre_mcp_call
|
||
|
|
)
|
||
|
|
is not True
|
||
|
|
):
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Convert MCP tool call to LLM message format for existing guardrail logic
|
||
|
|
synthetic_llm_data = self._convert_mcp_to_llm_format(
|
||
|
|
request_obj, kwargs
|
||
|
|
)
|
||
|
|
# Reuse existing LLM guardrail logic
|
||
|
|
user_api_key_auth_dict = self._convert_user_api_key_auth_to_dict(
|
||
|
|
kwargs.get("user_api_key_auth")
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await _callback.async_pre_call_hook(
|
||
|
|
user_api_key_dict=user_api_key_auth_dict, # type: ignore
|
||
|
|
cache=self.call_details["user_api_key_cache"],
|
||
|
|
data=synthetic_llm_data,
|
||
|
|
call_type="mcp_call",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Convert result back to MCP response format if blocked/modified
|
||
|
|
if result is not None:
|
||
|
|
mcp_response = self._convert_llm_result_to_mcp_response(
|
||
|
|
result, request_obj
|
||
|
|
)
|
||
|
|
if mcp_response is not None:
|
||
|
|
return self._parse_pre_mcp_call_hook_response(
|
||
|
|
response=mcp_response, original_request=request_obj
|
||
|
|
)
|
||
|
|
|
||
|
|
except (
|
||
|
|
BlockedPiiEntityError,
|
||
|
|
GuardrailRaisedException,
|
||
|
|
HTTPException,
|
||
|
|
) as e:
|
||
|
|
# Re-raise guardrail exceptions so they can be properly handled
|
||
|
|
raise e
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.exception(
|
||
|
|
"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format(
|
||
|
|
str(e)
|
||
|
|
)
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
def _convert_user_api_key_auth_to_dict(self, user_api_key_auth_obj):
|
||
|
|
"""
|
||
|
|
Helper function to convert UserAPIKeyAuth object to dictionary.
|
||
|
|
Handles both Pydantic models and regular objects.
|
||
|
|
"""
|
||
|
|
if user_api_key_auth_obj is not None:
|
||
|
|
if hasattr(user_api_key_auth_obj, "model_dump"):
|
||
|
|
# If it's a Pydantic model, convert to dict
|
||
|
|
return user_api_key_auth_obj.model_dump()
|
||
|
|
elif hasattr(user_api_key_auth_obj, "__dict__"):
|
||
|
|
# If it's a regular object, convert to dict
|
||
|
|
return user_api_key_auth_obj.__dict__
|
||
|
|
return user_api_key_auth_obj
|
||
|
|
|
||
|
|
def _convert_mcp_to_llm_format(self, request_obj, kwargs: dict) -> dict:
|
||
|
|
"""
|
||
|
|
Convert MCP tool call to LLM message format for existing guardrail validation.
|
||
|
|
"""
|
||
|
|
from litellm.types.llms.openai import ChatCompletionUserMessage
|
||
|
|
|
||
|
|
# Create a synthetic message that represents the tool call
|
||
|
|
tool_call_content = (
|
||
|
|
f"Tool: {request_obj.tool_name}\nArguments: {request_obj.arguments}"
|
||
|
|
)
|
||
|
|
|
||
|
|
synthetic_message = ChatCompletionUserMessage(
|
||
|
|
role="user", content=tool_call_content
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create synthetic LLM data that guardrails can process
|
||
|
|
synthetic_data = {
|
||
|
|
"messages": [synthetic_message],
|
||
|
|
"model": kwargs.get("model", "mcp-tool-call"),
|
||
|
|
"user_api_key_user_id": kwargs.get("user_api_key_user_id"),
|
||
|
|
"user_api_key_team_id": kwargs.get("user_api_key_team_id"),
|
||
|
|
"user_api_key_end_user_id": kwargs.get("user_api_key_end_user_id"),
|
||
|
|
"user_api_key_hash": kwargs.get("user_api_key_hash"),
|
||
|
|
"user_api_key_request_route": kwargs.get("user_api_key_request_route"),
|
||
|
|
"mcp_tool_name": request_obj.tool_name, # Keep original for reference
|
||
|
|
"mcp_arguments": request_obj.arguments, # Keep original for reference
|
||
|
|
}
|
||
|
|
|
||
|
|
return synthetic_data
|
||
|
|
|
||
|
|
def _convert_llm_result_to_mcp_response(
|
||
|
|
self, llm_result, request_obj
|
||
|
|
) -> Optional[Any]:
|
||
|
|
"""
|
||
|
|
Convert LLM guardrail result back to MCP response format.
|
||
|
|
"""
|
||
|
|
from litellm.types.mcp import MCPPreCallResponseObject
|
||
|
|
|
||
|
|
# If result is an exception, it means the guardrail blocked the request
|
||
|
|
if isinstance(llm_result, Exception):
|
||
|
|
return MCPPreCallResponseObject(
|
||
|
|
should_proceed=False,
|
||
|
|
error_message=str(llm_result),
|
||
|
|
modified_arguments=None,
|
||
|
|
)
|
||
|
|
|
||
|
|
# If result is a dict with modified messages, check for content filtering
|
||
|
|
if isinstance(llm_result, dict):
|
||
|
|
modified_messages = llm_result.get("messages")
|
||
|
|
if modified_messages:
|
||
|
|
# Check if content was blocked/modified
|
||
|
|
original_content = (
|
||
|
|
f"Tool: {request_obj.tool_name}\nArguments: {request_obj.arguments}"
|
||
|
|
)
|
||
|
|
new_content = (
|
||
|
|
modified_messages[0].get("content", "") if modified_messages else ""
|
||
|
|
)
|
||
|
|
|
||
|
|
if new_content != original_content:
|
||
|
|
# Content was modified - could be masking, redaction, or blocking
|
||
|
|
if (
|
||
|
|
not new_content
|
||
|
|
or "blocked" in new_content.lower()
|
||
|
|
or "violation" in new_content.lower()
|
||
|
|
):
|
||
|
|
# Content was blocked completely
|
||
|
|
return MCPPreCallResponseObject(
|
||
|
|
should_proceed=False,
|
||
|
|
error_message="Content blocked by guardrail",
|
||
|
|
modified_arguments=None,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# Content was masked/redacted - extract the modified arguments
|
||
|
|
try:
|
||
|
|
# Try to parse the modified arguments from the masked content
|
||
|
|
modified_args = (
|
||
|
|
self._extract_modified_arguments_from_content(
|
||
|
|
new_content, request_obj
|
||
|
|
)
|
||
|
|
)
|
||
|
|
if modified_args is not None:
|
||
|
|
# Return the masked/redacted arguments for the MCP call to use
|
||
|
|
return MCPPreCallResponseObject(
|
||
|
|
should_proceed=True,
|
||
|
|
error_message=None,
|
||
|
|
modified_arguments=modified_args,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# Could not parse modified arguments, allow original call but warn
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
f"Could not parse modified arguments from guardrail response: {new_content}"
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
f"Error parsing modified arguments: {e}"
|
||
|
|
)
|
||
|
|
# Fallback: allow original call
|
||
|
|
return None
|
||
|
|
|
||
|
|
# If result is a string, it's likely an error message
|
||
|
|
if isinstance(llm_result, str):
|
||
|
|
return MCPPreCallResponseObject(
|
||
|
|
should_proceed=False, error_message=llm_result, modified_arguments=None
|
||
|
|
)
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
def _extract_modified_arguments_from_content(
|
||
|
|
self, masked_content: str, request_obj
|
||
|
|
) -> Optional[dict]:
|
||
|
|
"""
|
||
|
|
Extract modified/masked arguments from the guardrail response content.
|
||
|
|
"""
|
||
|
|
import json
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Extracting modified args from content: {masked_content}"
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# The format should be: "Tool: <tool_name>\nArguments: <json_arguments>"
|
||
|
|
# Parse the arguments section
|
||
|
|
lines = masked_content.strip().split("\n")
|
||
|
|
for i, line in enumerate(lines):
|
||
|
|
if line.startswith("Arguments:"):
|
||
|
|
# Get the arguments part - everything after "Arguments: "
|
||
|
|
args_text = line[len("Arguments:") :].strip()
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(f"Found arguments text: {args_text}")
|
||
|
|
|
||
|
|
# Try to parse as JSON first
|
||
|
|
try:
|
||
|
|
modified_args = json.loads(args_text)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Successfully parsed JSON args: {modified_args}"
|
||
|
|
)
|
||
|
|
return modified_args
|
||
|
|
except json.JSONDecodeError as e:
|
||
|
|
# If JSON parsing fails, try to extract key-value pairs manually
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Failed to parse JSON arguments: {args_text}, error: {e}"
|
||
|
|
)
|
||
|
|
return self._parse_arguments_manually(
|
||
|
|
args_text, request_obj.arguments
|
||
|
|
)
|
||
|
|
|
||
|
|
# If we can't find the Arguments: line, return None
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"Could not find 'Arguments:' line in masked content"
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(f"Error extracting modified arguments: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
def _parse_arguments_manually(
|
||
|
|
self, args_text: str, original_args: dict
|
||
|
|
) -> Optional[dict]:
|
||
|
|
"""
|
||
|
|
Try to manually parse arguments when JSON parsing fails.
|
||
|
|
This is a fallback for cases where the guardrail modifies the format.
|
||
|
|
"""
|
||
|
|
import re
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Start with original arguments and try to apply modifications
|
||
|
|
modified_args = original_args.copy()
|
||
|
|
|
||
|
|
# Look for simple key-value patterns
|
||
|
|
# This is a basic implementation - can be enhanced based on specific guardrail formats
|
||
|
|
for key, original_value in original_args.items():
|
||
|
|
if isinstance(original_value, str):
|
||
|
|
# Look for the key in the masked content and try to extract its value
|
||
|
|
pattern = (
|
||
|
|
rf"['\"]?{re.escape(key)}['\"]?\s*:\s*['\"]?([^,'\"]*)['\"]?"
|
||
|
|
)
|
||
|
|
match = re.search(pattern, args_text, re.IGNORECASE)
|
||
|
|
if match:
|
||
|
|
new_value = match.group(1).strip()
|
||
|
|
if new_value:
|
||
|
|
modified_args[key] = new_value
|
||
|
|
|
||
|
|
return modified_args
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(f"Error in manual argument parsing: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
def _convert_llm_result_to_mcp_during_response(
|
||
|
|
self, llm_result, request_obj
|
||
|
|
) -> Optional[Any]:
|
||
|
|
"""
|
||
|
|
Convert LLM guardrail result back to MCP during call response format.
|
||
|
|
"""
|
||
|
|
from litellm.types.mcp import MCPDuringCallResponseObject
|
||
|
|
|
||
|
|
# If result is an exception, it means the guardrail wants to stop execution
|
||
|
|
if isinstance(llm_result, Exception):
|
||
|
|
return MCPDuringCallResponseObject(
|
||
|
|
should_continue=False, error_message=str(llm_result)
|
||
|
|
)
|
||
|
|
|
||
|
|
# If result is a dict with modified messages, check for content filtering
|
||
|
|
if isinstance(llm_result, dict):
|
||
|
|
modified_messages = llm_result.get("messages")
|
||
|
|
if modified_messages:
|
||
|
|
# Check if content was blocked/modified
|
||
|
|
original_content = (
|
||
|
|
f"Tool: {request_obj.tool_name}\nArguments: {request_obj.arguments}"
|
||
|
|
)
|
||
|
|
new_content = (
|
||
|
|
modified_messages[0].get("content", "") if modified_messages else ""
|
||
|
|
)
|
||
|
|
|
||
|
|
if new_content != original_content:
|
||
|
|
# Content was modified, could be masking or blocking
|
||
|
|
if not new_content or "blocked" in new_content.lower():
|
||
|
|
# Content was blocked
|
||
|
|
return MCPDuringCallResponseObject(
|
||
|
|
should_continue=False,
|
||
|
|
error_message="Content blocked by guardrail during execution",
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# Content was masked/modified - for now, stop execution
|
||
|
|
return MCPDuringCallResponseObject(
|
||
|
|
should_continue=False,
|
||
|
|
error_message="Content modified by guardrail during execution",
|
||
|
|
)
|
||
|
|
|
||
|
|
# If result is a string, it's likely an error message
|
||
|
|
if isinstance(llm_result, str):
|
||
|
|
return MCPDuringCallResponseObject(
|
||
|
|
should_continue=False, error_message=llm_result
|
||
|
|
)
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
def get_combined_callback_list(
|
||
|
|
self, dynamic_success_callbacks: Optional[List], global_callbacks: List
|
||
|
|
) -> List:
|
||
|
|
if dynamic_success_callbacks is None:
|
||
|
|
return global_callbacks
|
||
|
|
return list(set(dynamic_success_callbacks + global_callbacks))
|
||
|
|
|
||
|
|
def _parse_pre_mcp_call_hook_response(
|
||
|
|
self,
|
||
|
|
response: MCPPreCallResponseObject,
|
||
|
|
original_request: MCPPreCallRequestObject,
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Parse the response from the pre_mcp_tool_call_hook
|
||
|
|
|
||
|
|
1. Check if the call should proceed
|
||
|
|
2. Apply any argument modifications
|
||
|
|
3. Handle validation errors
|
||
|
|
"""
|
||
|
|
result = {
|
||
|
|
"should_proceed": response.should_proceed,
|
||
|
|
"modified_arguments": response.modified_arguments
|
||
|
|
or original_request.arguments,
|
||
|
|
"error_message": response.error_message,
|
||
|
|
"hidden_params": response.hidden_params,
|
||
|
|
}
|
||
|
|
return result
|
||
|
|
|
||
|
|
async def async_during_mcp_tool_call_hook(
|
||
|
|
self,
|
||
|
|
kwargs: dict,
|
||
|
|
request_obj: Any,
|
||
|
|
start_time: datetime,
|
||
|
|
end_time: datetime,
|
||
|
|
) -> Optional[Any]:
|
||
|
|
"""
|
||
|
|
During MCP Tool Call Hook
|
||
|
|
|
||
|
|
Use this for concurrent monitoring and validation during tool execution.
|
||
|
|
Reuses existing LLM guardrail logic by converting MCP calls to message format.
|
||
|
|
"""
|
||
|
|
from litellm.types.llms.base import HiddenParams
|
||
|
|
from litellm.types.mcp import MCPDuringCallRequestObject
|
||
|
|
|
||
|
|
callbacks = self.get_combined_callback_list(
|
||
|
|
dynamic_success_callbacks=getattr(self, "dynamic_success_callbacks", None),
|
||
|
|
global_callbacks=litellm.success_callback,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create the request object if it's not already one
|
||
|
|
if not isinstance(request_obj, MCPDuringCallRequestObject):
|
||
|
|
request_obj = MCPDuringCallRequestObject(
|
||
|
|
tool_name=kwargs.get("name", ""),
|
||
|
|
arguments=kwargs.get("arguments", {}),
|
||
|
|
server_name=kwargs.get("server_name"),
|
||
|
|
start_time=start_time.timestamp() if start_time else None,
|
||
|
|
hidden_params=HiddenParams(),
|
||
|
|
)
|
||
|
|
|
||
|
|
for callback in callbacks:
|
||
|
|
try:
|
||
|
|
_callback: Optional[CustomLogger] = None
|
||
|
|
if isinstance(callback, str):
|
||
|
|
from typing import cast
|
||
|
|
|
||
|
|
from litellm import _custom_logger_compatible_callbacks_literal
|
||
|
|
|
||
|
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||
|
|
cast(_custom_logger_compatible_callbacks_literal, callback)
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
_callback = callback # type: ignore
|
||
|
|
|
||
|
|
if _callback is not None and isinstance(_callback, CustomGuardrail):
|
||
|
|
from litellm.types.guardrails import GuardrailEventHooks
|
||
|
|
|
||
|
|
# Check if guardrail should be run for during_call hook (reusing existing logic)
|
||
|
|
if (
|
||
|
|
_callback.should_run_guardrail(
|
||
|
|
data=kwargs, event_type=GuardrailEventHooks.during_mcp_call
|
||
|
|
)
|
||
|
|
is not True
|
||
|
|
):
|
||
|
|
continue
|
||
|
|
# Convert MCP tool call to LLM message format for existing guardrail logic
|
||
|
|
synthetic_llm_data = self._convert_mcp_to_llm_format(
|
||
|
|
request_obj, kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
# Reuse existing LLM guardrail logic for during call
|
||
|
|
user_api_key_auth_dict = self._convert_user_api_key_auth_to_dict(
|
||
|
|
kwargs.get("user_api_key_auth")
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await _callback.async_moderation_hook(
|
||
|
|
data=synthetic_llm_data,
|
||
|
|
user_api_key_dict=user_api_key_auth_dict, # type: ignore
|
||
|
|
call_type="mcp_call",
|
||
|
|
)
|
||
|
|
# Convert result back to MCP response format if blocked/modified
|
||
|
|
if result is not None:
|
||
|
|
mcp_response = self._convert_llm_result_to_mcp_during_response(
|
||
|
|
result, request_obj
|
||
|
|
)
|
||
|
|
if mcp_response is not None:
|
||
|
|
return self._parse_during_mcp_call_hook_response(
|
||
|
|
response=mcp_response
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
raise e
|
||
|
|
verbose_proxy_logger.exception(
|
||
|
|
"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format(
|
||
|
|
str(e)
|
||
|
|
)
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
def _parse_during_mcp_call_hook_response(
|
||
|
|
self, response: MCPDuringCallResponseObject
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Parse the response from the during_mcp_tool_call_hook
|
||
|
|
|
||
|
|
1. Check if execution should continue
|
||
|
|
2. Handle any error messages
|
||
|
|
3. Apply any hidden parameter updates
|
||
|
|
"""
|
||
|
|
result = {
|
||
|
|
"should_continue": response.should_continue,
|
||
|
|
"error_message": response.error_message,
|
||
|
|
"hidden_params": response.hidden_params,
|
||
|
|
}
|
||
|
|
return result
|
||
|
|
|
||
|
|
async def process_pre_call_hook_response(self, response, data, call_type):
|
||
|
|
if isinstance(response, Exception):
|
||
|
|
raise response
|
||
|
|
if isinstance(response, dict):
|
||
|
|
return response
|
||
|
|
if isinstance(response, str):
|
||
|
|
if call_type in ["completion", "text_completion"]:
|
||
|
|
raise RejectedRequestError(
|
||
|
|
message=response,
|
||
|
|
model=data.get("model", ""),
|
||
|
|
llm_provider="",
|
||
|
|
request_data=data,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
raise HTTPException(status_code=400, detail={"error": response})
|
||
|
|
return data
|
||
|
|
|
||
|
|
# The actual implementation of the function
|
||
|
|
@overload
|
||
|
|
async def pre_call_hook(
|
||
|
|
self,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
data: None,
|
||
|
|
call_type: Literal[
|
||
|
|
"completion",
|
||
|
|
"text_completion",
|
||
|
|
"embeddings",
|
||
|
|
"image_generation",
|
||
|
|
"moderation",
|
||
|
|
"audio_transcription",
|
||
|
|
"pass_through_endpoint",
|
||
|
|
"rerank",
|
||
|
|
],
|
||
|
|
) -> None:
|
||
|
|
pass
|
||
|
|
|
||
|
|
@overload
|
||
|
|
async def pre_call_hook(
|
||
|
|
self,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
data: dict,
|
||
|
|
call_type: Literal[
|
||
|
|
"completion",
|
||
|
|
"text_completion",
|
||
|
|
"embeddings",
|
||
|
|
"image_generation",
|
||
|
|
"moderation",
|
||
|
|
"audio_transcription",
|
||
|
|
"pass_through_endpoint",
|
||
|
|
"rerank",
|
||
|
|
],
|
||
|
|
) -> dict:
|
||
|
|
pass
|
||
|
|
|
||
|
|
async def pre_call_hook(
|
||
|
|
self,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
data: Optional[dict],
|
||
|
|
call_type: Literal[
|
||
|
|
"completion",
|
||
|
|
"text_completion",
|
||
|
|
"embeddings",
|
||
|
|
"image_generation",
|
||
|
|
"moderation",
|
||
|
|
"audio_transcription",
|
||
|
|
"pass_through_endpoint",
|
||
|
|
"rerank",
|
||
|
|
],
|
||
|
|
) -> Optional[dict]:
|
||
|
|
"""
|
||
|
|
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||
|
|
|
||
|
|
Covers:
|
||
|
|
1. /chat/completions
|
||
|
|
2. /embeddings
|
||
|
|
3. /image/generation
|
||
|
|
"""
|
||
|
|
from litellm.utils import get_non_default_completion_params
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug("Inside Proxy Logging Pre-call hook!")
|
||
|
|
|
||
|
|
self._init_response_taking_too_long_task(data=data)
|
||
|
|
|
||
|
|
if data is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
litellm_logging_obj = cast(
|
||
|
|
Optional["LiteLLMLoggingObj"], data.get("litellm_logging_obj", None)
|
||
|
|
)
|
||
|
|
prompt_id = data.get("prompt_id", None)
|
||
|
|
|
||
|
|
## PROMPT TEMPLATE CHECK ##
|
||
|
|
if (
|
||
|
|
litellm_logging_obj is not None
|
||
|
|
and prompt_id is not None
|
||
|
|
and (call_type == "completion" or call_type == "acompletion")
|
||
|
|
):
|
||
|
|
from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY
|
||
|
|
|
||
|
|
custom_logger = IN_MEMORY_PROMPT_REGISTRY.get_prompt_callback_by_id(
|
||
|
|
prompt_id
|
||
|
|
)
|
||
|
|
prompt_spec = IN_MEMORY_PROMPT_REGISTRY.get_prompt_by_id(prompt_id)
|
||
|
|
litellm_prompt_id: Optional[str] = None
|
||
|
|
if prompt_spec is not None:
|
||
|
|
litellm_prompt_id = prompt_spec.litellm_params.prompt_id
|
||
|
|
|
||
|
|
if custom_logger and litellm_prompt_id is not None:
|
||
|
|
(
|
||
|
|
model,
|
||
|
|
messages,
|
||
|
|
optional_params,
|
||
|
|
) = litellm_logging_obj.get_chat_completion_prompt(
|
||
|
|
model=data.get("model", ""),
|
||
|
|
messages=data.get("messages", []),
|
||
|
|
non_default_params=get_non_default_completion_params(kwargs=data),
|
||
|
|
prompt_id=litellm_prompt_id,
|
||
|
|
prompt_management_logger=custom_logger,
|
||
|
|
prompt_variables=data.get("prompt_variables", None),
|
||
|
|
prompt_label=data.get("prompt_label", None),
|
||
|
|
prompt_version=data.get("prompt_version", None),
|
||
|
|
)
|
||
|
|
|
||
|
|
data.update(optional_params)
|
||
|
|
data["model"] = model
|
||
|
|
data["messages"] = messages
|
||
|
|
|
||
|
|
try:
|
||
|
|
for callback in litellm.callbacks:
|
||
|
|
_callback = None
|
||
|
|
if isinstance(callback, str):
|
||
|
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||
|
|
callback
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
_callback = callback # type: ignore
|
||
|
|
if _callback is not None and isinstance(_callback, CustomGuardrail):
|
||
|
|
from litellm.types.guardrails import GuardrailEventHooks
|
||
|
|
|
||
|
|
if (
|
||
|
|
_callback.should_run_guardrail(
|
||
|
|
data=data, event_type=GuardrailEventHooks.pre_call
|
||
|
|
)
|
||
|
|
is not True
|
||
|
|
):
|
||
|
|
continue
|
||
|
|
|
||
|
|
response = await _callback.async_pre_call_hook(
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
cache=self.call_details["user_api_key_cache"],
|
||
|
|
data=data, # type: ignore
|
||
|
|
call_type=call_type,
|
||
|
|
)
|
||
|
|
if response is not None:
|
||
|
|
data = await self.process_pre_call_hook_response(
|
||
|
|
response=response, data=data, call_type=call_type
|
||
|
|
)
|
||
|
|
|
||
|
|
elif (
|
||
|
|
_callback is not None
|
||
|
|
and isinstance(_callback, CustomLogger)
|
||
|
|
and "async_pre_call_hook" in vars(_callback.__class__)
|
||
|
|
and _callback.__class__.async_pre_call_hook
|
||
|
|
!= CustomLogger.async_pre_call_hook
|
||
|
|
):
|
||
|
|
response = await _callback.async_pre_call_hook(
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
cache=self.call_details["user_api_key_cache"],
|
||
|
|
data=data, # type: ignore
|
||
|
|
call_type=call_type,
|
||
|
|
)
|
||
|
|
if response is not None:
|
||
|
|
data = await self.process_pre_call_hook_response(
|
||
|
|
response=response, data=data, call_type=call_type
|
||
|
|
)
|
||
|
|
|
||
|
|
return data
|
||
|
|
except Exception as e:
|
||
|
|
raise e
|
||
|
|
|
||
|
|
async def during_call_hook(
|
||
|
|
self,
|
||
|
|
data: dict,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
call_type: Literal[
|
||
|
|
"completion",
|
||
|
|
"responses",
|
||
|
|
"embeddings",
|
||
|
|
"image_generation",
|
||
|
|
"moderation",
|
||
|
|
"audio_transcription",
|
||
|
|
],
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Runs the CustomGuardrail's async_moderation_hook()
|
||
|
|
"""
|
||
|
|
for callback in litellm.callbacks:
|
||
|
|
try:
|
||
|
|
if isinstance(callback, CustomGuardrail):
|
||
|
|
################################################################
|
||
|
|
# Check if guardrail should be run for GuardrailEventHooks.during_call hook
|
||
|
|
################################################################
|
||
|
|
|
||
|
|
# V1 implementation - backwards compatibility
|
||
|
|
if callback.event_hook is None and hasattr(
|
||
|
|
callback, "moderation_check"
|
||
|
|
):
|
||
|
|
if callback.moderation_check == "pre_call": # type: ignore
|
||
|
|
return
|
||
|
|
else:
|
||
|
|
# Main - V2 Guardrails implementation
|
||
|
|
from litellm.types.guardrails import GuardrailEventHooks
|
||
|
|
|
||
|
|
if (
|
||
|
|
callback.should_run_guardrail(
|
||
|
|
data=data, event_type=GuardrailEventHooks.during_call
|
||
|
|
)
|
||
|
|
is not True
|
||
|
|
):
|
||
|
|
continue
|
||
|
|
await callback.async_moderation_hook(
|
||
|
|
data=data,
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
call_type=call_type,
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
raise e
|
||
|
|
return data
|
||
|
|
|
||
|
|
async def failed_tracking_alert(
|
||
|
|
self,
|
||
|
|
error_message: str,
|
||
|
|
failing_model: str,
|
||
|
|
):
|
||
|
|
if self.alerting is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
if self.slack_alerting_instance:
|
||
|
|
await self.slack_alerting_instance.failed_tracking_alert(
|
||
|
|
error_message=error_message,
|
||
|
|
failing_model=failing_model,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def budget_alerts(
|
||
|
|
self,
|
||
|
|
type: Literal[
|
||
|
|
"token_budget",
|
||
|
|
"user_budget",
|
||
|
|
"soft_budget",
|
||
|
|
"team_budget",
|
||
|
|
"proxy_budget",
|
||
|
|
"projected_limit_exceeded",
|
||
|
|
],
|
||
|
|
user_info: CallInfo,
|
||
|
|
):
|
||
|
|
if self.alerting is None:
|
||
|
|
# do nothing if alerting is not switched on
|
||
|
|
return
|
||
|
|
await self.slack_alerting_instance.budget_alerts(
|
||
|
|
type=type,
|
||
|
|
user_info=user_info,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def alerting_handler(
|
||
|
|
self,
|
||
|
|
message: str,
|
||
|
|
level: Literal["Low", "Medium", "High"],
|
||
|
|
alert_type: AlertType,
|
||
|
|
request_data: Optional[dict] = None,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
||
|
|
|
||
|
|
- Responses taking too long
|
||
|
|
- Requests are hanging
|
||
|
|
- Calls are failing
|
||
|
|
- DB Read/Writes are failing
|
||
|
|
- Proxy Close to max budget
|
||
|
|
- Key Close to max budget
|
||
|
|
|
||
|
|
Parameters:
|
||
|
|
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
|
||
|
|
message: str - what is the alert about
|
||
|
|
"""
|
||
|
|
if self.alerting is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
# Get the current timestamp
|
||
|
|
current_time = datetime.now().strftime("%H:%M:%S")
|
||
|
|
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
||
|
|
formatted_message = (
|
||
|
|
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
||
|
|
)
|
||
|
|
if _proxy_base_url is not None:
|
||
|
|
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
||
|
|
|
||
|
|
extra_kwargs = {}
|
||
|
|
alerting_metadata = {}
|
||
|
|
if request_data is not None:
|
||
|
|
_url = await _add_langfuse_trace_id_to_alert(request_data=request_data)
|
||
|
|
|
||
|
|
if _url is not None:
|
||
|
|
extra_kwargs["🪢 Langfuse Trace"] = _url
|
||
|
|
formatted_message += "\n\n🪢 Langfuse Trace: {}".format(_url)
|
||
|
|
if (
|
||
|
|
"metadata" in request_data
|
||
|
|
and request_data["metadata"].get("alerting_metadata", None) is not None
|
||
|
|
and isinstance(request_data["metadata"]["alerting_metadata"], dict)
|
||
|
|
):
|
||
|
|
alerting_metadata = request_data["metadata"]["alerting_metadata"]
|
||
|
|
for client in self.alerting:
|
||
|
|
if client == "slack":
|
||
|
|
await self.slack_alerting_instance.send_alert(
|
||
|
|
message=message,
|
||
|
|
level=level,
|
||
|
|
alert_type=alert_type,
|
||
|
|
user_info=None,
|
||
|
|
alerting_metadata=alerting_metadata,
|
||
|
|
**extra_kwargs,
|
||
|
|
)
|
||
|
|
elif client == "sentry":
|
||
|
|
if litellm.utils.sentry_sdk_instance is not None:
|
||
|
|
litellm.utils.sentry_sdk_instance.capture_message(formatted_message)
|
||
|
|
else:
|
||
|
|
raise Exception("Missing SENTRY_DSN from environment")
|
||
|
|
|
||
|
|
async def failure_handler(
|
||
|
|
self, original_exception, duration: float, call_type: str, traceback_str=""
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Log failed db read/writes
|
||
|
|
|
||
|
|
Currently only logs exceptions to sentry
|
||
|
|
"""
|
||
|
|
### ALERTING ###
|
||
|
|
if AlertType.db_exceptions not in self.alert_types:
|
||
|
|
return
|
||
|
|
if isinstance(original_exception, HTTPException):
|
||
|
|
if isinstance(original_exception.detail, str):
|
||
|
|
error_message = original_exception.detail
|
||
|
|
elif isinstance(original_exception.detail, dict):
|
||
|
|
error_message = json.dumps(original_exception.detail)
|
||
|
|
else:
|
||
|
|
error_message = str(original_exception)
|
||
|
|
else:
|
||
|
|
error_message = str(original_exception)
|
||
|
|
if isinstance(traceback_str, str):
|
||
|
|
error_message += traceback_str[:1000]
|
||
|
|
asyncio.create_task(
|
||
|
|
self.alerting_handler(
|
||
|
|
message=f"DB read/write call failed: {error_message}",
|
||
|
|
level="High",
|
||
|
|
alert_type=AlertType.db_exceptions,
|
||
|
|
request_data={},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
if hasattr(self, "service_logging_obj"):
|
||
|
|
await self.service_logging_obj.async_service_failure_hook(
|
||
|
|
service=ServiceTypes.DB,
|
||
|
|
duration=duration,
|
||
|
|
error=error_message,
|
||
|
|
call_type=call_type,
|
||
|
|
)
|
||
|
|
|
||
|
|
if litellm.utils.capture_exception:
|
||
|
|
litellm.utils.capture_exception(error=original_exception)
|
||
|
|
|
||
|
|
async def post_call_failure_hook(
|
||
|
|
self,
|
||
|
|
request_data: dict,
|
||
|
|
original_exception: Exception,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
error_type: Optional[ProxyErrorTypes] = None,
|
||
|
|
route: Optional[str] = None,
|
||
|
|
traceback_str: Optional[str] = None,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
|
||
|
|
|
||
|
|
Covers:
|
||
|
|
1. /chat/completions
|
||
|
|
2. /embeddings
|
||
|
|
3. /image/generation
|
||
|
|
|
||
|
|
Args:
|
||
|
|
- request_data: dict - The request data.
|
||
|
|
- original_exception: Exception - The original exception.
|
||
|
|
- user_api_key_dict: UserAPIKeyAuth - The user api key dict.
|
||
|
|
- error_type: Optional[ProxyErrorTypes] - The error type.
|
||
|
|
- route: Optional[str] - The route.
|
||
|
|
- traceback_str: Optional[str] - The traceback string, sometimes upstream endpoints might need to send the upstream traceback. In which case we use this
|
||
|
|
"""
|
||
|
|
|
||
|
|
### ALERTING ###
|
||
|
|
await self.update_request_status(
|
||
|
|
litellm_call_id=request_data.get("litellm_call_id", ""), status="fail"
|
||
|
|
)
|
||
|
|
if AlertType.llm_exceptions in self.alert_types and not isinstance(
|
||
|
|
original_exception, HTTPException
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Just alert on LLM API exceptions. Do not alert on user errors
|
||
|
|
|
||
|
|
Related issue - https://github.com/BerriAI/litellm/issues/3395
|
||
|
|
"""
|
||
|
|
litellm_debug_info = getattr(original_exception, "litellm_debug_info", None)
|
||
|
|
exception_str = str(original_exception)
|
||
|
|
if litellm_debug_info is not None:
|
||
|
|
exception_str += litellm_debug_info
|
||
|
|
|
||
|
|
asyncio.create_task(
|
||
|
|
self.alerting_handler(
|
||
|
|
message=f"LLM API call failed: `{exception_str}`",
|
||
|
|
level="High",
|
||
|
|
alert_type=AlertType.llm_exceptions,
|
||
|
|
request_data=request_data,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
### LOGGING ###
|
||
|
|
if self._is_proxy_only_llm_api_error(
|
||
|
|
original_exception=original_exception,
|
||
|
|
error_type=error_type,
|
||
|
|
route=user_api_key_dict.request_route,
|
||
|
|
):
|
||
|
|
await self._handle_logging_proxy_only_error(
|
||
|
|
request_data=request_data,
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
route=route,
|
||
|
|
original_exception=original_exception,
|
||
|
|
)
|
||
|
|
|
||
|
|
for callback in litellm.callbacks:
|
||
|
|
try:
|
||
|
|
_callback: Optional[CustomLogger] = None
|
||
|
|
if isinstance(callback, str):
|
||
|
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||
|
|
callback
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
_callback = callback # type: ignore
|
||
|
|
if _callback is not None and isinstance(_callback, CustomLogger):
|
||
|
|
asyncio.create_task(
|
||
|
|
_callback.async_post_call_failure_hook(
|
||
|
|
request_data=request_data,
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
original_exception=original_exception,
|
||
|
|
traceback_str=traceback_str,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.exception(
|
||
|
|
f"[Non-Blocking] Error in post_call_failure_hook: {e}"
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
def _is_proxy_only_llm_api_error(
|
||
|
|
self,
|
||
|
|
original_exception: Exception,
|
||
|
|
error_type: Optional[ProxyErrorTypes] = None,
|
||
|
|
route: Optional[str] = None,
|
||
|
|
) -> bool:
|
||
|
|
"""
|
||
|
|
Return True if the error is a Proxy Only LLM API Error
|
||
|
|
|
||
|
|
Prevents double logging of LLM API exceptions
|
||
|
|
|
||
|
|
e.g should only return True for:
|
||
|
|
- Authentication Errors from user_api_key_auth
|
||
|
|
- HTTP HTTPException (rate limit errors)
|
||
|
|
"""
|
||
|
|
|
||
|
|
#########################################################
|
||
|
|
# Only log LLM API errors for proxy level hooks
|
||
|
|
# eg. Authentication errors, rate limit errors, etc.
|
||
|
|
# Note: This fixes a security issue where we
|
||
|
|
# would log temporary keys/auth info
|
||
|
|
# from management endpoints
|
||
|
|
#########################################################
|
||
|
|
if route is None:
|
||
|
|
return False
|
||
|
|
if RouteChecks.is_llm_api_route(route) is not True:
|
||
|
|
return False
|
||
|
|
|
||
|
|
return isinstance(original_exception, HTTPException) or (
|
||
|
|
error_type == ProxyErrorTypes.auth_error
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _handle_logging_proxy_only_error(
|
||
|
|
self,
|
||
|
|
request_data: dict,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
route: Optional[str] = None,
|
||
|
|
original_exception: Optional[Exception] = None,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Handle logging for proxy only errors by calling `litellm_logging_obj.async_failure_handler`
|
||
|
|
|
||
|
|
Is triggered when self._is_proxy_only_error() returns True
|
||
|
|
"""
|
||
|
|
litellm_logging_obj: Optional[Logging] = request_data.get(
|
||
|
|
"litellm_logging_obj", None
|
||
|
|
)
|
||
|
|
if litellm_logging_obj is None:
|
||
|
|
import uuid
|
||
|
|
|
||
|
|
request_data["litellm_call_id"] = str(uuid.uuid4())
|
||
|
|
user_api_key_logged_metadata = (
|
||
|
|
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||
|
|
user_api_key_dict=user_api_key_dict
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
litellm_logging_obj, data = litellm.utils.function_setup(
|
||
|
|
original_function=route or "IGNORE_THIS",
|
||
|
|
rules_obj=litellm.utils.Rules(),
|
||
|
|
start_time=datetime.now(),
|
||
|
|
**request_data,
|
||
|
|
)
|
||
|
|
if "metadata" not in request_data:
|
||
|
|
request_data["metadata"] = {}
|
||
|
|
request_data["metadata"].update(user_api_key_logged_metadata)
|
||
|
|
|
||
|
|
if litellm_logging_obj is not None:
|
||
|
|
## UPDATE LOGGING INPUT
|
||
|
|
_optional_params = {}
|
||
|
|
_litellm_params = {}
|
||
|
|
|
||
|
|
litellm_param_keys = LoggedLiteLLMParams.__annotations__.keys()
|
||
|
|
for k, v in request_data.items():
|
||
|
|
if k in litellm_param_keys:
|
||
|
|
_litellm_params[k] = v
|
||
|
|
elif k != "model" and k != "user":
|
||
|
|
_optional_params[k] = v
|
||
|
|
|
||
|
|
litellm_logging_obj.update_environment_variables(
|
||
|
|
model=request_data.get("model", ""),
|
||
|
|
user=request_data.get("user", ""),
|
||
|
|
optional_params=_optional_params,
|
||
|
|
litellm_params=_litellm_params,
|
||
|
|
)
|
||
|
|
|
||
|
|
input: Union[list, str, dict] = ""
|
||
|
|
if "messages" in request_data and isinstance(
|
||
|
|
request_data["messages"], list
|
||
|
|
):
|
||
|
|
input = request_data["messages"]
|
||
|
|
litellm_logging_obj.model_call_details["messages"] = input
|
||
|
|
litellm_logging_obj.call_type = CallTypes.acompletion.value
|
||
|
|
elif "prompt" in request_data and isinstance(request_data["prompt"], str):
|
||
|
|
input = request_data["prompt"]
|
||
|
|
litellm_logging_obj.model_call_details["prompt"] = input
|
||
|
|
litellm_logging_obj.call_type = CallTypes.atext_completion.value
|
||
|
|
elif "input" in request_data and isinstance(request_data["input"], list):
|
||
|
|
input = request_data["input"]
|
||
|
|
litellm_logging_obj.model_call_details["input"] = input
|
||
|
|
litellm_logging_obj.call_type = CallTypes.aembedding.value
|
||
|
|
litellm_logging_obj.pre_call(
|
||
|
|
input=input,
|
||
|
|
api_key="",
|
||
|
|
)
|
||
|
|
|
||
|
|
# log the custom exception
|
||
|
|
await litellm_logging_obj.async_failure_handler(
|
||
|
|
exception=original_exception,
|
||
|
|
traceback_exception=traceback.format_exc(),
|
||
|
|
)
|
||
|
|
|
||
|
|
threading.Thread(
|
||
|
|
target=litellm_logging_obj.failure_handler,
|
||
|
|
args=(
|
||
|
|
original_exception,
|
||
|
|
traceback.format_exc(),
|
||
|
|
),
|
||
|
|
).start()
|
||
|
|
|
||
|
|
async def post_call_success_hook(
|
||
|
|
self,
|
||
|
|
data: dict,
|
||
|
|
response: LLMResponseTypes,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Allow user to modify outgoing data
|
||
|
|
|
||
|
|
Covers:
|
||
|
|
1. /chat/completions
|
||
|
|
2. /embeddings
|
||
|
|
3. /image/generation
|
||
|
|
4. /files
|
||
|
|
"""
|
||
|
|
|
||
|
|
for callback in litellm.callbacks:
|
||
|
|
try:
|
||
|
|
_callback: Optional[CustomLogger] = None
|
||
|
|
if isinstance(callback, str):
|
||
|
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||
|
|
callback
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
_callback = callback # type: ignore
|
||
|
|
|
||
|
|
if _callback is not None:
|
||
|
|
############## Handle Guardrails ########################################
|
||
|
|
#############################################################################
|
||
|
|
if isinstance(callback, CustomGuardrail):
|
||
|
|
# Main - V2 Guardrails implementation
|
||
|
|
from litellm.types.guardrails import GuardrailEventHooks
|
||
|
|
|
||
|
|
if (
|
||
|
|
callback.should_run_guardrail(
|
||
|
|
data=data, event_type=GuardrailEventHooks.post_call
|
||
|
|
)
|
||
|
|
is not True
|
||
|
|
):
|
||
|
|
continue
|
||
|
|
|
||
|
|
await callback.async_post_call_success_hook(
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
data=data,
|
||
|
|
response=response,
|
||
|
|
)
|
||
|
|
|
||
|
|
############ Handle CustomLogger ###############################
|
||
|
|
#################################################################
|
||
|
|
elif isinstance(_callback, CustomLogger):
|
||
|
|
await _callback.async_post_call_success_hook(
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
data=data,
|
||
|
|
response=response,
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
raise e
|
||
|
|
return response
|
||
|
|
|
||
|
|
async def async_post_call_streaming_hook(
|
||
|
|
self,
|
||
|
|
data: dict,
|
||
|
|
response: Union[
|
||
|
|
ModelResponse, EmbeddingResponse, ImageResponse, ModelResponseStream
|
||
|
|
],
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
str_so_far: Optional[str] = None,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Allow user to modify outgoing streaming data -> per chunk
|
||
|
|
|
||
|
|
Covers:
|
||
|
|
1. /chat/completions
|
||
|
|
"""
|
||
|
|
from litellm.proxy.proxy_server import llm_router
|
||
|
|
|
||
|
|
response_str: Optional[str] = None
|
||
|
|
if isinstance(response, (ModelResponse, ModelResponseStream)):
|
||
|
|
response_str = litellm.get_response_string(response_obj=response)
|
||
|
|
if response_str is not None:
|
||
|
|
for callback in litellm.callbacks:
|
||
|
|
try:
|
||
|
|
_callback: Optional[CustomLogger] = None
|
||
|
|
if isinstance(callback, CustomGuardrail):
|
||
|
|
# Main - V2 Guardrails implementation
|
||
|
|
from litellm.types.guardrails import GuardrailEventHooks
|
||
|
|
|
||
|
|
## CHECK FOR MODEL-LEVEL GUARDRAILS
|
||
|
|
modified_data = _check_and_merge_model_level_guardrails(
|
||
|
|
data=data, llm_router=llm_router
|
||
|
|
)
|
||
|
|
|
||
|
|
if (
|
||
|
|
callback.should_run_guardrail(
|
||
|
|
data=modified_data,
|
||
|
|
event_type=GuardrailEventHooks.post_call,
|
||
|
|
)
|
||
|
|
is not True
|
||
|
|
):
|
||
|
|
continue
|
||
|
|
if isinstance(callback, str):
|
||
|
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||
|
|
callback
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
_callback = callback # type: ignore
|
||
|
|
if _callback is not None and isinstance(_callback, CustomLogger):
|
||
|
|
if str_so_far is not None:
|
||
|
|
complete_response = str_so_far + response_str
|
||
|
|
else:
|
||
|
|
complete_response = response_str
|
||
|
|
potential_error_response = (
|
||
|
|
await _callback.async_post_call_streaming_hook(
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
response=complete_response,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
if isinstance(
|
||
|
|
potential_error_response, str
|
||
|
|
) and potential_error_response.startswith("data: "):
|
||
|
|
return potential_error_response
|
||
|
|
except Exception as e:
|
||
|
|
raise e
|
||
|
|
return response
|
||
|
|
|
||
|
|
def async_post_call_streaming_iterator_hook(
|
||
|
|
self,
|
||
|
|
response,
|
||
|
|
user_api_key_dict: UserAPIKeyAuth,
|
||
|
|
request_data: dict,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Allow user to modify outgoing streaming data -> Given a whole response iterator.
|
||
|
|
This hook is best used when you need to modify multiple chunks of the response at once.
|
||
|
|
|
||
|
|
Covers:
|
||
|
|
1. /chat/completions
|
||
|
|
"""
|
||
|
|
for callback in litellm.callbacks:
|
||
|
|
_callback: Optional[CustomLogger] = None
|
||
|
|
if isinstance(callback, str):
|
||
|
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||
|
|
callback
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
_callback = callback # type: ignore
|
||
|
|
if _callback is not None and isinstance(_callback, CustomLogger):
|
||
|
|
if not isinstance(
|
||
|
|
_callback, CustomGuardrail
|
||
|
|
) or _callback.should_run_guardrail(
|
||
|
|
data=request_data, event_type=GuardrailEventHooks.post_call
|
||
|
|
):
|
||
|
|
response = _callback.async_post_call_streaming_iterator_hook(
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
response=response,
|
||
|
|
request_data=request_data,
|
||
|
|
)
|
||
|
|
return response
|
||
|
|
|
||
|
|
def _init_response_taking_too_long_task(self, data: Optional[dict] = None):
|
||
|
|
"""
|
||
|
|
Initialize the response taking too long task if user is using slack alerting
|
||
|
|
|
||
|
|
Only run task if user is using slack alerting
|
||
|
|
|
||
|
|
This handles checking for if a request is hanging for too long
|
||
|
|
"""
|
||
|
|
## ALERTING ###
|
||
|
|
if (
|
||
|
|
self.slack_alerting_instance
|
||
|
|
and self.slack_alerting_instance.alerting is not None
|
||
|
|
):
|
||
|
|
asyncio.create_task(
|
||
|
|
self.slack_alerting_instance.response_taking_too_long(request_data=data)
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
### DB CONNECTOR ###
|
||
|
|
# Define the retry decorator with backoff strategy
|
||
|
|
# Function to be called whenever a retry is about to happen
|
||
|
|
def on_backoff(details):
|
||
|
|
# The 'tries' key in the details dictionary contains the number of completed tries
|
||
|
|
print_verbose(f"Backing off... this was attempt #{details['tries']}")
|
||
|
|
|
||
|
|
|
||
|
|
def jsonify_object(data: dict) -> dict:
|
||
|
|
db_data = copy.deepcopy(data)
|
||
|
|
|
||
|
|
for k, v in db_data.items():
|
||
|
|
if isinstance(v, dict):
|
||
|
|
try:
|
||
|
|
db_data[k] = json.dumps(v)
|
||
|
|
except Exception:
|
||
|
|
# This avoids Prisma retrying this 5 times, and making 5 clients
|
||
|
|
db_data[k] = "failed-to-serialize-json"
|
||
|
|
return db_data
|
||
|
|
|
||
|
|
|
||
|
|
class PrismaClient:
|
||
|
|
spend_log_transactions: List = []
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
database_url: str,
|
||
|
|
proxy_logging_obj: ProxyLogging,
|
||
|
|
http_client: Optional[Any] = None,
|
||
|
|
):
|
||
|
|
## init logging object
|
||
|
|
self.proxy_logging_obj = proxy_logging_obj
|
||
|
|
self.iam_token_db_auth: Optional[bool] = str_to_bool(
|
||
|
|
os.getenv("IAM_TOKEN_DB_AUTH")
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug("Creating Prisma Client..")
|
||
|
|
try:
|
||
|
|
from prisma import Prisma # type: ignore
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(f"Failed to import Prisma client: {e}")
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"This usually means 'prisma generate' hasn't been run yet."
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
"Please run 'prisma generate' to generate the Prisma client."
|
||
|
|
)
|
||
|
|
raise Exception(
|
||
|
|
"Unable to find Prisma binaries. Please run 'prisma generate' first."
|
||
|
|
)
|
||
|
|
if http_client is not None:
|
||
|
|
self.db = PrismaWrapper(
|
||
|
|
original_prisma=Prisma(http=http_client),
|
||
|
|
iam_token_db_auth=(
|
||
|
|
self.iam_token_db_auth
|
||
|
|
if self.iam_token_db_auth is not None
|
||
|
|
else False
|
||
|
|
),
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
self.db = PrismaWrapper(
|
||
|
|
original_prisma=Prisma(),
|
||
|
|
iam_token_db_auth=(
|
||
|
|
self.iam_token_db_auth
|
||
|
|
if self.iam_token_db_auth is not None
|
||
|
|
else False
|
||
|
|
),
|
||
|
|
) # Client to connect to Prisma db
|
||
|
|
verbose_proxy_logger.debug("Success - Created Prisma Client")
|
||
|
|
|
||
|
|
def get_request_status(
|
||
|
|
self, payload: Union[dict, SpendLogsPayload]
|
||
|
|
) -> Literal["success", "failure"]:
|
||
|
|
"""
|
||
|
|
Determine if a request was successful or failed based on payload metadata.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
payload (Union[dict, SpendLogsPayload]): Request payload containing metadata
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Literal["success", "failure"]: Request status
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Get metadata and convert to dict if it's a JSON string
|
||
|
|
payload_metadata: Union[Dict, SpendLogsMetadata, str] = payload.get(
|
||
|
|
"metadata", {}
|
||
|
|
)
|
||
|
|
if isinstance(payload_metadata, str):
|
||
|
|
payload_metadata_json: Union[Dict, SpendLogsMetadata] = cast(
|
||
|
|
Dict, json.loads(payload_metadata)
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
payload_metadata_json = payload_metadata
|
||
|
|
|
||
|
|
# Check status in metadata dict
|
||
|
|
return (
|
||
|
|
"failure"
|
||
|
|
if payload_metadata_json.get("status") == "failure"
|
||
|
|
else "success"
|
||
|
|
)
|
||
|
|
|
||
|
|
except (json.JSONDecodeError, AttributeError):
|
||
|
|
# Default to success if metadata parsing fails
|
||
|
|
return "success"
|
||
|
|
|
||
|
|
def hash_token(self, token: str):
|
||
|
|
# Hash the string using SHA-256
|
||
|
|
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||
|
|
|
||
|
|
return hashed_token
|
||
|
|
|
||
|
|
def jsonify_object(self, data: dict) -> dict:
|
||
|
|
db_data = copy.deepcopy(data)
|
||
|
|
|
||
|
|
for k, v in db_data.items():
|
||
|
|
if isinstance(v, dict):
|
||
|
|
try:
|
||
|
|
db_data[k] = json.dumps(v)
|
||
|
|
except Exception:
|
||
|
|
# This avoids Prisma retrying this 5 times, and making 5 clients
|
||
|
|
db_data[k] = "failed-to-serialize-json"
|
||
|
|
return db_data
|
||
|
|
|
||
|
|
@backoff.on_exception(
|
||
|
|
backoff.expo,
|
||
|
|
Exception, # base exception to catch for the backoff
|
||
|
|
max_tries=3, # maximum number of retries
|
||
|
|
max_time=10, # maximum total time to retry for
|
||
|
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||
|
|
)
|
||
|
|
async def check_view_exists(self):
|
||
|
|
"""
|
||
|
|
Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db.
|
||
|
|
|
||
|
|
LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth
|
||
|
|
|
||
|
|
MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month
|
||
|
|
|
||
|
|
If the view doesn't exist, one will be created.
|
||
|
|
"""
|
||
|
|
|
||
|
|
# Check to see if all of the necessary views exist and if they do, simply return
|
||
|
|
# This is more efficient because it lets us check for all views in one
|
||
|
|
# query instead of multiple queries.
|
||
|
|
try:
|
||
|
|
expected_views = [
|
||
|
|
"LiteLLM_VerificationTokenView",
|
||
|
|
"MonthlyGlobalSpend",
|
||
|
|
"Last30dKeysBySpend",
|
||
|
|
"Last30dModelsBySpend",
|
||
|
|
"MonthlyGlobalSpendPerKey",
|
||
|
|
"MonthlyGlobalSpendPerUserPerKey",
|
||
|
|
"Last30dTopEndUsersSpend",
|
||
|
|
"DailyTagSpend",
|
||
|
|
]
|
||
|
|
required_view = "LiteLLM_VerificationTokenView"
|
||
|
|
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
|
||
|
|
pg_schema = os.getenv("DATABASE_SCHEMA", "public")
|
||
|
|
ret = await self.db.query_raw(
|
||
|
|
f"""
|
||
|
|
WITH existing_views AS (
|
||
|
|
SELECT viewname
|
||
|
|
FROM pg_views
|
||
|
|
WHERE schemaname = '{pg_schema}' AND viewname IN (
|
||
|
|
{expected_views_str}
|
||
|
|
)
|
||
|
|
)
|
||
|
|
SELECT
|
||
|
|
(SELECT COUNT(*) FROM existing_views) AS view_count,
|
||
|
|
ARRAY_AGG(viewname) AS view_names
|
||
|
|
FROM existing_views
|
||
|
|
"""
|
||
|
|
)
|
||
|
|
expected_total_views = len(expected_views)
|
||
|
|
if ret[0]["view_count"] == expected_total_views:
|
||
|
|
verbose_proxy_logger.info("All necessary views exist!")
|
||
|
|
return
|
||
|
|
else:
|
||
|
|
## check if required view exists ##
|
||
|
|
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
|
||
|
|
await self.health_check() # make sure we can connect to db
|
||
|
|
await self.db.execute_raw(
|
||
|
|
"""
|
||
|
|
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||
|
|
SELECT
|
||
|
|
v.*,
|
||
|
|
t.spend AS team_spend,
|
||
|
|
t.max_budget AS team_max_budget,
|
||
|
|
t.tpm_limit AS team_tpm_limit,
|
||
|
|
t.rpm_limit AS team_rpm_limit
|
||
|
|
FROM "LiteLLM_VerificationToken" v
|
||
|
|
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||
|
|
"""
|
||
|
|
)
|
||
|
|
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
"LiteLLM_VerificationTokenView Created in DB!"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
should_create_views = await should_create_missing_views(db=self.db)
|
||
|
|
if should_create_views:
|
||
|
|
await create_missing_views(db=self.db)
|
||
|
|
else:
|
||
|
|
# don't block execution if these views are missing
|
||
|
|
# Convert lists to sets for efficient difference calculation
|
||
|
|
ret_view_names_set = (
|
||
|
|
set(ret[0]["view_names"]) if ret[0]["view_names"] else set()
|
||
|
|
)
|
||
|
|
expected_views_set = set(expected_views)
|
||
|
|
# Find missing views
|
||
|
|
missing_views = expected_views_set - ret_view_names_set
|
||
|
|
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
"\n\n\033[93mNot all views exist in db, needed for UI 'Usage' tab. Missing={}.\nRun 'create_views.py' from https://github.com/BerriAI/litellm/tree/main/db_scripts to create missing views.\033[0m\n".format(
|
||
|
|
missing_views
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception:
|
||
|
|
raise
|
||
|
|
return
|
||
|
|
|
||
|
|
@log_db_metrics
|
||
|
|
@backoff.on_exception(
|
||
|
|
backoff.expo,
|
||
|
|
Exception, # base exception to catch for the backoff
|
||
|
|
max_tries=1, # maximum number of retries
|
||
|
|
max_time=2, # maximum total time to retry for
|
||
|
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||
|
|
)
|
||
|
|
async def get_generic_data(
|
||
|
|
self,
|
||
|
|
key: str,
|
||
|
|
value: Any,
|
||
|
|
table_name: Literal["users", "keys", "config", "spend"],
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Generic implementation of get data
|
||
|
|
"""
|
||
|
|
start_time = time.time()
|
||
|
|
try:
|
||
|
|
if table_name == "users":
|
||
|
|
response = await self.db.litellm_usertable.find_first(
|
||
|
|
where={key: value} # type: ignore
|
||
|
|
)
|
||
|
|
elif table_name == "keys":
|
||
|
|
response = await self.db.litellm_verificationtoken.find_first( # type: ignore
|
||
|
|
where={key: value} # type: ignore
|
||
|
|
)
|
||
|
|
elif table_name == "config":
|
||
|
|
response = await self.db.litellm_config.find_first( # type: ignore
|
||
|
|
where={key: value} # type: ignore
|
||
|
|
)
|
||
|
|
elif table_name == "spend":
|
||
|
|
response = await self.db.l.find_first( # type: ignore
|
||
|
|
where={key: value} # type: ignore
|
||
|
|
)
|
||
|
|
return response
|
||
|
|
except Exception as e:
|
||
|
|
import traceback
|
||
|
|
|
||
|
|
error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}"
|
||
|
|
verbose_proxy_logger.error(error_msg)
|
||
|
|
error_msg = error_msg + "\nException Type: {}".format(type(e))
|
||
|
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||
|
|
end_time = time.time()
|
||
|
|
_duration = end_time - start_time
|
||
|
|
asyncio.create_task(
|
||
|
|
self.proxy_logging_obj.failure_handler(
|
||
|
|
original_exception=e,
|
||
|
|
duration=_duration,
|
||
|
|
traceback_str=error_traceback,
|
||
|
|
call_type="get_generic_data",
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
raise e
|
||
|
|
|
||
|
|
@backoff.on_exception(
|
||
|
|
backoff.expo,
|
||
|
|
Exception, # base exception to catch for the backoff
|
||
|
|
max_tries=3, # maximum number of retries
|
||
|
|
max_time=10, # maximum total time to retry for
|
||
|
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||
|
|
)
|
||
|
|
@log_db_metrics
|
||
|
|
async def get_data( # noqa: PLR0915
|
||
|
|
self,
|
||
|
|
token: Optional[Union[str, list]] = None,
|
||
|
|
user_id: Optional[str] = None,
|
||
|
|
user_id_list: Optional[list] = None,
|
||
|
|
team_id: Optional[str] = None,
|
||
|
|
team_id_list: Optional[list] = None,
|
||
|
|
key_val: Optional[dict] = None,
|
||
|
|
table_name: Optional[
|
||
|
|
Literal[
|
||
|
|
"user",
|
||
|
|
"key",
|
||
|
|
"config",
|
||
|
|
"spend",
|
||
|
|
"enduser",
|
||
|
|
"budget",
|
||
|
|
"team",
|
||
|
|
"user_notification",
|
||
|
|
"combined_view",
|
||
|
|
]
|
||
|
|
] = None,
|
||
|
|
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
||
|
|
expires: Optional[datetime] = None,
|
||
|
|
reset_at: Optional[datetime] = None,
|
||
|
|
offset: Optional[int] = None, # pagination, what row number to start from
|
||
|
|
limit: Optional[
|
||
|
|
int
|
||
|
|
] = None, # pagination, number of rows to getch when find_all==True
|
||
|
|
parent_otel_span: Optional[Span] = None,
|
||
|
|
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||
|
|
budget_id_list: Optional[List[str]] = None,
|
||
|
|
):
|
||
|
|
args_passed_in = locals()
|
||
|
|
start_time = time.time()
|
||
|
|
hashed_token: Optional[str] = None
|
||
|
|
try:
|
||
|
|
response: Any = None
|
||
|
|
if (token is not None and table_name is None) or (
|
||
|
|
table_name is not None and table_name == "key"
|
||
|
|
):
|
||
|
|
# check if plain text or hash
|
||
|
|
if token is not None:
|
||
|
|
if isinstance(token, str):
|
||
|
|
hashed_token = _hash_token_if_needed(token=token)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"PrismaClient: find_unique for token: {hashed_token}"
|
||
|
|
)
|
||
|
|
if query_type == "find_unique" and hashed_token is not None:
|
||
|
|
if token is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=400,
|
||
|
|
detail={"error": f"No token passed in. Token={token}"},
|
||
|
|
)
|
||
|
|
response = await self.db.litellm_verificationtoken.find_unique(
|
||
|
|
where={"token": hashed_token}, # type: ignore
|
||
|
|
include={"litellm_budget_table": True},
|
||
|
|
)
|
||
|
|
if response is not None:
|
||
|
|
# for prisma we need to cast the expires time to str
|
||
|
|
if response.expires is not None and isinstance(
|
||
|
|
response.expires, datetime
|
||
|
|
):
|
||
|
|
response.expires = response.expires.isoformat()
|
||
|
|
else:
|
||
|
|
# Token does not exist.
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
detail=f"Authentication Error: invalid user key - user key does not exist in db. User Key={token}",
|
||
|
|
)
|
||
|
|
elif query_type == "find_all" and user_id is not None:
|
||
|
|
response = await self.db.litellm_verificationtoken.find_many(
|
||
|
|
where={"user_id": user_id},
|
||
|
|
include={"litellm_budget_table": True},
|
||
|
|
)
|
||
|
|
if response is not None and len(response) > 0:
|
||
|
|
for r in response:
|
||
|
|
if isinstance(r.expires, datetime):
|
||
|
|
r.expires = r.expires.isoformat()
|
||
|
|
elif query_type == "find_all" and team_id is not None:
|
||
|
|
response = await self.db.litellm_verificationtoken.find_many(
|
||
|
|
where={"team_id": team_id},
|
||
|
|
include={"litellm_budget_table": True},
|
||
|
|
)
|
||
|
|
if response is not None and len(response) > 0:
|
||
|
|
for r in response:
|
||
|
|
if isinstance(r.expires, datetime):
|
||
|
|
r.expires = r.expires.isoformat()
|
||
|
|
elif (
|
||
|
|
query_type == "find_all"
|
||
|
|
and expires is not None
|
||
|
|
and reset_at is not None
|
||
|
|
):
|
||
|
|
response = await self.db.litellm_verificationtoken.find_many(
|
||
|
|
where={ # type:ignore
|
||
|
|
"OR": [
|
||
|
|
{"expires": None},
|
||
|
|
{"expires": {"gt": expires}},
|
||
|
|
],
|
||
|
|
"budget_reset_at": {"lt": reset_at},
|
||
|
|
}
|
||
|
|
)
|
||
|
|
if response is not None and len(response) > 0:
|
||
|
|
for r in response:
|
||
|
|
if isinstance(r.expires, datetime):
|
||
|
|
r.expires = r.expires.isoformat()
|
||
|
|
elif query_type == "find_all":
|
||
|
|
where_filter: dict = {}
|
||
|
|
if token is not None:
|
||
|
|
where_filter["token"] = {}
|
||
|
|
if isinstance(token, str):
|
||
|
|
token = _hash_token_if_needed(token=token)
|
||
|
|
where_filter["token"]["in"] = [token]
|
||
|
|
elif isinstance(token, list):
|
||
|
|
hashed_tokens = []
|
||
|
|
for t in token:
|
||
|
|
assert isinstance(t, str)
|
||
|
|
if t.startswith("sk-"):
|
||
|
|
new_token = self.hash_token(token=t)
|
||
|
|
hashed_tokens.append(new_token)
|
||
|
|
else:
|
||
|
|
hashed_tokens.append(t)
|
||
|
|
where_filter["token"]["in"] = hashed_tokens
|
||
|
|
response = await self.db.litellm_verificationtoken.find_many(
|
||
|
|
order={"spend": "desc"},
|
||
|
|
where=where_filter, # type: ignore
|
||
|
|
include={"litellm_budget_table": True},
|
||
|
|
)
|
||
|
|
if response is not None:
|
||
|
|
return response
|
||
|
|
else:
|
||
|
|
# Token does not exist.
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
detail="Authentication Error: invalid user key - token does not exist",
|
||
|
|
)
|
||
|
|
elif (user_id is not None and table_name is None) or (
|
||
|
|
table_name is not None and table_name == "user"
|
||
|
|
):
|
||
|
|
if query_type == "find_unique":
|
||
|
|
if key_val is None:
|
||
|
|
key_val = {"user_id": user_id}
|
||
|
|
|
||
|
|
response = await self.db.litellm_usertable.find_unique( # type: ignore
|
||
|
|
where=key_val, # type: ignore
|
||
|
|
include={"organization_memberships": True},
|
||
|
|
)
|
||
|
|
|
||
|
|
elif query_type == "find_all" and key_val is not None:
|
||
|
|
response = await self.db.litellm_usertable.find_many(
|
||
|
|
where=key_val # type: ignore
|
||
|
|
) # type: ignore
|
||
|
|
elif query_type == "find_all" and reset_at is not None:
|
||
|
|
response = await self.db.litellm_usertable.find_many(
|
||
|
|
where={ # type:ignore
|
||
|
|
"budget_reset_at": {"lt": reset_at},
|
||
|
|
}
|
||
|
|
)
|
||
|
|
elif query_type == "find_all" and user_id_list is not None:
|
||
|
|
response = await self.db.litellm_usertable.find_many(
|
||
|
|
where={"user_id": {"in": user_id_list}}
|
||
|
|
)
|
||
|
|
elif query_type == "find_all":
|
||
|
|
if expires is not None:
|
||
|
|
response = await self.db.litellm_usertable.find_many( # type: ignore
|
||
|
|
order={"spend": "desc"},
|
||
|
|
where={ # type:ignore
|
||
|
|
"OR": [
|
||
|
|
{"expires": None}, # type:ignore
|
||
|
|
{"expires": {"gt": expires}}, # type:ignore
|
||
|
|
],
|
||
|
|
},
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# return all users in the table, get their key aliases ordered by spend
|
||
|
|
sql_query = """
|
||
|
|
SELECT
|
||
|
|
u.*,
|
||
|
|
json_agg(v.key_alias) AS key_aliases
|
||
|
|
FROM
|
||
|
|
"LiteLLM_UserTable" u
|
||
|
|
LEFT JOIN "LiteLLM_VerificationToken" v ON u.user_id = v.user_id
|
||
|
|
GROUP BY
|
||
|
|
u.user_id
|
||
|
|
ORDER BY u.spend DESC
|
||
|
|
LIMIT $1
|
||
|
|
OFFSET $2
|
||
|
|
"""
|
||
|
|
response = await self.db.query_raw(sql_query, limit, offset)
|
||
|
|
return response
|
||
|
|
elif table_name == "spend":
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"PrismaClient: get_data: table_name == 'spend'"
|
||
|
|
)
|
||
|
|
if key_val is not None:
|
||
|
|
if query_type == "find_unique":
|
||
|
|
response = await self.db.litellm_spendlogs.find_unique( # type: ignore
|
||
|
|
where={ # type: ignore
|
||
|
|
key_val["key"]: key_val["value"], # type: ignore
|
||
|
|
}
|
||
|
|
)
|
||
|
|
elif query_type == "find_all":
|
||
|
|
response = await self.db.litellm_spendlogs.find_many( # type: ignore
|
||
|
|
where={
|
||
|
|
key_val["key"]: key_val["value"], # type: ignore
|
||
|
|
}
|
||
|
|
)
|
||
|
|
return response
|
||
|
|
else:
|
||
|
|
response = await self.db.litellm_spendlogs.find_many( # type: ignore
|
||
|
|
order={"startTime": "desc"},
|
||
|
|
)
|
||
|
|
return response
|
||
|
|
elif table_name == "budget" and reset_at is not None:
|
||
|
|
if query_type == "find_all":
|
||
|
|
response = await self.db.litellm_budgettable.find_many(
|
||
|
|
where={ # type:ignore
|
||
|
|
"OR": [
|
||
|
|
{
|
||
|
|
"AND": [
|
||
|
|
{"budget_reset_at": None},
|
||
|
|
{"NOT": {"budget_duration": None}},
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{"budget_reset_at": {"lt": reset_at}},
|
||
|
|
]
|
||
|
|
}
|
||
|
|
)
|
||
|
|
return response
|
||
|
|
|
||
|
|
elif table_name == "enduser" and budget_id_list is not None:
|
||
|
|
if query_type == "find_all":
|
||
|
|
response = await self.db.litellm_endusertable.find_many(
|
||
|
|
where={"budget_id": {"in": budget_id_list}}
|
||
|
|
)
|
||
|
|
return response
|
||
|
|
elif table_name == "team":
|
||
|
|
if query_type == "find_unique":
|
||
|
|
response = await self.db.litellm_teamtable.find_unique(
|
||
|
|
where={"team_id": team_id}, # type: ignore
|
||
|
|
include={"litellm_model_table": True}, # type: ignore
|
||
|
|
)
|
||
|
|
elif query_type == "find_all" and reset_at is not None:
|
||
|
|
response = await self.db.litellm_teamtable.find_many(
|
||
|
|
where={ # type:ignore
|
||
|
|
"budget_reset_at": {"lt": reset_at},
|
||
|
|
}
|
||
|
|
)
|
||
|
|
elif query_type == "find_all" and user_id is not None:
|
||
|
|
response = await self.db.litellm_teamtable.find_many(
|
||
|
|
where={
|
||
|
|
"members": {"has": user_id},
|
||
|
|
},
|
||
|
|
include={"litellm_budget_table": True},
|
||
|
|
)
|
||
|
|
elif query_type == "find_all" and team_id_list is not None:
|
||
|
|
response = await self.db.litellm_teamtable.find_many(
|
||
|
|
where={"team_id": {"in": team_id_list}}
|
||
|
|
)
|
||
|
|
elif query_type == "find_all" and team_id_list is None:
|
||
|
|
response = await self.db.litellm_teamtable.find_many(
|
||
|
|
take=MAX_TEAM_LIST_LIMIT
|
||
|
|
)
|
||
|
|
return response
|
||
|
|
elif table_name == "user_notification":
|
||
|
|
if query_type == "find_unique":
|
||
|
|
response = await self.db.litellm_usernotifications.find_unique( # type: ignore
|
||
|
|
where={"user_id": user_id} # type: ignore
|
||
|
|
)
|
||
|
|
elif query_type == "find_all":
|
||
|
|
response = await self.db.litellm_usernotifications.find_many() # type: ignore
|
||
|
|
return response
|
||
|
|
elif table_name == "combined_view":
|
||
|
|
# check if plain text or hash
|
||
|
|
if token is not None:
|
||
|
|
if isinstance(token, str):
|
||
|
|
hashed_token = _hash_token_if_needed(token=token)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"PrismaClient: find_unique for token: {hashed_token}"
|
||
|
|
)
|
||
|
|
if query_type == "find_unique":
|
||
|
|
if token is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=400,
|
||
|
|
detail={"error": f"No token passed in. Token={token}"},
|
||
|
|
)
|
||
|
|
|
||
|
|
sql_query = f"""
|
||
|
|
SELECT
|
||
|
|
v.*,
|
||
|
|
t.spend AS team_spend,
|
||
|
|
t.max_budget AS team_max_budget,
|
||
|
|
t.tpm_limit AS team_tpm_limit,
|
||
|
|
t.rpm_limit AS team_rpm_limit,
|
||
|
|
t.models AS team_models,
|
||
|
|
t.metadata AS team_metadata,
|
||
|
|
t.blocked AS team_blocked,
|
||
|
|
t.team_alias AS team_alias,
|
||
|
|
t.metadata AS team_metadata,
|
||
|
|
t.members_with_roles AS team_members_with_roles,
|
||
|
|
t.organization_id as org_id,
|
||
|
|
tm.spend AS team_member_spend,
|
||
|
|
m.aliases AS team_model_aliases,
|
||
|
|
-- Added comma to separate b.* columns
|
||
|
|
b.max_budget AS litellm_budget_table_max_budget,
|
||
|
|
b.tpm_limit AS litellm_budget_table_tpm_limit,
|
||
|
|
b.rpm_limit AS litellm_budget_table_rpm_limit,
|
||
|
|
b.model_max_budget as litellm_budget_table_model_max_budget,
|
||
|
|
b.soft_budget as litellm_budget_table_soft_budget
|
||
|
|
FROM "LiteLLM_VerificationToken" AS v
|
||
|
|
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
||
|
|
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id
|
||
|
|
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id
|
||
|
|
LEFT JOIN "LiteLLM_BudgetTable" AS b ON v.budget_id = b.budget_id
|
||
|
|
WHERE v.token = '{token}'
|
||
|
|
"""
|
||
|
|
|
||
|
|
response = await self.db.query_first(query=sql_query)
|
||
|
|
|
||
|
|
if response is not None:
|
||
|
|
if response["team_models"] is None:
|
||
|
|
response["team_models"] = []
|
||
|
|
if response["team_blocked"] is None:
|
||
|
|
response["team_blocked"] = False
|
||
|
|
|
||
|
|
team_member: Optional[Member] = None
|
||
|
|
if (
|
||
|
|
response["team_members_with_roles"] is not None
|
||
|
|
and response["user_id"] is not None
|
||
|
|
):
|
||
|
|
## find the team member corresponding to user id
|
||
|
|
"""
|
||
|
|
[
|
||
|
|
{
|
||
|
|
"role": "admin",
|
||
|
|
"user_id": "default_user_id",
|
||
|
|
"user_email": null
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"role": "user",
|
||
|
|
"user_id": null,
|
||
|
|
"user_email": "test@email.com"
|
||
|
|
}
|
||
|
|
]
|
||
|
|
"""
|
||
|
|
for tm in response["team_members_with_roles"]:
|
||
|
|
if tm.get("user_id") is not None and response[
|
||
|
|
"user_id"
|
||
|
|
] == tm.get("user_id"):
|
||
|
|
team_member = Member(**tm)
|
||
|
|
response["team_member"] = team_member
|
||
|
|
response = LiteLLM_VerificationTokenView(
|
||
|
|
**response, last_refreshed_at=time.time()
|
||
|
|
)
|
||
|
|
# for prisma we need to cast the expires time to str
|
||
|
|
if response.expires is not None and isinstance(
|
||
|
|
response.expires, datetime
|
||
|
|
):
|
||
|
|
response.expires = response.expires.isoformat()
|
||
|
|
return response
|
||
|
|
except Exception as e:
|
||
|
|
import traceback
|
||
|
|
|
||
|
|
prisma_query_info = f"LiteLLM Prisma Client Exception: Error with `get_data`. Args passed in: {args_passed_in}"
|
||
|
|
error_msg = prisma_query_info + str(e)
|
||
|
|
print_verbose(error_msg)
|
||
|
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||
|
|
verbose_proxy_logger.debug(error_traceback)
|
||
|
|
end_time = time.time()
|
||
|
|
_duration = end_time - start_time
|
||
|
|
|
||
|
|
asyncio.create_task(
|
||
|
|
self.proxy_logging_obj.failure_handler(
|
||
|
|
original_exception=e,
|
||
|
|
duration=_duration,
|
||
|
|
call_type="get_data",
|
||
|
|
traceback_str=error_traceback,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
raise e
|
||
|
|
|
||
|
|
def jsonify_team_object(self, db_data: dict):
|
||
|
|
db_data = self.jsonify_object(data=db_data)
|
||
|
|
if db_data.get("members_with_roles", None) is not None and isinstance(
|
||
|
|
db_data["members_with_roles"], list
|
||
|
|
):
|
||
|
|
db_data["members_with_roles"] = json.dumps(db_data["members_with_roles"])
|
||
|
|
return db_data
|
||
|
|
|
||
|
|
# Define a retrying strategy with exponential backoff
|
||
|
|
@backoff.on_exception(
|
||
|
|
backoff.expo,
|
||
|
|
Exception, # base exception to catch for the backoff
|
||
|
|
max_tries=3, # maximum number of retries
|
||
|
|
max_time=10, # maximum total time to retry for
|
||
|
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||
|
|
)
|
||
|
|
async def insert_data( # noqa: PLR0915
|
||
|
|
self,
|
||
|
|
data: dict,
|
||
|
|
table_name: Literal[
|
||
|
|
"user", "key", "config", "spend", "team", "user_notification"
|
||
|
|
],
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Add a key to the database. If it already exists, do nothing.
|
||
|
|
"""
|
||
|
|
start_time = time.time()
|
||
|
|
try:
|
||
|
|
verbose_proxy_logger.debug("PrismaClient: insert_data: %s", data)
|
||
|
|
if table_name == "key":
|
||
|
|
token = data["token"]
|
||
|
|
hashed_token = self.hash_token(token=token)
|
||
|
|
db_data = self.jsonify_object(data=data)
|
||
|
|
db_data["token"] = hashed_token
|
||
|
|
print_verbose(
|
||
|
|
"PrismaClient: Before upsert into litellm_verificationtoken"
|
||
|
|
)
|
||
|
|
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||
|
|
where={
|
||
|
|
"token": hashed_token,
|
||
|
|
},
|
||
|
|
data={
|
||
|
|
"create": {**db_data}, # type: ignore
|
||
|
|
"update": {}, # don't do anything if it already exists
|
||
|
|
},
|
||
|
|
include={"litellm_budget_table": True},
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.info("Data Inserted into Keys Table")
|
||
|
|
return new_verification_token
|
||
|
|
elif table_name == "user":
|
||
|
|
db_data = self.jsonify_object(data=data)
|
||
|
|
try:
|
||
|
|
new_user_row = await self.db.litellm_usertable.upsert(
|
||
|
|
where={"user_id": data["user_id"]},
|
||
|
|
data={
|
||
|
|
"create": {**db_data}, # type: ignore
|
||
|
|
"update": {}, # don't do anything if it already exists
|
||
|
|
},
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
if (
|
||
|
|
"Foreign key constraint failed on the field: `LiteLLM_UserTable_organization_id_fkey (index)`"
|
||
|
|
in str(e)
|
||
|
|
):
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=400,
|
||
|
|
detail={
|
||
|
|
"error": f"Foreign Key Constraint failed. Organization ID={db_data['organization_id']} does not exist in LiteLLM_OrganizationTable. Create via `/organization/new`."
|
||
|
|
},
|
||
|
|
)
|
||
|
|
raise e
|
||
|
|
verbose_proxy_logger.info("Data Inserted into User Table")
|
||
|
|
return new_user_row
|
||
|
|
elif table_name == "team":
|
||
|
|
db_data = self.jsonify_team_object(db_data=data)
|
||
|
|
new_team_row = await self.db.litellm_teamtable.upsert(
|
||
|
|
where={"team_id": data["team_id"]},
|
||
|
|
data={
|
||
|
|
"create": {**db_data}, # type: ignore
|
||
|
|
"update": {}, # don't do anything if it already exists
|
||
|
|
},
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.info("Data Inserted into Team Table")
|
||
|
|
return new_team_row
|
||
|
|
elif table_name == "config":
|
||
|
|
"""
|
||
|
|
For each param,
|
||
|
|
get the existing table values
|
||
|
|
|
||
|
|
Add the new values
|
||
|
|
|
||
|
|
Update DB
|
||
|
|
"""
|
||
|
|
tasks = []
|
||
|
|
for k, v in data.items():
|
||
|
|
updated_data = v
|
||
|
|
updated_data = json.dumps(updated_data)
|
||
|
|
updated_table_row = self.db.litellm_config.upsert(
|
||
|
|
where={"param_name": k}, # type: ignore
|
||
|
|
data={
|
||
|
|
"create": {"param_name": k, "param_value": updated_data}, # type: ignore
|
||
|
|
"update": {"param_value": updated_data},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
tasks.append(updated_table_row)
|
||
|
|
await asyncio.gather(*tasks)
|
||
|
|
verbose_proxy_logger.info("Data Inserted into Config Table")
|
||
|
|
elif table_name == "spend":
|
||
|
|
db_data = self.jsonify_object(data=data)
|
||
|
|
new_spend_row = await self.db.litellm_spendlogs.upsert(
|
||
|
|
where={"request_id": data["request_id"]},
|
||
|
|
data={
|
||
|
|
"create": {**db_data}, # type: ignore
|
||
|
|
"update": {}, # don't do anything if it already exists
|
||
|
|
},
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.info("Data Inserted into Spend Table")
|
||
|
|
return new_spend_row
|
||
|
|
elif table_name == "user_notification":
|
||
|
|
db_data = self.jsonify_object(data=data)
|
||
|
|
new_user_notification_row = (
|
||
|
|
await self.db.litellm_usernotifications.upsert( # type: ignore
|
||
|
|
where={"request_id": data["request_id"]},
|
||
|
|
data={
|
||
|
|
"create": {**db_data}, # type: ignore
|
||
|
|
"update": {}, # don't do anything if it already exists
|
||
|
|
},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.info("Data Inserted into Model Request Table")
|
||
|
|
return new_user_notification_row
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
import traceback
|
||
|
|
|
||
|
|
error_msg = f"LiteLLM Prisma Client Exception in insert_data: {str(e)}"
|
||
|
|
print_verbose(error_msg)
|
||
|
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||
|
|
end_time = time.time()
|
||
|
|
_duration = end_time - start_time
|
||
|
|
asyncio.create_task(
|
||
|
|
self.proxy_logging_obj.failure_handler(
|
||
|
|
original_exception=e,
|
||
|
|
duration=_duration,
|
||
|
|
call_type="insert_data",
|
||
|
|
traceback_str=error_traceback,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
raise e
|
||
|
|
|
||
|
|
# Define a retrying strategy with exponential backoff
|
||
|
|
@backoff.on_exception(
|
||
|
|
backoff.expo,
|
||
|
|
Exception, # base exception to catch for the backoff
|
||
|
|
max_tries=3, # maximum number of retries
|
||
|
|
max_time=10, # maximum total time to retry for
|
||
|
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||
|
|
)
|
||
|
|
async def update_data( # noqa: PLR0915
|
||
|
|
self,
|
||
|
|
token: Optional[str] = None,
|
||
|
|
data: dict = {},
|
||
|
|
data_list: Optional[List] = None,
|
||
|
|
user_id: Optional[str] = None,
|
||
|
|
team_id: Optional[str] = None,
|
||
|
|
query_type: Literal["update", "update_many"] = "update",
|
||
|
|
table_name: Optional[
|
||
|
|
Literal["user", "key", "config", "spend", "team", "enduser", "budget"]
|
||
|
|
] = None,
|
||
|
|
update_key_values: Optional[dict] = None,
|
||
|
|
update_key_values_custom_query: Optional[dict] = None,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Update existing data
|
||
|
|
"""
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"PrismaClient: update_data, table_name: {table_name}"
|
||
|
|
)
|
||
|
|
start_time = time.time()
|
||
|
|
try:
|
||
|
|
db_data = self.jsonify_object(data=data)
|
||
|
|
if update_key_values is not None:
|
||
|
|
update_key_values = self.jsonify_object(data=update_key_values)
|
||
|
|
if token is not None:
|
||
|
|
print_verbose(f"token: {token}")
|
||
|
|
# check if plain text or hash
|
||
|
|
token = _hash_token_if_needed(token=token)
|
||
|
|
db_data["token"] = token
|
||
|
|
response = await self.db.litellm_verificationtoken.update(
|
||
|
|
where={"token": token}, # type: ignore
|
||
|
|
data={**db_data}, # type: ignore
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"\033[91m"
|
||
|
|
+ f"DB Token Table update succeeded {response}"
|
||
|
|
+ "\033[0m"
|
||
|
|
)
|
||
|
|
_data: dict = {}
|
||
|
|
if response is not None:
|
||
|
|
try:
|
||
|
|
_data = response.model_dump() # type: ignore
|
||
|
|
except Exception:
|
||
|
|
_data = response.dict()
|
||
|
|
return {"token": token, "data": _data}
|
||
|
|
elif (
|
||
|
|
user_id is not None
|
||
|
|
or (table_name is not None and table_name == "user")
|
||
|
|
and query_type == "update"
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
If data['spend'] + data['user'], update the user table with spend info as well
|
||
|
|
"""
|
||
|
|
if user_id is None:
|
||
|
|
user_id = db_data["user_id"]
|
||
|
|
if update_key_values is None:
|
||
|
|
if update_key_values_custom_query is not None:
|
||
|
|
update_key_values = update_key_values_custom_query
|
||
|
|
else:
|
||
|
|
update_key_values = db_data
|
||
|
|
update_user_row = await self.db.litellm_usertable.upsert(
|
||
|
|
where={"user_id": user_id}, # type: ignore
|
||
|
|
data={
|
||
|
|
"create": {**db_data}, # type: ignore
|
||
|
|
"update": {
|
||
|
|
**update_key_values # type: ignore
|
||
|
|
}, # just update user-specified values, if it already exists
|
||
|
|
},
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
"\033[91m"
|
||
|
|
+ f"DB User Table - update succeeded {update_user_row}"
|
||
|
|
+ "\033[0m"
|
||
|
|
)
|
||
|
|
return {"user_id": user_id, "data": update_user_row}
|
||
|
|
elif (
|
||
|
|
team_id is not None
|
||
|
|
or (table_name is not None and table_name == "team")
|
||
|
|
and query_type == "update"
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
If data['spend'] + data['user'], update the user table with spend info as well
|
||
|
|
"""
|
||
|
|
if team_id is None:
|
||
|
|
team_id = db_data["team_id"]
|
||
|
|
if update_key_values is None:
|
||
|
|
update_key_values = db_data
|
||
|
|
if "team_id" not in db_data and team_id is not None:
|
||
|
|
db_data["team_id"] = team_id
|
||
|
|
if "members_with_roles" in db_data and isinstance(
|
||
|
|
db_data["members_with_roles"], list
|
||
|
|
):
|
||
|
|
db_data["members_with_roles"] = json.dumps(
|
||
|
|
db_data["members_with_roles"]
|
||
|
|
)
|
||
|
|
if "members_with_roles" in update_key_values and isinstance(
|
||
|
|
update_key_values["members_with_roles"], list
|
||
|
|
):
|
||
|
|
update_key_values["members_with_roles"] = json.dumps(
|
||
|
|
update_key_values["members_with_roles"]
|
||
|
|
)
|
||
|
|
update_team_row = await self.db.litellm_teamtable.upsert(
|
||
|
|
where={"team_id": team_id}, # type: ignore
|
||
|
|
data={
|
||
|
|
"create": {**db_data}, # type: ignore
|
||
|
|
"update": {
|
||
|
|
**update_key_values # type: ignore
|
||
|
|
}, # just update user-specified values, if it already exists
|
||
|
|
},
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
"\033[91m"
|
||
|
|
+ f"DB Team Table - update succeeded {update_team_row}"
|
||
|
|
+ "\033[0m"
|
||
|
|
)
|
||
|
|
return {"team_id": team_id, "data": update_team_row}
|
||
|
|
elif (
|
||
|
|
table_name is not None
|
||
|
|
and table_name == "key"
|
||
|
|
and query_type == "update_many"
|
||
|
|
and data_list is not None
|
||
|
|
and isinstance(data_list, list)
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Batch write update queries
|
||
|
|
"""
|
||
|
|
batcher = self.db.batch_()
|
||
|
|
for idx, t in enumerate(data_list):
|
||
|
|
# check if plain text or hash
|
||
|
|
if t.token.startswith("sk-"): # type: ignore
|
||
|
|
t.token = self.hash_token(token=t.token) # type: ignore
|
||
|
|
try:
|
||
|
|
data_json = self.jsonify_object(
|
||
|
|
data=t.model_dump(exclude_none=True)
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
data_json = self.jsonify_object(data=t.dict(exclude_none=True))
|
||
|
|
batcher.litellm_verificationtoken.update(
|
||
|
|
where={"token": t.token}, # type: ignore
|
||
|
|
data={**data_json}, # type: ignore
|
||
|
|
)
|
||
|
|
await batcher.commit()
|
||
|
|
print_verbose(
|
||
|
|
"\033[91m" + "DB Token Table update succeeded" + "\033[0m"
|
||
|
|
)
|
||
|
|
elif (
|
||
|
|
table_name is not None
|
||
|
|
and table_name == "user"
|
||
|
|
and query_type == "update_many"
|
||
|
|
and data_list is not None
|
||
|
|
and isinstance(data_list, list)
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Batch write update queries
|
||
|
|
"""
|
||
|
|
batcher = self.db.batch_()
|
||
|
|
for idx, user in enumerate(data_list):
|
||
|
|
try:
|
||
|
|
data_json = self.jsonify_object(
|
||
|
|
data=user.model_dump(exclude_none=True)
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
data_json = self.jsonify_object(data=user.dict())
|
||
|
|
batcher.litellm_usertable.upsert(
|
||
|
|
where={"user_id": user.user_id}, # type: ignore
|
||
|
|
data={
|
||
|
|
"create": {**data_json}, # type: ignore
|
||
|
|
"update": {
|
||
|
|
**data_json # type: ignore
|
||
|
|
}, # just update user-specified values, if it already exists
|
||
|
|
},
|
||
|
|
)
|
||
|
|
await batcher.commit()
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
"\033[91m" + "DB User Table Batch update succeeded" + "\033[0m"
|
||
|
|
)
|
||
|
|
elif (
|
||
|
|
table_name is not None
|
||
|
|
and table_name == "enduser"
|
||
|
|
and query_type == "update_many"
|
||
|
|
and data_list is not None
|
||
|
|
and isinstance(data_list, list)
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Batch write update queries
|
||
|
|
"""
|
||
|
|
batcher = self.db.batch_()
|
||
|
|
for enduser in data_list:
|
||
|
|
try:
|
||
|
|
data_json = self.jsonify_object(
|
||
|
|
data=enduser.model_dump(exclude_none=True)
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
data_json = self.jsonify_object(data=enduser.dict())
|
||
|
|
batcher.litellm_endusertable.upsert(
|
||
|
|
where={"user_id": enduser.user_id}, # type: ignore
|
||
|
|
data={
|
||
|
|
"create": {**data_json}, # type: ignore
|
||
|
|
"update": {
|
||
|
|
**data_json # type: ignore
|
||
|
|
}, # just update end-user-specified values, if it already exists
|
||
|
|
},
|
||
|
|
)
|
||
|
|
await batcher.commit()
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
"\033[91m" + "DB End User Table Batch update succeeded" + "\033[0m"
|
||
|
|
)
|
||
|
|
elif (
|
||
|
|
table_name is not None
|
||
|
|
and table_name == "budget"
|
||
|
|
and query_type == "update_many"
|
||
|
|
and data_list is not None
|
||
|
|
and isinstance(data_list, list)
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Batch write update queries
|
||
|
|
"""
|
||
|
|
batcher = self.db.batch_()
|
||
|
|
for budget in data_list:
|
||
|
|
try:
|
||
|
|
data_json = self.jsonify_object(
|
||
|
|
data=budget.model_dump(exclude_none=True)
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
data_json = self.jsonify_object(data=budget.dict())
|
||
|
|
batcher.litellm_budgettable.upsert(
|
||
|
|
where={"budget_id": budget.budget_id}, # type: ignore
|
||
|
|
data={
|
||
|
|
"create": {**data_json}, # type: ignore
|
||
|
|
"update": {
|
||
|
|
**data_json # type: ignore
|
||
|
|
}, # just update end-user-specified values, if it already exists
|
||
|
|
},
|
||
|
|
)
|
||
|
|
await batcher.commit()
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
"\033[91m" + "DB Budget Table Batch update succeeded" + "\033[0m"
|
||
|
|
)
|
||
|
|
elif (
|
||
|
|
table_name is not None
|
||
|
|
and table_name == "team"
|
||
|
|
and query_type == "update_many"
|
||
|
|
and data_list is not None
|
||
|
|
and isinstance(data_list, list)
|
||
|
|
):
|
||
|
|
# Batch write update queries
|
||
|
|
batcher = self.db.batch_()
|
||
|
|
for idx, team in enumerate(data_list):
|
||
|
|
try:
|
||
|
|
data_json = self.jsonify_team_object(
|
||
|
|
db_data=team.model_dump(exclude_none=True)
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
data_json = self.jsonify_object(
|
||
|
|
data=team.dict(exclude_none=True)
|
||
|
|
)
|
||
|
|
batcher.litellm_teamtable.upsert(
|
||
|
|
where={"team_id": team.team_id}, # type: ignore
|
||
|
|
data={
|
||
|
|
"create": {**data_json}, # type: ignore
|
||
|
|
"update": {
|
||
|
|
**data_json # type: ignore
|
||
|
|
}, # just update user-specified values, if it already exists
|
||
|
|
},
|
||
|
|
)
|
||
|
|
await batcher.commit()
|
||
|
|
verbose_proxy_logger.info(
|
||
|
|
"\033[91m" + "DB Team Table Batch update succeeded" + "\033[0m"
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
import traceback
|
||
|
|
|
||
|
|
error_msg = f"LiteLLM Prisma Client Exception - update_data: {str(e)}"
|
||
|
|
print_verbose(error_msg)
|
||
|
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||
|
|
end_time = time.time()
|
||
|
|
_duration = end_time - start_time
|
||
|
|
asyncio.create_task(
|
||
|
|
self.proxy_logging_obj.failure_handler(
|
||
|
|
original_exception=e,
|
||
|
|
duration=_duration,
|
||
|
|
call_type="update_data",
|
||
|
|
traceback_str=error_traceback,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
raise e
|
||
|
|
|
||
|
|
# Define a retrying strategy with exponential backoff
|
||
|
|
@backoff.on_exception(
|
||
|
|
backoff.expo,
|
||
|
|
Exception, # base exception to catch for the backoff
|
||
|
|
max_tries=3, # maximum number of retries
|
||
|
|
max_time=10, # maximum total time to retry for
|
||
|
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||
|
|
)
|
||
|
|
async def delete_data(
|
||
|
|
self,
|
||
|
|
tokens: Optional[List] = None,
|
||
|
|
team_id_list: Optional[List] = None,
|
||
|
|
table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None,
|
||
|
|
user_id: Optional[str] = None,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Allow user to delete a key(s)
|
||
|
|
|
||
|
|
Ensure user owns that key, unless admin.
|
||
|
|
"""
|
||
|
|
start_time = time.time()
|
||
|
|
try:
|
||
|
|
if tokens is not None and isinstance(tokens, List):
|
||
|
|
hashed_tokens = []
|
||
|
|
for token in tokens:
|
||
|
|
if isinstance(token, str) and token.startswith("sk-"):
|
||
|
|
hashed_token = self.hash_token(token=token)
|
||
|
|
else:
|
||
|
|
hashed_token = token
|
||
|
|
hashed_tokens.append(hashed_token)
|
||
|
|
filter_query: dict = {}
|
||
|
|
if user_id is not None:
|
||
|
|
filter_query = {
|
||
|
|
"AND": [{"token": {"in": hashed_tokens}}, {"user_id": user_id}]
|
||
|
|
}
|
||
|
|
else:
|
||
|
|
filter_query = {"token": {"in": hashed_tokens}}
|
||
|
|
|
||
|
|
deleted_tokens = await self.db.litellm_verificationtoken.delete_many(
|
||
|
|
where=filter_query # type: ignore
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug("deleted_tokens: %s", deleted_tokens)
|
||
|
|
return {"deleted_keys": deleted_tokens}
|
||
|
|
elif (
|
||
|
|
table_name == "team"
|
||
|
|
and team_id_list is not None
|
||
|
|
and isinstance(team_id_list, List)
|
||
|
|
):
|
||
|
|
# admin only endpoint -> `/team/delete`
|
||
|
|
await self.db.litellm_teamtable.delete_many(
|
||
|
|
where={"team_id": {"in": team_id_list}}
|
||
|
|
)
|
||
|
|
return {"deleted_teams": team_id_list}
|
||
|
|
elif (
|
||
|
|
table_name == "key"
|
||
|
|
and team_id_list is not None
|
||
|
|
and isinstance(team_id_list, List)
|
||
|
|
):
|
||
|
|
# admin only endpoint -> `/team/delete`
|
||
|
|
await self.db.litellm_verificationtoken.delete_many(
|
||
|
|
where={"team_id": {"in": team_id_list}}
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
import traceback
|
||
|
|
|
||
|
|
error_msg = f"LiteLLM Prisma Client Exception - delete_data: {str(e)}"
|
||
|
|
print_verbose(error_msg)
|
||
|
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||
|
|
end_time = time.time()
|
||
|
|
_duration = end_time - start_time
|
||
|
|
asyncio.create_task(
|
||
|
|
self.proxy_logging_obj.failure_handler(
|
||
|
|
original_exception=e,
|
||
|
|
duration=_duration,
|
||
|
|
call_type="delete_data",
|
||
|
|
traceback_str=error_traceback,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
raise e
|
||
|
|
|
||
|
|
# Define a retrying strategy with exponential backoff
|
||
|
|
@backoff.on_exception(
|
||
|
|
backoff.expo,
|
||
|
|
Exception, # base exception to catch for the backoff
|
||
|
|
max_tries=3, # maximum number of retries
|
||
|
|
max_time=10, # maximum total time to retry for
|
||
|
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||
|
|
)
|
||
|
|
async def connect(self):
|
||
|
|
start_time = time.time()
|
||
|
|
try:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"PrismaClient: connect() called Attempting to Connect to DB"
|
||
|
|
)
|
||
|
|
if self.db.is_connected() is False:
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"PrismaClient: DB not connected, Attempting to Connect to DB"
|
||
|
|
)
|
||
|
|
await self.db.connect()
|
||
|
|
except Exception as e:
|
||
|
|
import traceback
|
||
|
|
|
||
|
|
error_msg = f"LiteLLM Prisma Client Exception connect(): {str(e)}"
|
||
|
|
print_verbose(error_msg)
|
||
|
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||
|
|
end_time = time.time()
|
||
|
|
_duration = end_time - start_time
|
||
|
|
asyncio.create_task(
|
||
|
|
self.proxy_logging_obj.failure_handler(
|
||
|
|
original_exception=e,
|
||
|
|
duration=_duration,
|
||
|
|
call_type="connect",
|
||
|
|
traceback_str=error_traceback,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
raise e
|
||
|
|
|
||
|
|
# Define a retrying strategy with exponential backoff
|
||
|
|
@backoff.on_exception(
|
||
|
|
backoff.expo,
|
||
|
|
Exception, # base exception to catch for the backoff
|
||
|
|
max_tries=3, # maximum number of retries
|
||
|
|
max_time=10, # maximum total time to retry for
|
||
|
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||
|
|
)
|
||
|
|
async def disconnect(self):
|
||
|
|
start_time = time.time()
|
||
|
|
try:
|
||
|
|
await self.db.disconnect()
|
||
|
|
except Exception as e:
|
||
|
|
import traceback
|
||
|
|
|
||
|
|
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
|
||
|
|
print_verbose(error_msg)
|
||
|
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||
|
|
end_time = time.time()
|
||
|
|
_duration = end_time - start_time
|
||
|
|
asyncio.create_task(
|
||
|
|
self.proxy_logging_obj.failure_handler(
|
||
|
|
original_exception=e,
|
||
|
|
duration=_duration,
|
||
|
|
call_type="disconnect",
|
||
|
|
traceback_str=error_traceback,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
raise e
|
||
|
|
|
||
|
|
async def health_check(self):
|
||
|
|
"""
|
||
|
|
Health check endpoint for the prisma client
|
||
|
|
"""
|
||
|
|
start_time = time.time()
|
||
|
|
try:
|
||
|
|
sql_query = "SELECT 1"
|
||
|
|
|
||
|
|
# Execute the raw query
|
||
|
|
# The asterisk before `user_id_list` unpacks the list into separate arguments
|
||
|
|
response = await self.db.query_raw(sql_query)
|
||
|
|
return response
|
||
|
|
except Exception as e:
|
||
|
|
import traceback
|
||
|
|
|
||
|
|
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}"
|
||
|
|
print_verbose(error_msg)
|
||
|
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||
|
|
end_time = time.time()
|
||
|
|
_duration = end_time - start_time
|
||
|
|
asyncio.create_task(
|
||
|
|
self.proxy_logging_obj.failure_handler(
|
||
|
|
original_exception=e,
|
||
|
|
duration=_duration,
|
||
|
|
call_type="health_check",
|
||
|
|
traceback_str=error_traceback,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
raise e
|
||
|
|
|
||
|
|
async def _get_spend_logs_row_count(self) -> int:
|
||
|
|
try:
|
||
|
|
sql_query = """
|
||
|
|
SELECT reltuples::BIGINT
|
||
|
|
FROM pg_class
|
||
|
|
WHERE oid = '"LiteLLM_SpendLogs"'::regclass;
|
||
|
|
"""
|
||
|
|
result = await self.db.query_raw(query=sql_query)
|
||
|
|
return result[0]["reltuples"]
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
f"Error getting LiteLLM_SpendLogs row count: {e}"
|
||
|
|
)
|
||
|
|
return 0
|
||
|
|
|
||
|
|
async def _set_spend_logs_row_count_in_proxy_state(self) -> None:
|
||
|
|
"""
|
||
|
|
Set the `LiteLLM_SpendLogs`row count in proxy state.
|
||
|
|
|
||
|
|
This is used later to determine if we should run expensive UI Usage queries.
|
||
|
|
"""
|
||
|
|
from litellm.proxy.proxy_server import proxy_state
|
||
|
|
|
||
|
|
_num_spend_logs_rows = await self._get_spend_logs_row_count()
|
||
|
|
proxy_state.set_proxy_state_variable(
|
||
|
|
variable_name="spend_logs_row_count",
|
||
|
|
value=_num_spend_logs_rows,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Health Check Database Methods
|
||
|
|
def _validate_response_time(
|
||
|
|
self, response_time_ms: Optional[float]
|
||
|
|
) -> Optional[float]:
|
||
|
|
"""Validate and clean response time value"""
|
||
|
|
if response_time_ms is None:
|
||
|
|
return None
|
||
|
|
try:
|
||
|
|
value = float(response_time_ms)
|
||
|
|
return (
|
||
|
|
value
|
||
|
|
if value == value and value not in (float("inf"), float("-inf"))
|
||
|
|
else None
|
||
|
|
)
|
||
|
|
except (ValueError, TypeError):
|
||
|
|
verbose_proxy_logger.warning(
|
||
|
|
f"Invalid response_time_ms value: {response_time_ms}"
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
def _clean_details(self, details: Optional[dict]) -> Optional[dict]:
|
||
|
|
"""Clean and validate details JSON"""
|
||
|
|
if not isinstance(details, dict):
|
||
|
|
return None
|
||
|
|
try:
|
||
|
|
return safe_json_loads(safe_dumps(details))
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.warning(f"Failed to clean details JSON: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def save_health_check_result(
|
||
|
|
self,
|
||
|
|
model_name: str,
|
||
|
|
status: str,
|
||
|
|
healthy_count: int = 0,
|
||
|
|
unhealthy_count: int = 0,
|
||
|
|
error_message: Optional[str] = None,
|
||
|
|
response_time_ms: Optional[float] = None,
|
||
|
|
details: Optional[dict] = None,
|
||
|
|
checked_by: Optional[str] = None,
|
||
|
|
model_id: Optional[str] = None,
|
||
|
|
):
|
||
|
|
"""Save health check result to database"""
|
||
|
|
try:
|
||
|
|
# Build base data with required fields
|
||
|
|
health_check_data = {
|
||
|
|
"model_name": str(model_name),
|
||
|
|
"status": str(status),
|
||
|
|
"healthy_count": int(healthy_count),
|
||
|
|
"unhealthy_count": int(unhealthy_count),
|
||
|
|
}
|
||
|
|
|
||
|
|
# Add optional fields using dict comprehension and helper methods
|
||
|
|
optional_fields = {
|
||
|
|
"error_message": str(error_message)[:500] if error_message else None,
|
||
|
|
"response_time_ms": self._validate_response_time(response_time_ms),
|
||
|
|
"details": self._clean_details(details),
|
||
|
|
"checked_by": str(checked_by) if checked_by else None,
|
||
|
|
"model_id": str(model_id) if model_id else None,
|
||
|
|
}
|
||
|
|
|
||
|
|
# Add only non-None optional fields
|
||
|
|
health_check_data.update(
|
||
|
|
{k: v for k, v in optional_fields.items() if v is not None}
|
||
|
|
)
|
||
|
|
|
||
|
|
verbose_proxy_logger.debug(f"Saving health check data: {health_check_data}")
|
||
|
|
return await self.db.litellm_healthchecktable.create(data=health_check_data)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(
|
||
|
|
f"Error saving health check result for model {model_name}: {e}"
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def get_health_check_history(
|
||
|
|
self,
|
||
|
|
model_name: Optional[str] = None,
|
||
|
|
limit: int = 100,
|
||
|
|
offset: int = 0,
|
||
|
|
status_filter: Optional[str] = None,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Get health check history with optional filtering
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
where_clause = {}
|
||
|
|
if model_name:
|
||
|
|
where_clause["model_name"] = model_name
|
||
|
|
if status_filter:
|
||
|
|
where_clause["status"] = status_filter
|
||
|
|
|
||
|
|
results = await self.db.litellm_healthchecktable.find_many(
|
||
|
|
where=where_clause,
|
||
|
|
order={"checked_at": "desc"},
|
||
|
|
take=limit,
|
||
|
|
skip=offset,
|
||
|
|
)
|
||
|
|
return results
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(f"Error getting health check history: {e}")
|
||
|
|
return []
|
||
|
|
|
||
|
|
async def get_all_latest_health_checks(self):
|
||
|
|
"""
|
||
|
|
Get the latest health check for each model
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Get all unique model names first
|
||
|
|
all_checks = await self.db.litellm_healthchecktable.find_many(
|
||
|
|
order={"checked_at": "desc"}
|
||
|
|
)
|
||
|
|
|
||
|
|
# Group by model_name and get the latest for each
|
||
|
|
latest_checks = {}
|
||
|
|
for check in all_checks:
|
||
|
|
if check.model_name not in latest_checks:
|
||
|
|
latest_checks[check.model_name] = check
|
||
|
|
|
||
|
|
return list(latest_checks.values())
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.error(f"Error getting all latest health checks: {e}")
|
||
|
|
return []
|
||
|
|
|
||
|
|
|
||
|
|
### HELPER FUNCTIONS ###
|
||
|
|
|
||
|
|
|
||
|
|
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
|
||
|
|
"""
|
||
|
|
Check if a user_id exists in cache,
|
||
|
|
if not retrieve it.
|
||
|
|
"""
|
||
|
|
cache_key = f"{user_id}_user_api_key_user_id"
|
||
|
|
response = cache.get_cache(key=cache_key)
|
||
|
|
if response is None: # Cache miss
|
||
|
|
user_row = await db.get_data(user_id=user_id)
|
||
|
|
if user_row is not None:
|
||
|
|
print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
|
||
|
|
if hasattr(user_row, "model_dump_json") and callable(
|
||
|
|
getattr(user_row, "model_dump_json")
|
||
|
|
):
|
||
|
|
cache_value = user_row.model_dump_json()
|
||
|
|
cache.set_cache(
|
||
|
|
key=cache_key, value=cache_value, ttl=600
|
||
|
|
) # store for 10 minutes
|
||
|
|
return
|
||
|
|
|
||
|
|
|
||
|
|
async def send_email(
|
||
|
|
receiver_email: Optional[str] = None,
|
||
|
|
subject: Optional[str] = None,
|
||
|
|
html: Optional[str] = None,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
smtp_host,
|
||
|
|
smtp_port,
|
||
|
|
smtp_username,
|
||
|
|
smtp_password,
|
||
|
|
sender_name,
|
||
|
|
sender_email,
|
||
|
|
"""
|
||
|
|
## SERVER SETUP ##
|
||
|
|
|
||
|
|
smtp_host = os.getenv("SMTP_HOST")
|
||
|
|
smtp_port = int(os.getenv("SMTP_PORT", "587")) # default to port 587
|
||
|
|
smtp_username = os.getenv("SMTP_USERNAME")
|
||
|
|
smtp_password = os.getenv("SMTP_PASSWORD")
|
||
|
|
sender_email = os.getenv("SMTP_SENDER_EMAIL", None)
|
||
|
|
if sender_email is None:
|
||
|
|
raise ValueError("Trying to use SMTP, but SMTP_SENDER_EMAIL is not set")
|
||
|
|
if receiver_email is None:
|
||
|
|
raise ValueError(f"No receiver email provided for SMTP email. {receiver_email}")
|
||
|
|
if subject is None:
|
||
|
|
raise ValueError(f"No subject provided for SMTP email. {subject}")
|
||
|
|
if html is None:
|
||
|
|
raise ValueError(f"No HTML body provided for SMTP email. {html}")
|
||
|
|
|
||
|
|
## EMAIL SETUP ##
|
||
|
|
email_message = MIMEMultipart()
|
||
|
|
email_message["From"] = sender_email
|
||
|
|
email_message["To"] = receiver_email
|
||
|
|
email_message["Subject"] = subject
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"sending email from %s to %s", sender_email, receiver_email
|
||
|
|
)
|
||
|
|
|
||
|
|
if smtp_host is None:
|
||
|
|
raise ValueError("Trying to use SMTP, but SMTP_HOST is not set")
|
||
|
|
|
||
|
|
# Attach the body to the email
|
||
|
|
email_message.attach(MIMEText(html, "html"))
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Establish a secure connection with the SMTP server
|
||
|
|
with smtplib.SMTP(
|
||
|
|
host=smtp_host,
|
||
|
|
port=smtp_port,
|
||
|
|
) as server:
|
||
|
|
if os.getenv("SMTP_TLS", "True") != "False":
|
||
|
|
server.starttls()
|
||
|
|
|
||
|
|
# Login to your email account only if smtp_username and smtp_password are provided
|
||
|
|
if smtp_username and smtp_password:
|
||
|
|
server.login(
|
||
|
|
user=smtp_username,
|
||
|
|
password=smtp_password,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Send the email
|
||
|
|
server.send_message(
|
||
|
|
msg=email_message,
|
||
|
|
from_addr=sender_email,
|
||
|
|
to_addrs=receiver_email,
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
verbose_proxy_logger.exception(
|
||
|
|
"An error occurred while sending the email:" + str(e)
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def hash_token(token: str):
|
||
|
|
import hashlib
|
||
|
|
|
||
|
|
# Hash the string using SHA-256
|
||
|
|
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||
|
|
|
||
|
|
return hashed_token
|
||
|
|
|
||
|
|
|
||
|
|
def _hash_token_if_needed(token: str) -> str:
|
||
|
|
"""
|
||
|
|
Hash the token if it's a string and starts with "sk-"
|
||
|
|
|
||
|
|
Else return the token as is
|
||
|
|
"""
|
||
|
|
if token.startswith("sk-"):
|
||
|
|
return hash_token(token=token)
|
||
|
|
else:
|
||
|
|
return token
|
||
|
|
|
||
|
|
|
||
|
|
class ProxyUpdateSpend:
|
||
|
|
@staticmethod
|
||
|
|
async def update_end_user_spend(
|
||
|
|
n_retry_times: int,
|
||
|
|
prisma_client: PrismaClient,
|
||
|
|
proxy_logging_obj: ProxyLogging,
|
||
|
|
end_user_list_transactions: Dict[str, float],
|
||
|
|
):
|
||
|
|
for i in range(n_retry_times + 1):
|
||
|
|
start_time = time.time()
|
||
|
|
try:
|
||
|
|
async with prisma_client.db.tx(
|
||
|
|
timeout=timedelta(seconds=60)
|
||
|
|
) as transaction:
|
||
|
|
async with transaction.batch_() as batcher:
|
||
|
|
for (
|
||
|
|
end_user_id,
|
||
|
|
response_cost,
|
||
|
|
) in end_user_list_transactions.items():
|
||
|
|
if litellm.max_end_user_budget is not None:
|
||
|
|
pass
|
||
|
|
batcher.litellm_endusertable.upsert(
|
||
|
|
where={"user_id": end_user_id},
|
||
|
|
data={
|
||
|
|
"create": {
|
||
|
|
"user_id": end_user_id,
|
||
|
|
"spend": response_cost,
|
||
|
|
"blocked": False,
|
||
|
|
},
|
||
|
|
"update": {"spend": {"increment": response_cost}},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
break
|
||
|
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||
|
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||
|
|
_raise_failed_update_spend_exception(
|
||
|
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||
|
|
)
|
||
|
|
# Optionally, sleep for a bit before retrying
|
||
|
|
await asyncio.sleep(2**i) # Exponential backoff
|
||
|
|
except Exception as e:
|
||
|
|
_raise_failed_update_spend_exception(
|
||
|
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def update_spend_logs(
|
||
|
|
n_retry_times: int,
|
||
|
|
prisma_client: PrismaClient,
|
||
|
|
db_writer_client: Optional[HTTPHandler],
|
||
|
|
proxy_logging_obj: ProxyLogging,
|
||
|
|
):
|
||
|
|
BATCH_SIZE = 100 # Preferred size of each batch to write to the database
|
||
|
|
MAX_LOGS_PER_INTERVAL = (
|
||
|
|
1000 # Maximum number of logs to flush in a single interval
|
||
|
|
)
|
||
|
|
# Get initial logs to process
|
||
|
|
logs_to_process = prisma_client.spend_log_transactions[:MAX_LOGS_PER_INTERVAL]
|
||
|
|
start_time = time.time()
|
||
|
|
try:
|
||
|
|
for i in range(n_retry_times + 1):
|
||
|
|
try:
|
||
|
|
base_url = os.getenv("SPEND_LOGS_URL", None)
|
||
|
|
if (
|
||
|
|
len(logs_to_process) > 0
|
||
|
|
and base_url is not None
|
||
|
|
and db_writer_client is not None
|
||
|
|
):
|
||
|
|
if not base_url.endswith("/"):
|
||
|
|
base_url += "/"
|
||
|
|
verbose_proxy_logger.debug("base_url: {}".format(base_url))
|
||
|
|
response = await db_writer_client.post(
|
||
|
|
url=base_url + "spend/update",
|
||
|
|
data=json.dumps(logs_to_process),
|
||
|
|
headers={"Content-Type": "application/json"},
|
||
|
|
)
|
||
|
|
if response.status_code == 200:
|
||
|
|
prisma_client.spend_log_transactions = (
|
||
|
|
prisma_client.spend_log_transactions[
|
||
|
|
len(logs_to_process) :
|
||
|
|
]
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
for j in range(0, len(logs_to_process), BATCH_SIZE):
|
||
|
|
batch = logs_to_process[j : j + BATCH_SIZE]
|
||
|
|
batch_with_dates = [
|
||
|
|
prisma_client.jsonify_object({**entry})
|
||
|
|
for entry in batch
|
||
|
|
]
|
||
|
|
await prisma_client.db.litellm_spendlogs.create_many(
|
||
|
|
data=batch_with_dates, skip_duplicates=True
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"Flushed {len(batch)} logs to the DB."
|
||
|
|
)
|
||
|
|
|
||
|
|
prisma_client.spend_log_transactions = (
|
||
|
|
prisma_client.spend_log_transactions[len(logs_to_process) :]
|
||
|
|
)
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}"
|
||
|
|
)
|
||
|
|
break
|
||
|
|
except DB_CONNECTION_ERROR_TYPES:
|
||
|
|
if i is None:
|
||
|
|
i = 0
|
||
|
|
if i >= n_retry_times:
|
||
|
|
raise
|
||
|
|
await asyncio.sleep(2**i)
|
||
|
|
except Exception as e:
|
||
|
|
prisma_client.spend_log_transactions = prisma_client.spend_log_transactions[
|
||
|
|
len(logs_to_process) :
|
||
|
|
]
|
||
|
|
_raise_failed_update_spend_exception(
|
||
|
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def disable_spend_updates() -> bool:
|
||
|
|
"""
|
||
|
|
returns True if should not update spend in db
|
||
|
|
Skips writing spend logs and updates to key, team, user spend to DB
|
||
|
|
"""
|
||
|
|
from litellm.proxy.proxy_server import general_settings
|
||
|
|
|
||
|
|
if general_settings.get("disable_spend_updates") is True:
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
async def update_spend( # noqa: PLR0915
|
||
|
|
prisma_client: PrismaClient,
|
||
|
|
db_writer_client: Optional[HTTPHandler],
|
||
|
|
proxy_logging_obj: ProxyLogging,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Batch write updates to db.
|
||
|
|
|
||
|
|
Triggered every minute.
|
||
|
|
|
||
|
|
Requires:
|
||
|
|
user_id_list: dict,
|
||
|
|
keys_list: list,
|
||
|
|
team_list: list,
|
||
|
|
spend_logs: list,
|
||
|
|
"""
|
||
|
|
n_retry_times = 3
|
||
|
|
await proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler(
|
||
|
|
prisma_client=prisma_client,
|
||
|
|
n_retry_times=n_retry_times,
|
||
|
|
proxy_logging_obj=proxy_logging_obj,
|
||
|
|
)
|
||
|
|
|
||
|
|
### UPDATE SPEND LOGS ###
|
||
|
|
verbose_proxy_logger.debug(
|
||
|
|
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))
|
||
|
|
)
|
||
|
|
|
||
|
|
if len(prisma_client.spend_log_transactions) > 0:
|
||
|
|
await ProxyUpdateSpend.update_spend_logs(
|
||
|
|
n_retry_times=n_retry_times,
|
||
|
|
prisma_client=prisma_client,
|
||
|
|
proxy_logging_obj=proxy_logging_obj,
|
||
|
|
db_writer_client=db_writer_client,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _raise_failed_update_spend_exception(
|
||
|
|
e: Exception, start_time: float, proxy_logging_obj: ProxyLogging
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Raise an exception for failed update spend logs
|
||
|
|
|
||
|
|
- Calls proxy_logging_obj.failure_handler to log the error
|
||
|
|
- Ensures error messages says "Non-Blocking"
|
||
|
|
"""
|
||
|
|
import traceback
|
||
|
|
|
||
|
|
error_msg = (
|
||
|
|
f"[Non-Blocking]LiteLLM Prisma Client Exception - update spend logs: {str(e)}"
|
||
|
|
)
|
||
|
|
error_traceback = error_msg + "\n" + traceback.format_exc()
|
||
|
|
end_time = time.time()
|
||
|
|
_duration = end_time - start_time
|
||
|
|
asyncio.create_task(
|
||
|
|
proxy_logging_obj.failure_handler(
|
||
|
|
original_exception=e,
|
||
|
|
duration=_duration,
|
||
|
|
call_type="update_spend",
|
||
|
|
traceback_str=error_traceback,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
raise e
|
||
|
|
|
||
|
|
|
||
|
|
def _is_projected_spend_over_limit(
|
||
|
|
current_spend: float, soft_budget_limit: Optional[float]
|
||
|
|
):
|
||
|
|
from datetime import date
|
||
|
|
|
||
|
|
if soft_budget_limit is None:
|
||
|
|
# If there's no limit, we can't exceed it.
|
||
|
|
return False
|
||
|
|
|
||
|
|
today = date.today()
|
||
|
|
|
||
|
|
# Finding the first day of the next month, then subtracting one day to get the end of the current month.
|
||
|
|
if today.month == 12: # December edge case
|
||
|
|
end_month = date(today.year + 1, 1, 1) - timedelta(days=1)
|
||
|
|
else:
|
||
|
|
end_month = date(today.year, today.month + 1, 1) - timedelta(days=1)
|
||
|
|
|
||
|
|
remaining_days = (end_month - today).days
|
||
|
|
|
||
|
|
# Check for the start of the month to avoid division by zero
|
||
|
|
if today.day == 1:
|
||
|
|
daily_spend_estimate = current_spend
|
||
|
|
else:
|
||
|
|
daily_spend_estimate = current_spend / (today.day - 1)
|
||
|
|
|
||
|
|
# Total projected spend for the month
|
||
|
|
projected_spend = current_spend + (daily_spend_estimate * remaining_days)
|
||
|
|
|
||
|
|
if projected_spend > soft_budget_limit:
|
||
|
|
print_verbose("Projected spend exceeds soft budget limit!")
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def _get_projected_spend_over_limit(
|
||
|
|
current_spend: float, soft_budget_limit: Optional[float]
|
||
|
|
) -> Optional[tuple]:
|
||
|
|
import datetime
|
||
|
|
|
||
|
|
if soft_budget_limit is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
today = datetime.date.today()
|
||
|
|
end_month = datetime.date(today.year, today.month + 1, 1) - datetime.timedelta(
|
||
|
|
days=1
|
||
|
|
)
|
||
|
|
remaining_days = (end_month - today).days
|
||
|
|
|
||
|
|
daily_spend = current_spend / (
|
||
|
|
today.day - 1
|
||
|
|
) # assuming the current spend till today (not including today)
|
||
|
|
projected_spend = daily_spend * remaining_days
|
||
|
|
|
||
|
|
if projected_spend > soft_budget_limit:
|
||
|
|
approx_days = soft_budget_limit / daily_spend
|
||
|
|
limit_exceed_date = today + datetime.timedelta(days=approx_days)
|
||
|
|
|
||
|
|
# return the projected spend and the date it will exceeded
|
||
|
|
return projected_spend, limit_exceed_date
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _is_valid_team_configs(team_id=None, team_config=None, request_data=None):
|
||
|
|
if team_id is None or team_config is None or request_data is None:
|
||
|
|
return
|
||
|
|
# check if valid model called for team
|
||
|
|
if "models" in team_config:
|
||
|
|
valid_models = team_config.pop("models")
|
||
|
|
model_in_request = request_data["model"]
|
||
|
|
if model_in_request not in valid_models:
|
||
|
|
raise Exception(
|
||
|
|
f"Invalid model for team {team_id}: {model_in_request}. Valid models for team are: {valid_models}\n"
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
|
||
|
|
def _to_ns(dt):
|
||
|
|
return int(dt.timestamp() * 1e9)
|
||
|
|
|
||
|
|
|
||
|
|
def _check_and_merge_model_level_guardrails(
|
||
|
|
data: dict, llm_router: Optional[Router]
|
||
|
|
) -> dict:
|
||
|
|
"""
|
||
|
|
Check if the model has guardrails defined and merge them with existing guardrails in the request data.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
data: The request data dict
|
||
|
|
llm_router: The LLM router instance to get deployment info from
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Modified data dict with merged guardrails (if any model-level guardrails exist)
|
||
|
|
"""
|
||
|
|
if llm_router is None:
|
||
|
|
return data
|
||
|
|
|
||
|
|
# Get the model ID from the data
|
||
|
|
metadata = data.get("metadata") or {}
|
||
|
|
model_info = metadata.get("model_info") or {}
|
||
|
|
model_id = model_info.get("id", None)
|
||
|
|
|
||
|
|
if model_id is None:
|
||
|
|
return data
|
||
|
|
|
||
|
|
# Check if the model has guardrails
|
||
|
|
deployment = llm_router.get_deployment(model_id=model_id)
|
||
|
|
if deployment is None:
|
||
|
|
return data
|
||
|
|
|
||
|
|
model_level_guardrails = deployment.litellm_params.get("guardrails")
|
||
|
|
|
||
|
|
if model_level_guardrails is None:
|
||
|
|
return data
|
||
|
|
|
||
|
|
# Merge model-level guardrails with existing ones
|
||
|
|
return _merge_guardrails_with_existing(data, model_level_guardrails)
|
||
|
|
|
||
|
|
|
||
|
|
def _merge_guardrails_with_existing(data: dict, model_level_guardrails: Any) -> dict:
|
||
|
|
"""
|
||
|
|
Merge model-level guardrails with any existing guardrails in the request data.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
data: The request data dict
|
||
|
|
model_level_guardrails: Guardrails defined at the model level
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Modified data dict with merged guardrails in metadata
|
||
|
|
"""
|
||
|
|
modified_data = data.copy()
|
||
|
|
metadata = modified_data.setdefault("metadata", {})
|
||
|
|
existing_guardrails = metadata.get("guardrails", [])
|
||
|
|
|
||
|
|
# Ensure existing_guardrails is a list
|
||
|
|
if not isinstance(existing_guardrails, list):
|
||
|
|
existing_guardrails = [existing_guardrails] if existing_guardrails else []
|
||
|
|
|
||
|
|
# Ensure model_level_guardrails is a list
|
||
|
|
if not isinstance(model_level_guardrails, list):
|
||
|
|
model_level_guardrails = (
|
||
|
|
[model_level_guardrails] if model_level_guardrails else []
|
||
|
|
)
|
||
|
|
|
||
|
|
# Combine existing and model-level guardrails
|
||
|
|
metadata["guardrails"] = list(set(existing_guardrails + model_level_guardrails))
|
||
|
|
return modified_data
|
||
|
|
|
||
|
|
|
||
|
|
def get_error_message_str(e: Exception) -> str:
|
||
|
|
error_message = ""
|
||
|
|
if isinstance(e, HTTPException):
|
||
|
|
if isinstance(e.detail, str):
|
||
|
|
error_message = e.detail
|
||
|
|
elif isinstance(e.detail, dict):
|
||
|
|
error_message = json.dumps(e.detail)
|
||
|
|
elif hasattr(e, "message"):
|
||
|
|
_error = getattr(e, "message", None)
|
||
|
|
if isinstance(_error, str):
|
||
|
|
error_message = _error
|
||
|
|
elif isinstance(_error, dict):
|
||
|
|
error_message = json.dumps(_error)
|
||
|
|
else:
|
||
|
|
error_message = str(e)
|
||
|
|
else:
|
||
|
|
error_message = str(e)
|
||
|
|
return error_message
|
||
|
|
|
||
|
|
|
||
|
|
def _get_redoc_url() -> Optional[str]:
|
||
|
|
"""
|
||
|
|
Get the Redoc URL from the environment variables.
|
||
|
|
|
||
|
|
- If REDOC_URL is set, return it.
|
||
|
|
- If NO_REDOC is True, return None.
|
||
|
|
- Otherwise, default to "/redoc".
|
||
|
|
"""
|
||
|
|
if redoc_url := os.getenv("REDOC_URL"):
|
||
|
|
return redoc_url
|
||
|
|
|
||
|
|
if str_to_bool(os.getenv("NO_REDOC")) is True:
|
||
|
|
return None
|
||
|
|
|
||
|
|
return "/redoc"
|
||
|
|
|
||
|
|
|
||
|
|
def _get_docs_url() -> Optional[str]:
|
||
|
|
"""
|
||
|
|
Get the docs (Swagger UI) URL from the environment variables.
|
||
|
|
|
||
|
|
- If DOCS_URL is set, return it.
|
||
|
|
- If NO_DOCS is True, return None.
|
||
|
|
- Otherwise, default to "/".
|
||
|
|
"""
|
||
|
|
if docs_url := os.getenv("DOCS_URL"):
|
||
|
|
return docs_url
|
||
|
|
|
||
|
|
if str_to_bool(os.getenv("NO_DOCS")) is True:
|
||
|
|
return None
|
||
|
|
|
||
|
|
return "/"
|
||
|
|
|
||
|
|
|
||
|
|
def handle_exception_on_proxy(e: Exception) -> ProxyException:
|
||
|
|
"""
|
||
|
|
Returns an Exception as ProxyException, this ensures all exceptions are OpenAI API compatible
|
||
|
|
"""
|
||
|
|
from fastapi import status
|
||
|
|
|
||
|
|
verbose_proxy_logger.exception(f"Exception: {e}")
|
||
|
|
|
||
|
|
if isinstance(e, HTTPException):
|
||
|
|
return ProxyException(
|
||
|
|
message=getattr(e, "detail", f"error({str(e)})"),
|
||
|
|
type=ProxyErrorTypes.internal_server_error,
|
||
|
|
param=getattr(e, "param", "None"),
|
||
|
|
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||
|
|
)
|
||
|
|
elif isinstance(e, ProxyException):
|
||
|
|
return e
|
||
|
|
return ProxyException(
|
||
|
|
message="Internal Server Error, " + str(e),
|
||
|
|
type=ProxyErrorTypes.internal_server_error,
|
||
|
|
param=getattr(e, "param", "None"),
|
||
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _premium_user_check():
|
||
|
|
"""
|
||
|
|
Raises an HTTPException if the user is not a premium user
|
||
|
|
"""
|
||
|
|
from litellm.proxy.proxy_server import premium_user
|
||
|
|
|
||
|
|
if not premium_user:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=403,
|
||
|
|
detail={
|
||
|
|
"error": f"This feature is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}"
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def is_known_model(model: Optional[str], llm_router: Optional[Router]) -> bool:
|
||
|
|
"""
|
||
|
|
Returns True if the model is in the llm_router model names
|
||
|
|
"""
|
||
|
|
if model is None or llm_router is None:
|
||
|
|
return False
|
||
|
|
model_names = llm_router.get_model_names()
|
||
|
|
|
||
|
|
is_in_list = False
|
||
|
|
if model in model_names:
|
||
|
|
is_in_list = True
|
||
|
|
|
||
|
|
return is_in_list
|
||
|
|
|
||
|
|
|
||
|
|
def join_paths(base_path: str, route: str) -> str:
|
||
|
|
# Remove trailing slashes from base_path and leading slashes from route
|
||
|
|
base_path = base_path.rstrip("/")
|
||
|
|
route = route.lstrip("/")
|
||
|
|
|
||
|
|
# If base_path is empty, return route with leading slash
|
||
|
|
if not base_path:
|
||
|
|
return f"/{route}" if route else "/"
|
||
|
|
|
||
|
|
# If route is empty, return just base_path
|
||
|
|
if not route:
|
||
|
|
return base_path
|
||
|
|
|
||
|
|
# Join with single slash
|
||
|
|
return f"{base_path}/{route}"
|
||
|
|
|
||
|
|
|
||
|
|
def get_custom_url(request_base_url: str, route: Optional[str] = None) -> str:
|
||
|
|
# Use environment variable value, otherwise use URL from request
|
||
|
|
server_base_url = get_proxy_base_url()
|
||
|
|
if server_base_url is not None:
|
||
|
|
base_url = server_base_url
|
||
|
|
else:
|
||
|
|
base_url = request_base_url
|
||
|
|
|
||
|
|
server_root_path = get_server_root_path()
|
||
|
|
if route is not None:
|
||
|
|
if server_root_path != "":
|
||
|
|
# First join base_url with server_root_path, then with route
|
||
|
|
intermediate_url = join_paths(base_url, server_root_path)
|
||
|
|
return join_paths(intermediate_url, route)
|
||
|
|
else:
|
||
|
|
return join_paths(base_url, route)
|
||
|
|
else:
|
||
|
|
return join_paths(base_url, server_root_path)
|
||
|
|
|
||
|
|
|
||
|
|
def get_proxy_base_url() -> Optional[str]:
|
||
|
|
"""
|
||
|
|
Get the proxy base url from the environment variables.
|
||
|
|
"""
|
||
|
|
return os.getenv("PROXY_BASE_URL")
|
||
|
|
|
||
|
|
|
||
|
|
def get_server_root_path() -> str:
|
||
|
|
"""
|
||
|
|
Get the server root path from the environment variables.
|
||
|
|
|
||
|
|
- If SERVER_ROOT_PATH is set, return it.
|
||
|
|
- Otherwise, default to "/".
|
||
|
|
"""
|
||
|
|
return os.getenv("SERVER_ROOT_PATH", "/")
|
||
|
|
|
||
|
|
|
||
|
|
def get_prisma_client_or_throw(message: str):
|
||
|
|
from litellm.proxy.proxy_server import prisma_client
|
||
|
|
|
||
|
|
if prisma_client is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
detail={"error": message},
|
||
|
|
)
|
||
|
|
return prisma_client
|
||
|
|
|
||
|
|
|
||
|
|
def is_valid_api_key(key: str) -> bool:
|
||
|
|
"""
|
||
|
|
Validates API key format:
|
||
|
|
- sk- keys: must match ^sk-[A-Za-z0-9_-]+$
|
||
|
|
- hashed keys: must match ^[a-fA-F0-9]{64}$
|
||
|
|
- Length between 20 and 100 characters
|
||
|
|
"""
|
||
|
|
import re
|
||
|
|
|
||
|
|
if not isinstance(key, str):
|
||
|
|
return False
|
||
|
|
if 3 <= len(key) <= 100:
|
||
|
|
if re.match(r"^sk-[A-Za-z0-9_-]+$", key):
|
||
|
|
return True
|
||
|
|
if re.match(r"^[a-fA-F0-9]{64}$", key):
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def construct_database_url_from_env_vars() -> Optional[str]:
|
||
|
|
"""
|
||
|
|
Construct a DATABASE_URL from individual environment variables.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Optional[str]: The constructed DATABASE_URL or None if required variables are missing
|
||
|
|
"""
|
||
|
|
import urllib.parse
|
||
|
|
|
||
|
|
# Check if all required variables are provided
|
||
|
|
database_host = os.getenv("DATABASE_HOST")
|
||
|
|
database_username = os.getenv("DATABASE_USERNAME")
|
||
|
|
database_password = os.getenv("DATABASE_PASSWORD")
|
||
|
|
database_name = os.getenv("DATABASE_NAME")
|
||
|
|
|
||
|
|
if database_host and database_username and database_name:
|
||
|
|
# Handle the problem of special character escaping in the database URL
|
||
|
|
database_username_enc = urllib.parse.quote_plus(database_username)
|
||
|
|
database_password_enc = (
|
||
|
|
urllib.parse.quote_plus(database_password) if database_password else ""
|
||
|
|
)
|
||
|
|
database_name_enc = urllib.parse.quote_plus(database_name)
|
||
|
|
|
||
|
|
# Construct DATABASE_URL from the provided variables
|
||
|
|
if database_password:
|
||
|
|
database_url = f"postgresql://{database_username_enc}:{database_password_enc}@{database_host}/{database_name_enc}"
|
||
|
|
else:
|
||
|
|
database_url = f"postgresql://{database_username_enc}@{database_host}/{database_name_enc}"
|
||
|
|
|
||
|
|
return database_url
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
async def get_available_models_for_user(
|
||
|
|
user_api_key_dict: "UserAPIKeyAuth",
|
||
|
|
llm_router: Optional["Router"],
|
||
|
|
general_settings: dict,
|
||
|
|
user_model: Optional[str],
|
||
|
|
prisma_client: Optional["PrismaClient"] = None,
|
||
|
|
proxy_logging_obj: Optional["ProxyLogging"] = None,
|
||
|
|
team_id: Optional[str] = None,
|
||
|
|
include_model_access_groups: bool = False,
|
||
|
|
only_model_access_groups: bool = False,
|
||
|
|
return_wildcard_routes: bool = False,
|
||
|
|
user_api_key_cache: Optional["DualCache"] = None,
|
||
|
|
) -> List[str]:
|
||
|
|
"""
|
||
|
|
Get the list of models available to a user based on their API key and team permissions.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
user_api_key_dict: User API key authentication object
|
||
|
|
llm_router: LiteLLM router instance
|
||
|
|
general_settings: General settings from config
|
||
|
|
user_model: User-specific model
|
||
|
|
prisma_client: Prisma client for database operations
|
||
|
|
proxy_logging_obj: Proxy logging object
|
||
|
|
team_id: Specific team ID to check (optional)
|
||
|
|
include_model_access_groups: Whether to include model access groups
|
||
|
|
only_model_access_groups: Whether to only return model access groups
|
||
|
|
return_wildcard_routes: Whether to return wildcard routes
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of model names available to the user
|
||
|
|
"""
|
||
|
|
from litellm.proxy.auth.model_checks import (
|
||
|
|
get_key_models,
|
||
|
|
get_team_models,
|
||
|
|
get_complete_model_list,
|
||
|
|
)
|
||
|
|
from litellm.proxy.auth.auth_checks import get_team_object
|
||
|
|
from litellm.proxy.management_endpoints.team_endpoints import validate_membership
|
||
|
|
|
||
|
|
# Get proxy model list and access groups
|
||
|
|
if llm_router is None:
|
||
|
|
proxy_model_list = []
|
||
|
|
model_access_groups = {}
|
||
|
|
else:
|
||
|
|
proxy_model_list = llm_router.get_model_names()
|
||
|
|
model_access_groups = llm_router.get_model_access_groups()
|
||
|
|
|
||
|
|
# Get key models
|
||
|
|
key_models = get_key_models(
|
||
|
|
user_api_key_dict=user_api_key_dict,
|
||
|
|
proxy_model_list=proxy_model_list,
|
||
|
|
model_access_groups=model_access_groups,
|
||
|
|
include_model_access_groups=include_model_access_groups,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Get team models
|
||
|
|
team_models: List[str] = user_api_key_dict.team_models
|
||
|
|
|
||
|
|
# If specific team_id is provided, validate and get team models
|
||
|
|
if team_id and prisma_client and proxy_logging_obj and user_api_key_cache:
|
||
|
|
key_models = []
|
||
|
|
team_object = await get_team_object(
|
||
|
|
team_id=team_id,
|
||
|
|
prisma_client=prisma_client,
|
||
|
|
user_api_key_cache=user_api_key_cache,
|
||
|
|
proxy_logging_obj=proxy_logging_obj,
|
||
|
|
)
|
||
|
|
validate_membership(user_api_key_dict=user_api_key_dict, team_table=team_object)
|
||
|
|
team_models = team_object.models
|
||
|
|
|
||
|
|
team_models = get_team_models(
|
||
|
|
team_models=team_models,
|
||
|
|
proxy_model_list=proxy_model_list,
|
||
|
|
model_access_groups=model_access_groups,
|
||
|
|
include_model_access_groups=include_model_access_groups,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Get complete model list
|
||
|
|
all_models = get_complete_model_list(
|
||
|
|
key_models=key_models,
|
||
|
|
team_models=team_models,
|
||
|
|
proxy_model_list=proxy_model_list,
|
||
|
|
user_model=user_model,
|
||
|
|
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||
|
|
return_wildcard_routes=return_wildcard_routes,
|
||
|
|
llm_router=llm_router,
|
||
|
|
model_access_groups=model_access_groups,
|
||
|
|
include_model_access_groups=include_model_access_groups,
|
||
|
|
only_model_access_groups=only_model_access_groups,
|
||
|
|
)
|
||
|
|
|
||
|
|
return all_models
|
||
|
|
|
||
|
|
|
||
|
|
def create_model_info_response(
|
||
|
|
model_id: str,
|
||
|
|
provider: str,
|
||
|
|
include_metadata: bool = False,
|
||
|
|
fallback_type: Optional[str] = None,
|
||
|
|
llm_router: Optional["Router"] = None,
|
||
|
|
) -> dict:
|
||
|
|
"""
|
||
|
|
Create a standardized model info response.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model_id: The model ID
|
||
|
|
provider: The model provider
|
||
|
|
include_metadata: Whether to include metadata
|
||
|
|
fallback_type: Type of fallbacks to include
|
||
|
|
llm_router: LiteLLM router instance
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary containing model information
|
||
|
|
"""
|
||
|
|
from litellm.proxy.auth.model_checks import get_all_fallbacks
|
||
|
|
|
||
|
|
model_info = {
|
||
|
|
"id": model_id,
|
||
|
|
"object": "model",
|
||
|
|
"created": DEFAULT_MODEL_CREATED_AT_TIME,
|
||
|
|
"owned_by": provider,
|
||
|
|
}
|
||
|
|
|
||
|
|
# Add metadata if requested
|
||
|
|
if include_metadata:
|
||
|
|
metadata = {}
|
||
|
|
|
||
|
|
# Default fallback_type to "general" if include_metadata is true
|
||
|
|
effective_fallback_type = (
|
||
|
|
fallback_type if fallback_type is not None else "general"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Validate fallback_type
|
||
|
|
valid_fallback_types = ["general", "context_window", "content_policy"]
|
||
|
|
if effective_fallback_type not in valid_fallback_types:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=400,
|
||
|
|
detail=f"Invalid fallback_type. Must be one of: {valid_fallback_types}",
|
||
|
|
)
|
||
|
|
|
||
|
|
fallbacks = get_all_fallbacks(
|
||
|
|
model=model_id,
|
||
|
|
llm_router=llm_router,
|
||
|
|
fallback_type=effective_fallback_type,
|
||
|
|
)
|
||
|
|
metadata["fallbacks"] = fallbacks
|
||
|
|
|
||
|
|
model_info["metadata"] = metadata
|
||
|
|
|
||
|
|
return model_info
|
||
|
|
|
||
|
|
|
||
|
|
def validate_model_access(
|
||
|
|
model_id: str,
|
||
|
|
available_models: List[str],
|
||
|
|
) -> None:
|
||
|
|
"""
|
||
|
|
Validate that a model is accessible to the user.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model_id: The model ID to validate
|
||
|
|
available_models: List of models available to the user
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
HTTPException: If the model is not accessible
|
||
|
|
"""
|
||
|
|
if model_id not in available_models:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=404,
|
||
|
|
detail="The model `{}` does not exist or is not accessible".format(model_id)
|
||
|
|
)
|