""" This module is used to pass through requests to the LLM APIs. """ import asyncio import contextvars from functools import partial from typing import ( TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Generator, List, Optional, Union, cast, ) import httpx from httpx._types import CookieTypes, QueryParamTypes, RequestFiles import litellm from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from litellm.utils import client base_llm_http_handler = BaseLLMHTTPHandler() from .utils import BasePassthroughUtils if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig @client async def allm_passthrough_route( *, method: str, endpoint: str, model: str, custom_llm_provider: Optional[str] = None, api_base: Optional[str] = None, api_key: Optional[str] = None, request_query_params: Optional[dict] = None, request_headers: Optional[dict] = None, content: Optional[Any] = None, data: Optional[dict] = None, files: Optional[RequestFiles] = None, json: Optional[Any] = None, params: Optional[QueryParamTypes] = None, cookies: Optional[CookieTypes] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, **kwargs, ) -> Union[ httpx.Response, Coroutine[Any, Any, httpx.Response], Generator[Any, Any, Any], AsyncGenerator[Any, Any], ]: """ Async: Reranks a list of documents based on their relevance to the query """ try: loop = asyncio.get_event_loop() kwargs["allm_passthrough_route"] = True model, custom_llm_provider, api_key, api_base = get_llm_provider( model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key, ) from litellm.types.utils import LlmProviders from litellm.utils import ProviderConfigManager provider_config = cast( Optional["BasePassthroughConfig"], kwargs.get("provider_config") ) or ProviderConfigManager.get_provider_passthrough_config( provider=LlmProviders(custom_llm_provider), model=model, ) if provider_config is None: raise Exception(f"Provider {custom_llm_provider} not found") func = partial( llm_passthrough_route, method=method, endpoint=endpoint, model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key, request_query_params=request_query_params, request_headers=request_headers, content=content, data=data, files=files, json=json, params=params, cookies=cookies, client=client, **kwargs, ) ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) init_response = await loop.run_in_executor(None, func_with_context) if asyncio.iscoroutine(init_response): response = await init_response try: response.raise_for_status() except httpx.HTTPStatusError as e: error_text = await e.response.aread() error_text_str = error_text.decode("utf-8") raise Exception(error_text_str) else: response = init_response return response except Exception as e: # For passthrough routes, we need to get the provider config to properly handle errors from litellm.types.utils import LlmProviders from litellm.utils import ProviderConfigManager # Get the provider using the same logic as llm_passthrough_route _, resolved_custom_llm_provider, _, _ = get_llm_provider( model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key, ) # Get provider config if available provider_config = None if resolved_custom_llm_provider: try: provider_config = cast( Optional["BasePassthroughConfig"], kwargs.get("provider_config") ) or ProviderConfigManager.get_provider_passthrough_config( provider=LlmProviders(resolved_custom_llm_provider), model=model, ) except Exception: # If we can't get provider config, pass None pass if provider_config is None: # If no provider config available, raise the original exception raise e raise base_llm_http_handler._handle_error( e=e, provider_config=provider_config, ) @client def llm_passthrough_route( *, method: str, endpoint: str, model: str, custom_llm_provider: Optional[str] = None, api_base: Optional[str] = None, api_key: Optional[str] = None, request_query_params: Optional[dict] = None, request_headers: Optional[dict] = None, allm_passthrough_route: bool = False, content: Optional[Any] = None, data: Optional[dict] = None, files: Optional[RequestFiles] = None, json: Optional[Any] = None, params: Optional[QueryParamTypes] = None, cookies: Optional[CookieTypes] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, **kwargs, ) -> Union[ httpx.Response, Coroutine[Any, Any, httpx.Response], Generator[Any, Any, Any], AsyncGenerator[Any, Any], ]: """ Pass through requests to the LLM APIs. Step 1. Build the request Step 2. Send the request Step 3. Return the response """ from litellm.litellm_core_utils.get_litellm_params import get_litellm_params from litellm.types.utils import LlmProviders from litellm.utils import ProviderConfigManager if client is None: if allm_passthrough_route: client = litellm.module_level_aclient else: client = litellm.module_level_client litellm_logging_obj = cast("LiteLLMLoggingObj", kwargs.get("litellm_logging_obj")) model, custom_llm_provider, api_key, api_base = get_llm_provider( model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key, ) litellm_params_dict = get_litellm_params(**kwargs) litellm_logging_obj.update_environment_variables( model=model, litellm_params=litellm_params_dict, optional_params={}, endpoint=endpoint, custom_llm_provider=custom_llm_provider, request_data=data if data else json, ) provider_config = cast( Optional["BasePassthroughConfig"], kwargs.get("provider_config") ) or ProviderConfigManager.get_provider_passthrough_config( provider=LlmProviders(custom_llm_provider), model=model, ) if provider_config is None: raise Exception(f"Provider {custom_llm_provider} not found") updated_url, base_target_url = provider_config.get_complete_url( api_base=api_base, api_key=api_key, model=model, endpoint=endpoint, request_query_params=request_query_params, litellm_params=litellm_params_dict, ) # Add or update query parameters provider_api_key = provider_config.get_api_key(api_key) auth_headers = provider_config.validate_environment( headers={}, model=model, messages=[], optional_params={}, litellm_params={}, api_key=provider_api_key, api_base=base_target_url, ) headers = BasePassthroughUtils.forward_headers_from_request( request_headers=request_headers or {}, headers=auth_headers, forward_headers=False, ) headers, signed_json_body = provider_config.sign_request( headers=headers, litellm_params=litellm_params_dict, request_data=data if data else json, api_base=str(updated_url), model=model, ) ## SWAP MODEL IN JSON BODY [TODO: REFACTOR TO A provider_config.transform_request method] if json and isinstance(json, dict) and "model" in json: json["model"] = model request = client.client.build_request( method=method, url=updated_url, content=signed_json_body, data=data if signed_json_body is None else None, files=files, json=json if signed_json_body is None else None, params=params, headers=headers, cookies=cookies, ) ## IS STREAMING REQUEST is_streaming_request = provider_config.is_streaming_request( endpoint=endpoint, request_data=data or json or {}, ) # Update logging object with streaming status litellm_logging_obj.stream = is_streaming_request try: response = client.client.send(request=request, stream=is_streaming_request) if asyncio.iscoroutine(response): if is_streaming_request: return _async_streaming(response, litellm_logging_obj, provider_config) else: return response response.raise_for_status() if ( hasattr(response, "iter_bytes") and is_streaming_request ): # yield the chunk, so we can store it in the logging object return _sync_streaming(response, litellm_logging_obj, provider_config) else: # For non-streaming responses, yield the entire response return response except Exception as e: if provider_config is None: raise e raise base_llm_http_handler._handle_error( e=e, provider_config=provider_config, ) def _sync_streaming( response: httpx.Response, litellm_logging_obj: "LiteLLMLoggingObj", provider_config: "BasePassthroughConfig", ): from litellm.utils import executor try: raw_bytes: List[bytes] = [] for chunk in response.iter_bytes(): # type: ignore raw_bytes.append(chunk) yield chunk executor.submit( litellm_logging_obj.flush_passthrough_collected_chunks, raw_bytes=raw_bytes, provider_config=provider_config, ) except Exception as e: raise e async def _async_streaming( response: Coroutine[Any, Any, httpx.Response], litellm_logging_obj: "LiteLLMLoggingObj", provider_config: "BasePassthroughConfig", ): try: iter_response = await response raw_bytes: List[bytes] = [] async for chunk in iter_response.aiter_bytes(): # type: ignore raw_bytes.append(chunk) yield chunk asyncio.create_task( litellm_logging_obj.async_flush_passthrough_collected_chunks( raw_bytes=raw_bytes, provider_config=provider_config, ) ) except Exception as e: raise e