ai-station/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_endpoints.py

973 lines
32 KiB
Python
Raw Normal View History

2025-12-25 14:54:33 +00:00
"""
CRUD ENDPOINTS FOR GUARDRAILS
"""
import inspect
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from litellm._logging import verbose_proxy_logger
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.guardrails.guardrail_registry import GuardrailRegistry
from litellm.types.guardrails import (
PII_ENTITY_CATEGORIES_MAP,
BedrockGuardrailConfigModel,
Guardrail,
GuardrailEventHooks,
GuardrailInfoResponse,
GuardrailUIAddGuardrailSettings,
LakeraV2GuardrailConfigModel,
ListGuardrailsResponse,
LitellmParams,
PatchGuardrailRequest,
PiiAction,
PiiEntityType,
PresidioPresidioConfigModelUserInterface,
SupportedGuardrailIntegrations,
)
#### GUARDRAILS ENDPOINTS ####
router = APIRouter()
GUARDRAIL_REGISTRY = GuardrailRegistry()
def _get_guardrails_list_response(
guardrails_config: List[Dict],
) -> ListGuardrailsResponse:
"""
Helper function to get the guardrails list response
"""
guardrail_configs: List[GuardrailInfoResponse] = []
for guardrail in guardrails_config:
guardrail_configs.append(
GuardrailInfoResponse(
guardrail_name=guardrail.get("guardrail_name"),
litellm_params=guardrail.get("litellm_params"),
guardrail_info=guardrail.get("guardrail_info"),
)
)
return ListGuardrailsResponse(guardrails=guardrail_configs)
@router.get(
"/guardrails/list",
tags=["Guardrails"],
dependencies=[Depends(user_api_key_auth)],
response_model=ListGuardrailsResponse,
)
async def list_guardrails():
"""
List the guardrails that are available on the proxy server
👉 [Guardrail docs](https://docs.litellm.ai/docs/proxy/guardrails/quick_start)
Example Request:
```bash
curl -X GET "http://localhost:4000/guardrails/list" -H "Authorization: Bearer <your_api_key>"
```
Example Response:
```json
{
"guardrails": [
{
"guardrail_name": "bedrock-pre-guard",
"guardrail_info": {
"params": [
{
"name": "toxicity_score",
"type": "float",
"description": "Score between 0-1 indicating content toxicity level"
},
{
"name": "pii_detection",
"type": "boolean"
}
]
}
}
]
}
```
"""
from litellm.proxy.proxy_server import proxy_config
config = proxy_config.config
_guardrails_config = cast(Optional[list[dict]], config.get("guardrails"))
if _guardrails_config is None:
return _get_guardrails_list_response([])
return _get_guardrails_list_response(_guardrails_config)
@router.get(
"/v2/guardrails/list",
tags=["Guardrails"],
dependencies=[Depends(user_api_key_auth)],
response_model=ListGuardrailsResponse,
)
async def list_guardrails_v2():
"""
List the guardrails that are available in the database using GuardrailRegistry
👉 [Guardrail docs](https://docs.litellm.ai/docs/proxy/guardrails/quick_start)
Example Request:
```bash
curl -X GET "http://localhost:4000/v2/guardrails/list" -H "Authorization: Bearer <your_api_key>"
```
Example Response:
```json
{
"guardrails": [
{
"guardrail_id": "123e4567-e89b-12d3-a456-426614174000",
"guardrail_name": "my-bedrock-guard",
"litellm_params": {
"guardrail": "bedrock",
"mode": "pre_call",
"guardrailIdentifier": "ff6ujrregl1q",
"guardrailVersion": "DRAFT",
"default_on": true
},
"guardrail_info": {
"description": "Bedrock content moderation guardrail"
}
}
]
}
```
"""
from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Prisma client not initialized")
try:
guardrails = await GUARDRAIL_REGISTRY.get_all_guardrails_from_db(
prisma_client=prisma_client
)
guardrail_configs: List[GuardrailInfoResponse] = []
seen_guardrail_ids = set()
for guardrail in guardrails:
guardrail_configs.append(
GuardrailInfoResponse(
guardrail_id=guardrail.get("guardrail_id"),
guardrail_name=guardrail.get("guardrail_name"),
litellm_params=guardrail.get("litellm_params"),
guardrail_info=guardrail.get("guardrail_info"),
created_at=guardrail.get("created_at"),
updated_at=guardrail.get("updated_at"),
guardrail_definition_location="db",
)
)
seen_guardrail_ids.add(guardrail.get("guardrail_id"))
# get guardrails initialized on litellm config.yaml
in_memory_guardrails = IN_MEMORY_GUARDRAIL_HANDLER.list_in_memory_guardrails()
for guardrail in in_memory_guardrails:
# only add guardrails that are not in DB guardrail list already
if guardrail.get("guardrail_id") not in seen_guardrail_ids:
guardrail_configs.append(
GuardrailInfoResponse(
guardrail_id=guardrail.get("guardrail_id"),
guardrail_name=guardrail.get("guardrail_name"),
litellm_params=dict(guardrail.get("litellm_params") or {}),
guardrail_info=dict(guardrail.get("guardrail_info") or {}),
guardrail_definition_location="config",
)
)
seen_guardrail_ids.add(guardrail.get("guardrail_id"))
return ListGuardrailsResponse(guardrails=guardrail_configs)
except Exception as e:
verbose_proxy_logger.exception(f"Error getting guardrails from db: {e}")
raise HTTPException(status_code=500, detail=str(e))
class CreateGuardrailRequest(BaseModel):
guardrail: Guardrail
@router.post(
"/guardrails",
tags=["Guardrails"],
dependencies=[Depends(user_api_key_auth)],
)
async def create_guardrail(request: CreateGuardrailRequest):
"""
Create a new guardrail
👉 [Guardrail docs](https://docs.litellm.ai/docs/proxy/guardrails/quick_start)
Example Request:
```bash
curl -X POST "http://localhost:4000/guardrails" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"guardrail": {
"guardrail_name": "my-bedrock-guard",
"litellm_params": {
"guardrail": "bedrock",
"mode": "pre_call",
"guardrailIdentifier": "ff6ujrregl1q",
"guardrailVersion": "DRAFT",
"default_on": true
},
"guardrail_info": {
"description": "Bedrock content moderation guardrail"
}
}
}'
```
Example Response:
```json
{
"guardrail_id": "123e4567-e89b-12d3-a456-426614174000",
"guardrail_name": "my-bedrock-guard",
"litellm_params": {
"guardrail": "bedrock",
"mode": "pre_call",
"guardrailIdentifier": "ff6ujrregl1q",
"guardrailVersion": "DRAFT",
"default_on": true
},
"guardrail_info": {
"description": "Bedrock content moderation guardrail"
},
"created_at": "2023-11-09T12:34:56.789Z",
"updated_at": "2023-11-09T12:34:56.789Z"
}
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Prisma client not initialized")
try:
result = await GUARDRAIL_REGISTRY.add_guardrail_to_db(
guardrail=request.guardrail, prisma_client=prisma_client
)
return result
except Exception as e:
verbose_proxy_logger.exception(f"Error adding guardrail to db: {e}")
raise HTTPException(status_code=500, detail=str(e))
class UpdateGuardrailRequest(BaseModel):
guardrail: Guardrail
@router.put(
"/guardrails/{guardrail_id}",
tags=["Guardrails"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_guardrail(guardrail_id: str, request: UpdateGuardrailRequest):
"""
Update an existing guardrail
👉 [Guardrail docs](https://docs.litellm.ai/docs/proxy/guardrails/quick_start)
Example Request:
```bash
curl -X PUT "http://localhost:4000/guardrails/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"guardrail": {
"guardrail_name": "updated-bedrock-guard",
"litellm_params": {
"guardrail": "bedrock",
"mode": "pre_call",
"guardrailIdentifier": "ff6ujrregl1q",
"guardrailVersion": "1.0",
"default_on": true
},
"guardrail_info": {
"description": "Updated Bedrock content moderation guardrail"
}
}
}'
```
Example Response:
```json
{
"guardrail_id": "123e4567-e89b-12d3-a456-426614174000",
"guardrail_name": "updated-bedrock-guard",
"litellm_params": {
"guardrail": "bedrock",
"mode": "pre_call",
"guardrailIdentifier": "ff6ujrregl1q",
"guardrailVersion": "1.0",
"default_on": true
},
"guardrail_info": {
"description": "Updated Bedrock content moderation guardrail"
},
"created_at": "2023-11-09T12:34:56.789Z",
"updated_at": "2023-11-09T13:45:12.345Z"
}
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Prisma client not initialized")
try:
# Check if guardrail exists
existing_guardrail = await GUARDRAIL_REGISTRY.get_guardrail_by_id_from_db(
guardrail_id=guardrail_id, prisma_client=prisma_client
)
if existing_guardrail is None:
raise HTTPException(
status_code=404, detail=f"Guardrail with ID {guardrail_id} not found"
)
result = await GUARDRAIL_REGISTRY.update_guardrail_in_db(
guardrail_id=guardrail_id,
guardrail=request.guardrail,
prisma_client=prisma_client,
)
return result
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete(
"/guardrails/{guardrail_id}",
tags=["Guardrails"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_guardrail(guardrail_id: str):
"""
Delete a guardrail
👉 [Guardrail docs](https://docs.litellm.ai/docs/proxy/guardrails/quick_start)
Example Request:
```bash
curl -X DELETE "http://localhost:4000/guardrails/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>"
```
Example Response:
```json
{
"message": "Guardrail 123e4567-e89b-12d3-a456-426614174000 deleted successfully"
}
```
"""
from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Prisma client not initialized")
try:
# Check if guardrail exists
existing_guardrail = await GUARDRAIL_REGISTRY.get_guardrail_by_id_from_db(
guardrail_id=guardrail_id, prisma_client=prisma_client
)
if existing_guardrail is None:
raise HTTPException(
status_code=404, detail=f"Guardrail with ID {guardrail_id} not found"
)
result = await GUARDRAIL_REGISTRY.delete_guardrail_from_db(
guardrail_id=guardrail_id, prisma_client=prisma_client
)
# delete in memory guardrail
IN_MEMORY_GUARDRAIL_HANDLER.delete_in_memory_guardrail(
guardrail_id=guardrail_id,
)
return result
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.patch(
"/guardrails/{guardrail_id}",
tags=["Guardrails"],
dependencies=[Depends(user_api_key_auth)],
)
async def patch_guardrail(guardrail_id: str, request: PatchGuardrailRequest):
"""
Partially update an existing guardrail
👉 [Guardrail docs](https://docs.litellm.ai/docs/proxy/guardrails/quick_start)
This endpoint allows updating specific fields of a guardrail without sending the entire object.
Only the following fields can be updated:
- guardrail_name: The name of the guardrail
- default_on: Whether the guardrail is enabled by default
- guardrail_info: Additional information about the guardrail
Example Request:
```bash
curl -X PATCH "http://localhost:4000/guardrails/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"guardrail_name": "updated-name",
"default_on": true,
"guardrail_info": {
"description": "Updated description"
}
}'
```
Example Response:
```json
{
"guardrail_id": "123e4567-e89b-12d3-a456-426614174000",
"guardrail_name": "updated-name",
"litellm_params": {
"guardrail": "bedrock",
"mode": "pre_call",
"guardrailIdentifier": "ff6ujrregl1q",
"guardrailVersion": "DRAFT",
"default_on": true
},
"guardrail_info": {
"description": "Updated description"
},
"created_at": "2023-11-09T12:34:56.789Z",
"updated_at": "2023-11-09T14:22:33.456Z"
}
```
"""
from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Prisma client not initialized")
try:
# Check if guardrail exists and get current data
existing_guardrail = await GUARDRAIL_REGISTRY.get_guardrail_by_id_from_db(
guardrail_id=guardrail_id, prisma_client=prisma_client
)
if existing_guardrail is None:
raise HTTPException(
status_code=404, detail=f"Guardrail with ID {guardrail_id} not found"
)
# Create updated guardrail object
guardrail_name = (
request.guardrail_name
if request.guardrail_name is not None
else existing_guardrail.get("guardrail_name")
)
# Update litellm_params if default_on is provided or pii_entities_config is provided
litellm_params = LitellmParams(
**dict(existing_guardrail.get("litellm_params", {}))
)
if request.litellm_params is not None:
requested_litellm_params = request.litellm_params.model_dump(
exclude_unset=True
)
litellm_params_dict = litellm_params.model_dump(exclude_unset=True)
litellm_params_dict.update(requested_litellm_params)
litellm_params = LitellmParams(**litellm_params_dict)
# Update guardrail_info if provided
guardrail_info = (
request.guardrail_info
if request.guardrail_info is not None
else existing_guardrail.get("guardrail_info", {})
)
# Create the guardrail object
guardrail = Guardrail(
guardrail_id=guardrail_id,
guardrail_name=guardrail_name or "",
litellm_params=litellm_params,
guardrail_info=guardrail_info,
)
result = await GUARDRAIL_REGISTRY.update_guardrail_in_db(
guardrail_id=guardrail_id,
guardrail=guardrail,
prisma_client=prisma_client,
)
# update in memory guardrail
IN_MEMORY_GUARDRAIL_HANDLER.update_in_memory_guardrail(
guardrail_id=guardrail_id,
guardrail=guardrail,
)
return result
except HTTPException as e:
raise e
except Exception as e:
verbose_proxy_logger.exception(f"Error updating guardrail: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/guardrails/{guardrail_id}",
tags=["Guardrails"],
dependencies=[Depends(user_api_key_auth)],
)
@router.get(
"/guardrails/{guardrail_id}/info",
tags=["Guardrails"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_guardrail_info(guardrail_id: str):
"""
Get detailed information about a specific guardrail by ID
👉 [Guardrail docs](https://docs.litellm.ai/docs/proxy/guardrails/quick_start)
Example Request:
```bash
curl -X GET "http://localhost:4000/guardrails/123e4567-e89b-12d3-a456-426614174000/info" \\
-H "Authorization: Bearer <your_api_key>"
```
Example Response:
```json
{
"guardrail_id": "123e4567-e89b-12d3-a456-426614174000",
"guardrail_name": "my-bedrock-guard",
"litellm_params": {
"guardrail": "bedrock",
"mode": "pre_call",
"guardrailIdentifier": "ff6ujrregl1q",
"guardrailVersion": "DRAFT",
"default_on": true
},
"guardrail_info": {
"description": "Bedrock content moderation guardrail"
},
"created_at": "2023-11-09T12:34:56.789Z",
"updated_at": "2023-11-09T12:34:56.789Z"
}
```
"""
from litellm.litellm_core_utils.litellm_logging import _get_masked_values
from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Prisma client not initialized")
try:
result = await GUARDRAIL_REGISTRY.get_guardrail_by_id_from_db(
guardrail_id=guardrail_id, prisma_client=prisma_client
)
if result is None:
result = IN_MEMORY_GUARDRAIL_HANDLER.get_guardrail_by_id(
guardrail_id=guardrail_id
)
if result is None:
raise HTTPException(
status_code=404, detail=f"Guardrail with ID {guardrail_id} not found"
)
litellm_params: Optional[Union[LitellmParams, dict]] = result.get(
"litellm_params"
)
result_litellm_params_dict = (
litellm_params.model_dump(exclude_none=True)
if isinstance(litellm_params, LitellmParams)
else litellm_params
) or {}
masked_litellm_params_dict = _get_masked_values(
result_litellm_params_dict,
unmasked_length=4,
number_of_asterisks=4,
)
return GuardrailInfoResponse(
guardrail_id=result.get("guardrail_id"),
guardrail_name=result.get("guardrail_name"),
litellm_params=masked_litellm_params_dict,
guardrail_info=dict(result.get("guardrail_info") or {}),
created_at=result.get("created_at"),
updated_at=result.get("updated_at"),
)
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/guardrails/ui/add_guardrail_settings",
tags=["Guardrails"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_guardrail_ui_settings():
"""
Get the UI settings for the guardrails
Returns:
- Supported entities for guardrails
- Supported modes for guardrails
- PII entity categories for UI organization
"""
# Convert the PII_ENTITY_CATEGORIES_MAP to the format expected by the UI
category_maps = []
for category, entities in PII_ENTITY_CATEGORIES_MAP.items():
category_maps.append({"category": category, "entities": entities})
return GuardrailUIAddGuardrailSettings(
supported_entities=list(PiiEntityType),
supported_actions=list(PiiAction),
supported_modes=list(GuardrailEventHooks),
pii_entity_categories=category_maps,
)
def _get_field_type_from_annotation(field_annotation: Any) -> str:
"""
Convert a Python type annotation to a UI-friendly type string
"""
# Handle Union types (like Optional[T])
if (
hasattr(field_annotation, "__origin__")
and field_annotation.__origin__ is Union
and hasattr(field_annotation, "__args__")
):
# For Optional[T], get the non-None type
args = field_annotation.__args__
non_none_args = [arg for arg in args if arg is not type(None)]
if non_none_args:
field_annotation = non_none_args[0]
# Handle List types
if hasattr(field_annotation, "__origin__") and field_annotation.__origin__ is list:
return "array"
# Handle Dict types
if hasattr(field_annotation, "__origin__") and field_annotation.__origin__ is dict:
return "dict"
# Handle Literal types
if hasattr(field_annotation, "__origin__") and hasattr(
field_annotation, "__args__"
):
# Check for Literal types (Python 3.8+)
origin = field_annotation.__origin__
if hasattr(origin, "__name__") and origin.__name__ == "Literal":
return "select" # For dropdown/select inputs
# Handle basic types
if field_annotation is str:
return "string"
elif field_annotation is int:
return "number"
elif field_annotation is float:
return "number"
elif field_annotation is bool:
return "boolean"
elif field_annotation is dict:
return "object"
elif field_annotation is list:
return "array"
# Default to string for unknown types
return "string"
def _extract_literal_values(annotation: Any) -> List[str]:
"""
Extract literal values from a Literal type annotation
"""
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"):
origin = annotation.__origin__
if hasattr(origin, "__name__") and origin.__name__ == "Literal":
return list(annotation.__args__)
return []
def _get_dict_key_options(field_annotation: Any) -> Optional[List[str]]:
"""
Extract key options from Dict[Literal[...], T] types
"""
if (
hasattr(field_annotation, "__origin__")
and field_annotation.__origin__ is dict
and hasattr(field_annotation, "__args__")
):
args = field_annotation.__args__
if len(args) >= 2:
key_type = args[0]
return _extract_literal_values(key_type)
return None
def _get_dict_value_type(field_annotation: Any) -> str:
"""
Get the value type from Dict[K, V] types
"""
if (
hasattr(field_annotation, "__origin__")
and field_annotation.__origin__ is dict
and hasattr(field_annotation, "__args__")
):
args = field_annotation.__args__
if len(args) >= 2:
value_type = args[1]
return _get_field_type_from_annotation(value_type)
return "string"
def _get_list_element_options(field_annotation: Any) -> Optional[List[str]]:
"""
Extract element options from List[Literal[...]] types
"""
if (
hasattr(field_annotation, "__origin__")
and field_annotation.__origin__ is list
and hasattr(field_annotation, "__args__")
):
args = field_annotation.__args__
if len(args) >= 1:
element_type = args[0]
return _extract_literal_values(element_type)
return None
def _extract_fields_recursive(
model: Type[BaseModel],
depth: int = 0,
) -> Dict[str, Any]:
# Check if we've exceeded the maximum recursion depth
if depth > DEFAULT_MAX_RECURSE_DEPTH:
raise HTTPException(
status_code=400,
detail=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing model fields. Please check the model structure for excessive nesting.",
)
fields = {}
for field_name, field in model.model_fields.items():
# Skip optional_params if it's not meaningfully overridden
if field_name == "optional_params":
field_annotation = field.annotation
if field_annotation is None:
continue
# Check if the annotation is still a generic TypeVar (not specialized)
if isinstance(field_annotation, TypeVar) or (
hasattr(field_annotation, "__origin__")
and field_annotation.__origin__ is TypeVar
):
# Skip this field as it's not meaningfully overridden
continue
# Also skip if it's a generic type that wasn't specialized
if hasattr(field_annotation, "__name__") and field_annotation.__name__ in (
"T",
"TypeVar",
):
continue
# Get field metadata
description = field.description or field_name
# Check if this field is required
required = field.is_required()
# Check if the field annotation is a BaseModel subclass
field_annotation = field.annotation
# Handle Optional types and get the actual type
if field_annotation is None:
continue
if (
hasattr(field_annotation, "__origin__")
and field_annotation.__origin__ is Union
and hasattr(field_annotation, "__args__")
):
# For Optional[BaseModel], get the non-None type
args = field_annotation.__args__
non_none_args = [arg for arg in args if arg is not type(None)]
if non_none_args:
field_annotation = non_none_args[0]
# Check if this is a BaseModel subclass
is_basemodel_subclass = (
inspect.isclass(field_annotation)
and issubclass(field_annotation, BaseModel)
and field_annotation is not BaseModel
)
if is_basemodel_subclass:
# Recursively get fields from the nested model
nested_fields = _extract_fields_recursive(
cast(Type[BaseModel], field_annotation), depth + 1
)
fields[field_name] = {
"description": description,
"required": required,
"type": "nested",
"fields": nested_fields,
}
else:
# Determine the field type from annotation
field_type = _get_field_type_from_annotation(field_annotation)
# Check for custom UI type override
field_json_schema_extra = getattr(field, "json_schema_extra", {})
if field_json_schema_extra and "ui_type" in field_json_schema_extra:
field_type = field_json_schema_extra["ui_type"].value
elif field_json_schema_extra and "type" in field_json_schema_extra:
field_type = field_json_schema_extra["type"]
# Add the field to the dictionary
field_dict = {
"description": description,
"required": required,
"type": field_type,
}
# Extract options from type annotations
if field_type == "dict":
# For Dict[Literal[...], T] types, extract key options
dict_key_options = _get_dict_key_options(field_annotation)
if dict_key_options:
field_dict["dict_key_options"] = dict_key_options
# Extract value type for the dict values
dict_value_type = _get_dict_value_type(field_annotation)
field_dict["dict_value_type"] = dict_value_type
elif field_type == "array":
# For List[Literal[...]] types, extract element options
list_element_options = _get_list_element_options(field_annotation)
if list_element_options:
field_dict["options"] = list_element_options
field_dict["type"] = "multiselect"
# Add options if they exist in json_schema_extra (this takes precedence)
if field_json_schema_extra and "options" in field_json_schema_extra:
field_dict["options"] = field_json_schema_extra["options"]
# Add default value if it exists
if field.default is not None and field.default is not ...:
field_dict["default_value"] = field.default
fields[field_name] = field_dict
return fields
def _get_fields_from_model(model_class: Type[BaseModel]) -> Dict[str, Any]:
"""
Get the fields from a Pydantic model as a nested dictionary structure
"""
return _extract_fields_recursive(model_class, depth=0)
@router.get(
"/guardrails/ui/provider_specific_params",
tags=["Guardrails"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_provider_specific_params():
"""
Get provider-specific parameters for different guardrail types.
Returns a dictionary mapping guardrail providers to their specific parameters,
including parameter names, descriptions, and whether they are required.
Example Response:
```json
{
"bedrock": {
"guardrailIdentifier": {
"description": "The ID of your guardrail on Bedrock",
"required": true,
"type": null
},
"guardrailVersion": {
"description": "The version of your Bedrock guardrail (e.g., DRAFT or version number)",
"required": true,
"type": null
}
},
"azure_content_safety_text_moderation": {
"api_key": {
"description": "API key for the Azure Content Safety Text Moderation guardrail",
"required": false,
"type": null
},
"optional_params": {
"description": "Optional parameters for the Azure Content Safety Text Moderation guardrail",
"required": true,
"type": "nested",
"fields": {
"severity_threshold": {
"description": "Severity threshold for the Azure Content Safety Text Moderation guardrail across all categories",
"required": false,
"type": null
},
"categories": {
"description": "Categories to scan for the Azure Content Safety Text Moderation guardrail",
"required": false,
"type": "multiselect",
"options": ["Hate", "SelfHarm", "Sexual", "Violence"],
"default_value": None
}
}
}
}
}
```
"""
# Get fields from the models
bedrock_fields = _get_fields_from_model(BedrockGuardrailConfigModel)
presidio_fields = _get_fields_from_model(PresidioPresidioConfigModelUserInterface)
lakera_v2_fields = _get_fields_from_model(LakeraV2GuardrailConfigModel)
# Return the provider-specific parameters
provider_params = {
SupportedGuardrailIntegrations.BEDROCK.value: bedrock_fields,
SupportedGuardrailIntegrations.PRESIDIO.value: presidio_fields,
SupportedGuardrailIntegrations.LAKERA_V2.value: lakera_v2_fields,
}
### get the config model for the guardrail - go through the registry and get the config model for the guardrail
from litellm.proxy.guardrails.guardrail_registry import guardrail_class_registry
for guardrail_name, guardrail_class in guardrail_class_registry.items():
guardrail_config_model = guardrail_class.get_config_model()
if guardrail_config_model:
fields = _get_fields_from_model(guardrail_config_model)
ui_friendly_name = guardrail_config_model.ui_friendly_name()
fields["ui_friendly_name"] = ui_friendly_name
provider_params[guardrail_name] = fields
return provider_params