355 lines
11 KiB
Python
355 lines
11 KiB
Python
# What is this?
|
|
## Common checks for /v1/models and `/model/info`
|
|
from typing import Dict, List, Optional, Set
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
|
|
from litellm.router import Router
|
|
from litellm.router_utils.fallback_event_handlers import get_fallback_model_group
|
|
from litellm.types.router import LiteLLM_Params
|
|
from litellm.utils import get_valid_models
|
|
|
|
|
|
def _check_wildcard_routing(model: str) -> bool:
|
|
"""
|
|
Returns True if a model is a provider wildcard.
|
|
|
|
eg:
|
|
- anthropic/*
|
|
- openai/*
|
|
- *
|
|
"""
|
|
if "*" in model:
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_provider_models(
|
|
provider: str, litellm_params: Optional[LiteLLM_Params] = None
|
|
) -> Optional[List[str]]:
|
|
"""
|
|
Returns the list of known models by provider
|
|
"""
|
|
if provider == "*":
|
|
return get_valid_models(litellm_params=litellm_params)
|
|
|
|
if provider in litellm.models_by_provider:
|
|
provider_models = get_valid_models(
|
|
custom_llm_provider=provider, litellm_params=litellm_params
|
|
)
|
|
return provider_models
|
|
return None
|
|
|
|
|
|
def _get_models_from_access_groups(
|
|
model_access_groups: Dict[str, List[str]],
|
|
all_models: List[str],
|
|
include_model_access_groups: Optional[bool] = False,
|
|
) -> List[str]:
|
|
idx_to_remove = []
|
|
new_models = []
|
|
for idx, model in enumerate(all_models):
|
|
if model in model_access_groups:
|
|
if (
|
|
not include_model_access_groups
|
|
): # remove access group, unless requested - e.g. when creating a key
|
|
idx_to_remove.append(idx)
|
|
new_models.extend(model_access_groups[model])
|
|
|
|
for idx in sorted(idx_to_remove, reverse=True):
|
|
all_models.pop(idx)
|
|
|
|
all_models.extend(new_models)
|
|
return all_models
|
|
|
|
|
|
async def get_mcp_server_ids(
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
) -> List[str]:
|
|
"""
|
|
Returns the list of MCP server ids for a given key by querying the object_permission table
|
|
"""
|
|
from litellm.proxy.proxy_server import prisma_client
|
|
|
|
if prisma_client is None:
|
|
return []
|
|
|
|
if user_api_key_dict.object_permission_id is None:
|
|
return []
|
|
|
|
# Make a direct SQL query to get just the mcp_servers
|
|
try:
|
|
|
|
result = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
|
where={"object_permission_id": user_api_key_dict.object_permission_id},
|
|
)
|
|
if result and result.mcp_servers:
|
|
return result.mcp_servers
|
|
return []
|
|
except Exception:
|
|
return []
|
|
|
|
|
|
def get_key_models(
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
proxy_model_list: List[str],
|
|
model_access_groups: Dict[str, List[str]],
|
|
include_model_access_groups: Optional[bool] = False,
|
|
only_model_access_groups: Optional[bool] = False,
|
|
) -> List[str]:
|
|
"""
|
|
Returns:
|
|
- List of model name strings
|
|
- Empty list if no models set
|
|
- If model_access_groups is provided, only return models that are in the access groups
|
|
- If include_model_access_groups is True, it includes the 'keys' of the model_access_groups
|
|
in the response - {"beta-models": ["gpt-4", "claude-v1"]} -> returns 'beta-models'
|
|
"""
|
|
all_models: List[str] = []
|
|
if len(user_api_key_dict.models) > 0:
|
|
all_models = user_api_key_dict.models
|
|
if SpecialModelNames.all_team_models.value in all_models:
|
|
all_models = user_api_key_dict.team_models
|
|
if SpecialModelNames.all_proxy_models.value in all_models:
|
|
all_models = proxy_model_list
|
|
|
|
all_models = _get_models_from_access_groups(
|
|
model_access_groups=model_access_groups, all_models=all_models
|
|
)
|
|
|
|
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
|
|
return all_models
|
|
|
|
|
|
def get_team_models(
|
|
team_models: List[str],
|
|
proxy_model_list: List[str],
|
|
model_access_groups: Dict[str, List[str]],
|
|
include_model_access_groups: Optional[bool] = False,
|
|
) -> List[str]:
|
|
"""
|
|
Returns:
|
|
- List of model name strings
|
|
- Empty list if no models set
|
|
- If model_access_groups is provided, only return models that are in the access groups
|
|
"""
|
|
all_models_set: Set[str] = set()
|
|
if len(team_models) > 0:
|
|
all_models_set.update(team_models)
|
|
if SpecialModelNames.all_team_models.value in all_models_set:
|
|
all_models_set.update(team_models)
|
|
if SpecialModelNames.all_proxy_models.value in all_models_set:
|
|
all_models_set.update(proxy_model_list)
|
|
|
|
all_models = list(all_models_set)
|
|
|
|
all_models = _get_models_from_access_groups(
|
|
model_access_groups=model_access_groups,
|
|
all_models=list(all_models_set),
|
|
include_model_access_groups=include_model_access_groups,
|
|
)
|
|
|
|
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
|
|
return all_models
|
|
|
|
|
|
def get_complete_model_list(
|
|
key_models: List[str],
|
|
team_models: List[str],
|
|
proxy_model_list: List[str],
|
|
user_model: Optional[str],
|
|
infer_model_from_keys: Optional[bool],
|
|
return_wildcard_routes: Optional[bool] = False,
|
|
llm_router: Optional[Router] = None,
|
|
model_access_groups: Dict[str, List[str]] = {},
|
|
include_model_access_groups: Optional[bool] = False,
|
|
only_model_access_groups: Optional[bool] = False,
|
|
) -> List[str]:
|
|
"""Logic for returning complete model list for a given key + team pair"""
|
|
|
|
"""
|
|
- If key list is empty -> defer to team list
|
|
- If team list is empty -> defer to proxy model list
|
|
|
|
If list contains wildcard -> return known provider models
|
|
"""
|
|
|
|
unique_models = []
|
|
def append_unique(models):
|
|
for model in models:
|
|
if model not in unique_models:
|
|
unique_models.append(model)
|
|
|
|
if key_models:
|
|
append_unique(key_models)
|
|
elif team_models:
|
|
append_unique(team_models)
|
|
else:
|
|
append_unique(proxy_model_list)
|
|
if include_model_access_groups:
|
|
append_unique(list(model_access_groups.keys())) # TODO: keys order
|
|
|
|
if user_model:
|
|
append_unique([user_model])
|
|
|
|
if infer_model_from_keys:
|
|
valid_models = get_valid_models()
|
|
append_unique(valid_models)
|
|
|
|
if only_model_access_groups:
|
|
model_access_groups_to_return: List[str] = []
|
|
for model in unique_models:
|
|
if model in model_access_groups:
|
|
model_access_groups_to_return.append(model)
|
|
return model_access_groups_to_return
|
|
|
|
all_wildcard_models = _get_wildcard_models(
|
|
unique_models=unique_models,
|
|
return_wildcard_routes=return_wildcard_routes,
|
|
llm_router=llm_router,
|
|
)
|
|
|
|
complete_model_list = unique_models + all_wildcard_models
|
|
|
|
return complete_model_list
|
|
|
|
|
|
def get_known_models_from_wildcard(
|
|
wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None
|
|
) -> List[str]:
|
|
try:
|
|
wildcard_provider_prefix, wildcard_suffix = wildcard_model.split("/", 1)
|
|
except ValueError: # safely fail
|
|
return []
|
|
|
|
if litellm_params is None: # need litellm params to extract litellm model name
|
|
return []
|
|
|
|
try:
|
|
provider = litellm_params.model.split("/", 1)[0]
|
|
except ValueError:
|
|
provider = wildcard_provider_prefix
|
|
|
|
# get all known provider models
|
|
|
|
wildcard_models = get_provider_models(
|
|
provider=provider, litellm_params=litellm_params
|
|
)
|
|
|
|
if wildcard_models is None:
|
|
return []
|
|
if wildcard_suffix != "*":
|
|
## CHECK IF PARTIAL FILTER e.g. `gemini-*`
|
|
model_prefix = wildcard_suffix.replace("*", "")
|
|
|
|
is_partial_filter = any(
|
|
wc_model.startswith(model_prefix) for wc_model in wildcard_models
|
|
)
|
|
if is_partial_filter:
|
|
filtered_wildcard_models = [
|
|
wc_model
|
|
for wc_model in wildcard_models
|
|
if wc_model.startswith(model_prefix)
|
|
]
|
|
wildcard_models = filtered_wildcard_models
|
|
else:
|
|
# add model prefix to wildcard models
|
|
wildcard_models = [f"{model_prefix}{model}" for model in wildcard_models]
|
|
|
|
suffix_appended_wildcard_models = []
|
|
for model in wildcard_models:
|
|
if not model.startswith(wildcard_provider_prefix):
|
|
model = f"{wildcard_provider_prefix}/{model}"
|
|
suffix_appended_wildcard_models.append(model)
|
|
return suffix_appended_wildcard_models or []
|
|
|
|
|
|
def _get_wildcard_models(
|
|
unique_models: List[str],
|
|
return_wildcard_routes: Optional[bool] = False,
|
|
llm_router: Optional[Router] = None,
|
|
) -> List[str]:
|
|
models_to_remove = set()
|
|
all_wildcard_models = []
|
|
for model in unique_models:
|
|
if _check_wildcard_routing(model=model):
|
|
if (
|
|
return_wildcard_routes
|
|
): # will add the wildcard route to the list eg: anthropic/*.
|
|
all_wildcard_models.append(model)
|
|
|
|
## get litellm params from model
|
|
if llm_router is not None:
|
|
model_list = llm_router.get_model_list(model_name=model)
|
|
if model_list is not None:
|
|
for router_model in model_list:
|
|
wildcard_models = get_known_models_from_wildcard(
|
|
wildcard_model=model,
|
|
litellm_params=LiteLLM_Params(
|
|
**router_model["litellm_params"] # type: ignore
|
|
),
|
|
)
|
|
all_wildcard_models.extend(wildcard_models)
|
|
else:
|
|
# get all known provider models
|
|
wildcard_models = get_known_models_from_wildcard(wildcard_model=model)
|
|
|
|
if wildcard_models is not None:
|
|
models_to_remove.add(model)
|
|
all_wildcard_models.extend(wildcard_models)
|
|
|
|
for model in models_to_remove:
|
|
unique_models.remove(model)
|
|
|
|
return all_wildcard_models
|
|
|
|
|
|
def get_all_fallbacks(
|
|
model: str,
|
|
llm_router: Optional[Router] = None,
|
|
fallback_type: str = "general",
|
|
) -> List[str]:
|
|
"""
|
|
Get all fallbacks for a given model from the router's fallback configuration.
|
|
|
|
Args:
|
|
model: The model name to get fallbacks for
|
|
llm_router: The LiteLLM router instance
|
|
fallback_type: Type of fallback ("general", "context_window", "content_policy")
|
|
|
|
Returns:
|
|
List of fallback model names. Empty list if no fallbacks found.
|
|
"""
|
|
if llm_router is None:
|
|
return []
|
|
|
|
# Get the appropriate fallback list based on type
|
|
fallbacks_config: list = []
|
|
if fallback_type == "general":
|
|
fallbacks_config = getattr(llm_router, "fallbacks", [])
|
|
elif fallback_type == "context_window":
|
|
fallbacks_config = getattr(llm_router, "context_window_fallbacks", [])
|
|
elif fallback_type == "content_policy":
|
|
fallbacks_config = getattr(llm_router, "content_policy_fallbacks", [])
|
|
else:
|
|
verbose_proxy_logger.warning(f"Unknown fallback_type: {fallback_type}")
|
|
return []
|
|
|
|
if not fallbacks_config:
|
|
return []
|
|
|
|
try:
|
|
# Use existing function to get fallback model group
|
|
fallback_model_group, _ = get_fallback_model_group(
|
|
fallbacks=fallbacks_config, model_group=model
|
|
)
|
|
|
|
if fallback_model_group is None:
|
|
return []
|
|
|
|
return fallback_model_group
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(f"Error getting fallbacks for model {model}: {e}")
|
|
return []
|