559 lines
20 KiB
Python
559 lines
20 KiB
Python
from datetime import datetime
|
|
from typing import Any, Dict, List, Literal, Optional, Type, Union, get_args
|
|
|
|
from litellm._logging import verbose_logger
|
|
from litellm.caching import DualCache
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.types.guardrails import (
|
|
DynamicGuardrailParams,
|
|
GuardrailEventHooks,
|
|
LitellmParams,
|
|
Mode,
|
|
PiiEntityType,
|
|
)
|
|
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
|
|
from litellm.types.utils import (
|
|
CallTypes,
|
|
LLMResponseTypes,
|
|
StandardLoggingGuardrailInformation,
|
|
)
|
|
|
|
dc = DualCache()
|
|
|
|
|
|
class CustomGuardrail(CustomLogger):
|
|
def __init__(
|
|
self,
|
|
guardrail_name: Optional[str] = None,
|
|
supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
|
|
event_hook: Optional[
|
|
Union[GuardrailEventHooks, List[GuardrailEventHooks], Mode]
|
|
] = None,
|
|
default_on: bool = False,
|
|
mask_request_content: bool = False,
|
|
mask_response_content: bool = False,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize the CustomGuardrail class
|
|
|
|
Args:
|
|
guardrail_name: The name of the guardrail. This is the name used in your requests.
|
|
supported_event_hooks: The event hooks that the guardrail supports
|
|
event_hook: The event hook to run the guardrail on
|
|
default_on: If True, the guardrail will be run by default on all requests
|
|
mask_request_content: If True, the guardrail will mask the request content
|
|
mask_response_content: If True, the guardrail will mask the response content
|
|
"""
|
|
self.guardrail_name = guardrail_name
|
|
self.supported_event_hooks = supported_event_hooks
|
|
self.event_hook: Optional[
|
|
Union[GuardrailEventHooks, List[GuardrailEventHooks], Mode]
|
|
] = event_hook
|
|
self.default_on: bool = default_on
|
|
self.mask_request_content: bool = mask_request_content
|
|
self.mask_response_content: bool = mask_response_content
|
|
|
|
if supported_event_hooks:
|
|
|
|
## validate event_hook is in supported_event_hooks
|
|
self._validate_event_hook(event_hook, supported_event_hooks)
|
|
super().__init__(**kwargs)
|
|
|
|
@staticmethod
|
|
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
|
|
"""
|
|
Returns the config model for the guardrail
|
|
|
|
This is used to render the config model in the UI.
|
|
"""
|
|
return None
|
|
|
|
def _validate_event_hook(
|
|
self,
|
|
event_hook: Optional[
|
|
Union[GuardrailEventHooks, List[GuardrailEventHooks], Mode]
|
|
],
|
|
supported_event_hooks: List[GuardrailEventHooks],
|
|
) -> None:
|
|
|
|
def _validate_event_hook_list_is_in_supported_event_hooks(
|
|
event_hook: Union[List[GuardrailEventHooks], List[str]],
|
|
supported_event_hooks: List[GuardrailEventHooks],
|
|
) -> None:
|
|
for hook in event_hook:
|
|
if isinstance(hook, str):
|
|
hook = GuardrailEventHooks(hook)
|
|
if hook not in supported_event_hooks:
|
|
raise ValueError(
|
|
f"Event hook {hook} is not in the supported event hooks {supported_event_hooks}"
|
|
)
|
|
|
|
if event_hook is None:
|
|
return
|
|
if isinstance(event_hook, str):
|
|
event_hook = GuardrailEventHooks(event_hook)
|
|
if isinstance(event_hook, list):
|
|
_validate_event_hook_list_is_in_supported_event_hooks(
|
|
event_hook, supported_event_hooks
|
|
)
|
|
elif isinstance(event_hook, Mode):
|
|
_validate_event_hook_list_is_in_supported_event_hooks(
|
|
list(event_hook.tags.values()), supported_event_hooks
|
|
)
|
|
if event_hook.default:
|
|
_validate_event_hook_list_is_in_supported_event_hooks(
|
|
[event_hook.default], supported_event_hooks
|
|
)
|
|
elif isinstance(event_hook, GuardrailEventHooks):
|
|
if event_hook not in supported_event_hooks:
|
|
raise ValueError(
|
|
f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}"
|
|
)
|
|
|
|
def get_guardrail_from_metadata(
|
|
self, data: dict
|
|
) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]:
|
|
"""
|
|
Returns the guardrail(s) to be run from the metadata or root
|
|
"""
|
|
if "guardrails" in data:
|
|
return data["guardrails"]
|
|
metadata = data.get("metadata") or {}
|
|
requested_guardrails = metadata.get("guardrails") or []
|
|
if requested_guardrails:
|
|
return requested_guardrails
|
|
return requested_guardrails
|
|
|
|
def _guardrail_is_in_requested_guardrails(
|
|
self,
|
|
requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]],
|
|
) -> bool:
|
|
|
|
for _guardrail in requested_guardrails:
|
|
if isinstance(_guardrail, dict):
|
|
if self.guardrail_name in _guardrail:
|
|
|
|
return True
|
|
elif isinstance(_guardrail, str):
|
|
if self.guardrail_name == _guardrail:
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
async def async_pre_call_deployment_hook(
|
|
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes]
|
|
) -> Optional[dict]:
|
|
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
|
|
# should run guardrail
|
|
litellm_guardrails = kwargs.get("guardrails")
|
|
if litellm_guardrails is None or not isinstance(litellm_guardrails, list):
|
|
return kwargs
|
|
|
|
if (
|
|
self.should_run_guardrail(
|
|
data=kwargs, event_type=GuardrailEventHooks.pre_call
|
|
)
|
|
is not True
|
|
):
|
|
return kwargs
|
|
|
|
# CHECK IF GUARDRAIL REJECTS THE REQUEST
|
|
if call_type == CallTypes.completion or call_type == CallTypes.acompletion:
|
|
result = await self.async_pre_call_hook(
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_id=kwargs.get("user_api_key_user_id"),
|
|
team_id=kwargs.get("user_api_key_team_id"),
|
|
end_user_id=kwargs.get("user_api_key_end_user_id"),
|
|
api_key=kwargs.get("user_api_key_hash"),
|
|
request_route=kwargs.get("user_api_key_request_route"),
|
|
),
|
|
cache=dc,
|
|
data=kwargs,
|
|
call_type=call_type.value or "acompletion", # type: ignore
|
|
)
|
|
|
|
if result is not None and isinstance(result, dict):
|
|
result_messages = result.get("messages")
|
|
if result_messages is not None: # update for any pii / masking logic
|
|
kwargs["messages"] = result_messages
|
|
|
|
return kwargs
|
|
|
|
async def async_post_call_success_deployment_hook(
|
|
self,
|
|
request_data: dict,
|
|
response: LLMResponseTypes,
|
|
call_type: Optional[CallTypes],
|
|
) -> Optional[LLMResponseTypes]:
|
|
"""
|
|
Allow modifying / reviewing the response just after it's received from the deployment.
|
|
"""
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
|
|
# should run guardrail
|
|
litellm_guardrails = request_data.get("guardrails")
|
|
if litellm_guardrails is None or not isinstance(litellm_guardrails, list):
|
|
return response
|
|
|
|
if (
|
|
self.should_run_guardrail(
|
|
data=request_data, event_type=GuardrailEventHooks.post_call
|
|
)
|
|
is not True
|
|
):
|
|
return response
|
|
|
|
# CHECK IF GUARDRAIL REJECTS THE REQUEST
|
|
result = await self.async_post_call_success_hook(
|
|
user_api_key_dict=UserAPIKeyAuth(
|
|
user_id=request_data.get("user_api_key_user_id"),
|
|
team_id=request_data.get("user_api_key_team_id"),
|
|
end_user_id=request_data.get("user_api_key_end_user_id"),
|
|
api_key=request_data.get("user_api_key_hash"),
|
|
request_route=request_data.get("user_api_key_request_route"),
|
|
),
|
|
data=request_data,
|
|
response=response,
|
|
)
|
|
|
|
if result is None or not isinstance(result, get_args(LLMResponseTypes)):
|
|
return response
|
|
|
|
return result
|
|
|
|
def should_run_guardrail(
|
|
self,
|
|
data,
|
|
event_type: GuardrailEventHooks,
|
|
) -> bool:
|
|
"""
|
|
Returns True if the guardrail should be run on the event_type
|
|
"""
|
|
requested_guardrails = self.get_guardrail_from_metadata(data)
|
|
verbose_logger.debug(
|
|
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s",
|
|
self.guardrail_name,
|
|
event_type,
|
|
self.event_hook,
|
|
requested_guardrails,
|
|
self.default_on,
|
|
)
|
|
if self.default_on is True:
|
|
if self._event_hook_is_event_type(event_type):
|
|
if isinstance(self.event_hook, Mode):
|
|
try:
|
|
from litellm_enterprise.integrations.custom_guardrail import (
|
|
EnterpriseCustomGuardrailHelper,
|
|
)
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Setting tag-based guardrails is only available in litellm-enterprise. You must be a premium user to use this feature."
|
|
)
|
|
result = EnterpriseCustomGuardrailHelper._should_run_if_mode_by_tag(
|
|
data, self.event_hook
|
|
)
|
|
if result is not None:
|
|
return result
|
|
return True
|
|
return False
|
|
|
|
if (
|
|
self.event_hook
|
|
and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
|
|
and event_type.value != "logging_only"
|
|
):
|
|
return False
|
|
|
|
if not self._event_hook_is_event_type(event_type):
|
|
return False
|
|
|
|
if isinstance(self.event_hook, Mode):
|
|
try:
|
|
from litellm_enterprise.integrations.custom_guardrail import (
|
|
EnterpriseCustomGuardrailHelper,
|
|
)
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Setting tag-based guardrails is only available in litellm-enterprise. You must be a premium user to use this feature."
|
|
)
|
|
result = EnterpriseCustomGuardrailHelper._should_run_if_mode_by_tag(
|
|
data, self.event_hook
|
|
)
|
|
if result is not None:
|
|
return result
|
|
return True
|
|
|
|
def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool:
|
|
"""
|
|
Returns True if the event_hook is the same as the event_type
|
|
|
|
eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
|
|
eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
|
|
"""
|
|
|
|
if self.event_hook is None:
|
|
return True
|
|
if isinstance(self.event_hook, list):
|
|
return event_type.value in self.event_hook
|
|
if isinstance(self.event_hook, Mode):
|
|
return event_type.value in self.event_hook.tags.values()
|
|
return self.event_hook == event_type.value
|
|
|
|
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
|
|
"""
|
|
Returns `extra_body` to be added to the request body for the Guardrail API call
|
|
|
|
Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc.
|
|
|
|
```
|
|
[{"lakera_guard": {"extra_body": {"foo": "bar"}}}]
|
|
```
|
|
|
|
Will return: for guardrail=`lakera-guard`:
|
|
{
|
|
"foo": "bar"
|
|
}
|
|
|
|
Args:
|
|
request_data: The original `request_data` passed to LiteLLM Proxy
|
|
"""
|
|
requested_guardrails = self.get_guardrail_from_metadata(request_data)
|
|
|
|
# Look for the guardrail configuration matching self.guardrail_name
|
|
for guardrail in requested_guardrails:
|
|
if isinstance(guardrail, dict) and self.guardrail_name in guardrail:
|
|
# Get the configuration for this guardrail
|
|
guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams(
|
|
**guardrail[self.guardrail_name]
|
|
)
|
|
if self._validate_premium_user() is not True:
|
|
return {}
|
|
|
|
# Return the extra_body if it exists, otherwise empty dict
|
|
return guardrail_config.get("extra_body", {})
|
|
|
|
return {}
|
|
|
|
def _validate_premium_user(self) -> bool:
|
|
"""
|
|
Returns True if the user is a premium user
|
|
"""
|
|
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
|
|
|
if premium_user is not True:
|
|
verbose_logger.warning(
|
|
f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}"
|
|
)
|
|
return False
|
|
return True
|
|
|
|
def add_standard_logging_guardrail_information_to_request_data(
|
|
self,
|
|
guardrail_json_response: Union[Exception, str, dict, List[dict]],
|
|
request_data: dict,
|
|
guardrail_status: Literal["success", "failure"],
|
|
start_time: Optional[float] = None,
|
|
end_time: Optional[float] = None,
|
|
duration: Optional[float] = None,
|
|
masked_entity_count: Optional[Dict[str, int]] = None,
|
|
) -> None:
|
|
"""
|
|
Builds `StandardLoggingGuardrailInformation` and adds it to the request metadata so it can be used for logging to DataDog, Langfuse, etc.
|
|
"""
|
|
if isinstance(guardrail_json_response, Exception):
|
|
guardrail_json_response = str(guardrail_json_response)
|
|
from litellm.types.utils import GuardrailMode
|
|
|
|
slg = StandardLoggingGuardrailInformation(
|
|
guardrail_name=self.guardrail_name,
|
|
guardrail_mode=(
|
|
GuardrailMode(**self.event_hook.model_dump()) # type: ignore
|
|
if isinstance(self.event_hook, Mode)
|
|
else self.event_hook
|
|
),
|
|
guardrail_response=guardrail_json_response,
|
|
guardrail_status=guardrail_status,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
duration=duration,
|
|
masked_entity_count=masked_entity_count,
|
|
)
|
|
if "metadata" in request_data:
|
|
if request_data["metadata"] is None:
|
|
request_data["metadata"] = {}
|
|
request_data["metadata"]["standard_logging_guardrail_information"] = slg
|
|
elif "litellm_metadata" in request_data:
|
|
request_data["litellm_metadata"][
|
|
"standard_logging_guardrail_information"
|
|
] = slg
|
|
else:
|
|
verbose_logger.warning(
|
|
"unable to log guardrail information. No metadata found in request_data"
|
|
)
|
|
|
|
async def apply_guardrail(
|
|
self,
|
|
text: str,
|
|
language: Optional[str] = None,
|
|
entities: Optional[List[PiiEntityType]] = None,
|
|
) -> str:
|
|
"""
|
|
Apply your guardrail logic to the given text
|
|
|
|
Args:
|
|
text: The text to apply the guardrail to
|
|
language: The language of the text
|
|
entities: The entities to mask, optional
|
|
|
|
Any of the custom guardrails can override this method to provide custom guardrail logic
|
|
|
|
Returns the text with the guardrail applied
|
|
|
|
Raises:
|
|
Exception:
|
|
- If the guardrail raises an exception
|
|
|
|
"""
|
|
return text
|
|
|
|
def _process_response(
|
|
self,
|
|
response: Optional[Dict],
|
|
request_data: dict,
|
|
start_time: Optional[float] = None,
|
|
end_time: Optional[float] = None,
|
|
duration: Optional[float] = None,
|
|
):
|
|
"""
|
|
Add StandardLoggingGuardrailInformation to the request data
|
|
|
|
This gets logged on downsteam Langfuse, DataDog, etc.
|
|
"""
|
|
# Convert None to empty dict to satisfy type requirements
|
|
guardrail_response = {} if response is None else response
|
|
self.add_standard_logging_guardrail_information_to_request_data(
|
|
guardrail_json_response=guardrail_response,
|
|
request_data=request_data,
|
|
guardrail_status="success",
|
|
duration=duration,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
)
|
|
return response
|
|
|
|
def _process_error(
|
|
self,
|
|
e: Exception,
|
|
request_data: dict,
|
|
start_time: Optional[float] = None,
|
|
end_time: Optional[float] = None,
|
|
duration: Optional[float] = None,
|
|
):
|
|
"""
|
|
Add StandardLoggingGuardrailInformation to the request data
|
|
|
|
This gets logged on downsteam Langfuse, DataDog, etc.
|
|
"""
|
|
self.add_standard_logging_guardrail_information_to_request_data(
|
|
guardrail_json_response=e,
|
|
request_data=request_data,
|
|
guardrail_status="failure",
|
|
duration=duration,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
)
|
|
raise e
|
|
|
|
def mask_content_in_string(
|
|
self,
|
|
content_string: str,
|
|
mask_string: str,
|
|
start_index: int,
|
|
end_index: int,
|
|
) -> str:
|
|
"""
|
|
Mask the content in the string between the start and end indices.
|
|
"""
|
|
|
|
# Do nothing if the start or end are not valid
|
|
if not (0 <= start_index < end_index <= len(content_string)):
|
|
return content_string
|
|
|
|
# Mask the content
|
|
return content_string[:start_index] + mask_string + content_string[end_index:]
|
|
|
|
def update_in_memory_litellm_params(self, litellm_params: LitellmParams) -> None:
|
|
"""
|
|
Update the guardrails litellm params in memory
|
|
"""
|
|
pass
|
|
|
|
|
|
def log_guardrail_information(func):
|
|
"""
|
|
Decorator to add standard logging guardrail information to any function
|
|
|
|
Add this decorator to ensure your guardrail response is logged to DataDog, OTEL, s3, GCS etc.
|
|
|
|
Logs for:
|
|
- pre_call
|
|
- during_call
|
|
- TODO: log post_call. This is more involved since the logs are sent to DD, s3 before the guardrail is even run
|
|
"""
|
|
import asyncio
|
|
import functools
|
|
|
|
@functools.wraps(func)
|
|
async def async_wrapper(*args, **kwargs):
|
|
start_time = datetime.now() # Move start_time inside the wrapper
|
|
self: CustomGuardrail = args[0]
|
|
request_data: dict = kwargs.get("data") or kwargs.get("request_data") or {}
|
|
try:
|
|
response = await func(*args, **kwargs)
|
|
return self._process_response(
|
|
response=response,
|
|
request_data=request_data,
|
|
start_time=start_time.timestamp(),
|
|
end_time=datetime.now().timestamp(),
|
|
duration=(datetime.now() - start_time).total_seconds(),
|
|
)
|
|
except Exception as e:
|
|
return self._process_error(
|
|
e=e,
|
|
request_data=request_data,
|
|
start_time=start_time.timestamp(),
|
|
end_time=datetime.now().timestamp(),
|
|
duration=(datetime.now() - start_time).total_seconds(),
|
|
)
|
|
|
|
@functools.wraps(func)
|
|
def sync_wrapper(*args, **kwargs):
|
|
start_time = datetime.now() # Move start_time inside the wrapper
|
|
self: CustomGuardrail = args[0]
|
|
request_data: dict = kwargs.get("data") or kwargs.get("request_data") or {}
|
|
try:
|
|
response = func(*args, **kwargs)
|
|
return self._process_response(
|
|
response=response,
|
|
request_data=request_data,
|
|
duration=(datetime.now() - start_time).total_seconds(),
|
|
)
|
|
except Exception as e:
|
|
return self._process_error(
|
|
e=e,
|
|
request_data=request_data,
|
|
duration=(datetime.now() - start_time).total_seconds(),
|
|
)
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if asyncio.iscoroutinefunction(func):
|
|
return async_wrapper(*args, **kwargs)
|
|
return sync_wrapper(*args, **kwargs)
|
|
|
|
return wrapper
|