822 lines
30 KiB
Python
822 lines
30 KiB
Python
import asyncio
|
|
import contextvars
|
|
from functools import partial
|
|
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast, overload
|
|
|
|
import httpx
|
|
|
|
import litellm
|
|
from litellm import Logging, client, exception_type, get_litellm_params
|
|
from litellm.constants import DEFAULT_IMAGE_ENDPOINT_MODEL
|
|
from litellm.constants import request_timeout as DEFAULT_REQUEST_TIMEOUT
|
|
from litellm.exceptions import LiteLLMUnknownProvider
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
from litellm.litellm_core_utils.mock_functions import mock_image_generation
|
|
from litellm.llms.base_llm import BaseImageEditConfig, BaseImageGenerationConfig
|
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
|
from litellm.llms.custom_llm import CustomLLM
|
|
|
|
#################### Initialize provider clients ####################
|
|
llm_http_handler: BaseLLMHTTPHandler = BaseLLMHTTPHandler()
|
|
from litellm.main import (
|
|
azure_chat_completions,
|
|
base_llm_aiohttp_handler,
|
|
base_llm_http_handler,
|
|
bedrock_image_generation,
|
|
openai_chat_completions,
|
|
openai_image_variations,
|
|
vertex_image_generation,
|
|
)
|
|
|
|
###########################################
|
|
from litellm.secret_managers.main import get_secret_str
|
|
from litellm.types.images.main import ImageEditOptionalRequestParams
|
|
from litellm.types.llms.openai import ImageGenerationRequestQuality
|
|
from litellm.types.router import GenericLiteLLMParams
|
|
from litellm.types.utils import (
|
|
LITELLM_IMAGE_VARIATION_PROVIDERS,
|
|
FileTypes,
|
|
LlmProviders,
|
|
all_litellm_params,
|
|
)
|
|
from litellm.utils import (
|
|
ImageResponse,
|
|
ProviderConfigManager,
|
|
get_llm_provider,
|
|
get_optional_params_image_gen,
|
|
)
|
|
|
|
from .utils import ImageEditRequestUtils
|
|
|
|
|
|
##### Image Generation #######################
|
|
@client
|
|
async def aimage_generation(*args, **kwargs) -> ImageResponse:
|
|
"""
|
|
Asynchronously calls the `image_generation` function with the given arguments and keyword arguments.
|
|
|
|
Parameters:
|
|
- `args` (tuple): Positional arguments to be passed to the `image_generation` function.
|
|
- `kwargs` (dict): Keyword arguments to be passed to the `image_generation` function.
|
|
|
|
Returns:
|
|
- `response` (Any): The response returned by the `image_generation` function.
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
|
### PASS ARGS TO Image Generation ###
|
|
kwargs["aimg_generation"] = True
|
|
custom_llm_provider = None
|
|
try:
|
|
# Use a partial function to pass your keyword arguments
|
|
func = partial(image_generation, *args, **kwargs)
|
|
|
|
# Add the context to the function
|
|
ctx = contextvars.copy_context()
|
|
func_with_context = partial(ctx.run, func)
|
|
|
|
_, custom_llm_provider, _, _ = get_llm_provider(
|
|
model=model, api_base=kwargs.get("api_base", None)
|
|
)
|
|
|
|
# Await normally
|
|
init_response = await loop.run_in_executor(None, func_with_context)
|
|
|
|
response: Optional[ImageResponse] = None
|
|
if isinstance(init_response, dict):
|
|
response = ImageResponse(**init_response)
|
|
elif isinstance(init_response, ImageResponse): ## CACHING SCENARIO
|
|
response = init_response
|
|
elif asyncio.iscoroutine(init_response):
|
|
response = await init_response # type: ignore
|
|
|
|
if response is None:
|
|
raise ValueError(
|
|
"Unable to get Image Response. Please pass a valid llm_provider."
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
custom_llm_provider = custom_llm_provider or "openai"
|
|
raise exception_type(
|
|
model=model,
|
|
custom_llm_provider=custom_llm_provider,
|
|
original_exception=e,
|
|
completion_kwargs=args,
|
|
extra_kwargs=kwargs,
|
|
)
|
|
|
|
|
|
# Overload for when aimg_generation=True (returns Coroutine)
|
|
@overload
|
|
def image_generation(
|
|
prompt: str,
|
|
model: Optional[str] = None,
|
|
n: Optional[int] = None,
|
|
quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
|
|
response_format: Optional[str] = None,
|
|
size: Optional[str] = None,
|
|
style: Optional[str] = None,
|
|
user: Optional[str] = None,
|
|
input_fidelity: Optional[str] = None,
|
|
timeout=600, # default to 10 minutes
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
api_version: Optional[str] = None,
|
|
custom_llm_provider=None,
|
|
*,
|
|
aimg_generation: Literal[True],
|
|
**kwargs,
|
|
) -> Coroutine[Any, Any, ImageResponse]:
|
|
...
|
|
|
|
|
|
# Overload for when aimg_generation=False or not specified (returns ImageResponse)
|
|
@overload
|
|
def image_generation(
|
|
prompt: str,
|
|
model: Optional[str] = None,
|
|
n: Optional[int] = None,
|
|
quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
|
|
response_format: Optional[str] = None,
|
|
size: Optional[str] = None,
|
|
style: Optional[str] = None,
|
|
user: Optional[str] = None,
|
|
input_fidelity: Optional[str] = None,
|
|
timeout=600, # default to 10 minutes
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
api_version: Optional[str] = None,
|
|
custom_llm_provider=None,
|
|
*,
|
|
aimg_generation: Literal[False] = False,
|
|
**kwargs,
|
|
) -> ImageResponse:
|
|
...
|
|
|
|
|
|
@client
|
|
def image_generation( # noqa: PLR0915
|
|
prompt: str,
|
|
model: Optional[str] = None,
|
|
n: Optional[int] = None,
|
|
quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
|
|
response_format: Optional[str] = None,
|
|
size: Optional[str] = None,
|
|
style: Optional[str] = None,
|
|
user: Optional[str] = None,
|
|
input_fidelity: Optional[str] = None,
|
|
timeout=600, # default to 10 minutes
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
api_version: Optional[str] = None,
|
|
custom_llm_provider=None,
|
|
**kwargs,
|
|
) -> Union[
|
|
ImageResponse,
|
|
Coroutine[Any, Any, ImageResponse],
|
|
]:
|
|
"""
|
|
Maps the https://api.openai.com/v1/images/generations endpoint.
|
|
|
|
Currently supports just Azure + OpenAI.
|
|
"""
|
|
try:
|
|
args = locals()
|
|
aimg_generation = kwargs.get("aimg_generation", False)
|
|
litellm_call_id = kwargs.get("litellm_call_id", None)
|
|
logger_fn = kwargs.get("logger_fn", None)
|
|
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
|
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
|
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
|
|
model_info = kwargs.get("model_info", None)
|
|
metadata = kwargs.get("metadata", {})
|
|
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
|
client = kwargs.get("client", None)
|
|
extra_headers = kwargs.get("extra_headers", None)
|
|
headers: dict = kwargs.get("headers", None) or {}
|
|
base_model = kwargs.get("base_model", None)
|
|
if extra_headers is not None:
|
|
headers.update(extra_headers)
|
|
model_response: ImageResponse = litellm.utils.ImageResponse()
|
|
dynamic_api_key: Optional[str] = None
|
|
if model is not None or custom_llm_provider is not None:
|
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
|
model=model, # type: ignore
|
|
custom_llm_provider=custom_llm_provider,
|
|
api_base=api_base,
|
|
)
|
|
else:
|
|
model = "dall-e-2"
|
|
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
|
model_response._hidden_params["model"] = model
|
|
openai_params = [
|
|
"user",
|
|
"request_timeout",
|
|
"api_base",
|
|
"api_version",
|
|
"api_key",
|
|
"deployment_id",
|
|
"organization",
|
|
"base_url",
|
|
"default_headers",
|
|
"timeout",
|
|
"max_retries",
|
|
"n",
|
|
"quality",
|
|
"size",
|
|
"style",
|
|
"input_fidelity",
|
|
]
|
|
litellm_params = all_litellm_params
|
|
default_params = openai_params + litellm_params
|
|
non_default_params = {
|
|
k: v for k, v in kwargs.items() if k not in default_params
|
|
} # model-specific params - pass them straight to the model/provider
|
|
|
|
image_generation_config: Optional[BaseImageGenerationConfig] = None
|
|
if (
|
|
custom_llm_provider is not None
|
|
and custom_llm_provider in LlmProviders._member_map_.values()
|
|
):
|
|
image_generation_config = (
|
|
ProviderConfigManager.get_provider_image_generation_config(
|
|
model=base_model or model,
|
|
provider=LlmProviders(custom_llm_provider),
|
|
)
|
|
)
|
|
|
|
optional_params = get_optional_params_image_gen(
|
|
model=base_model or model,
|
|
n=n,
|
|
quality=quality,
|
|
response_format=response_format,
|
|
size=size,
|
|
style=style,
|
|
user=user,
|
|
input_fidelity=input_fidelity,
|
|
custom_llm_provider=custom_llm_provider,
|
|
provider_config=image_generation_config,
|
|
**non_default_params,
|
|
)
|
|
|
|
litellm_params_dict = get_litellm_params(**kwargs)
|
|
|
|
logging: Logging = litellm_logging_obj
|
|
logging.update_environment_variables(
|
|
model=model,
|
|
user=user,
|
|
optional_params=optional_params,
|
|
litellm_params={
|
|
"timeout": timeout,
|
|
"azure": False,
|
|
"litellm_call_id": litellm_call_id,
|
|
"logger_fn": logger_fn,
|
|
"proxy_server_request": proxy_server_request,
|
|
"model_info": model_info,
|
|
"metadata": metadata,
|
|
"preset_cache_key": None,
|
|
"stream_response": {},
|
|
},
|
|
custom_llm_provider=custom_llm_provider,
|
|
)
|
|
if "custom_llm_provider" not in logging.model_call_details:
|
|
logging.model_call_details["custom_llm_provider"] = custom_llm_provider
|
|
if mock_response is not None:
|
|
return mock_image_generation(model=model, mock_response=mock_response)
|
|
|
|
if custom_llm_provider == "azure":
|
|
# azure configs
|
|
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
|
|
|
|
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
|
|
|
api_version = (
|
|
api_version
|
|
or litellm.api_version
|
|
or get_secret_str("AZURE_API_VERSION")
|
|
)
|
|
|
|
api_key = (
|
|
api_key
|
|
or litellm.api_key
|
|
or litellm.azure_key
|
|
or get_secret_str("AZURE_OPENAI_API_KEY")
|
|
or get_secret_str("AZURE_API_KEY")
|
|
)
|
|
|
|
azure_ad_token = optional_params.pop(
|
|
"azure_ad_token", None
|
|
) or get_secret_str("AZURE_AD_TOKEN")
|
|
|
|
default_headers = {
|
|
"Content-Type": "application/json;",
|
|
"api-key": api_key,
|
|
}
|
|
for k, v in default_headers.items():
|
|
if k not in headers:
|
|
headers[k] = v
|
|
|
|
model_response = azure_chat_completions.image_generation(
|
|
model=model,
|
|
prompt=prompt,
|
|
timeout=timeout,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
azure_ad_token=azure_ad_token,
|
|
azure_ad_token_provider=azure_ad_token_provider,
|
|
logging_obj=litellm_logging_obj,
|
|
optional_params=optional_params,
|
|
model_response=model_response,
|
|
api_version=api_version,
|
|
aimg_generation=aimg_generation,
|
|
client=client,
|
|
headers=headers,
|
|
litellm_params=litellm_params_dict,
|
|
)
|
|
elif (
|
|
custom_llm_provider == "openai"
|
|
or custom_llm_provider in litellm.openai_compatible_providers
|
|
):
|
|
model_response = openai_chat_completions.image_generation(
|
|
model=model,
|
|
prompt=prompt,
|
|
timeout=timeout,
|
|
api_key=api_key or dynamic_api_key,
|
|
api_base=api_base,
|
|
logging_obj=litellm_logging_obj,
|
|
optional_params=optional_params,
|
|
model_response=model_response,
|
|
aimg_generation=aimg_generation,
|
|
client=client,
|
|
)
|
|
elif custom_llm_provider == "bedrock":
|
|
if model is None:
|
|
raise Exception("Model needs to be set for bedrock")
|
|
model_response = bedrock_image_generation.image_generation( # type: ignore
|
|
model=model,
|
|
prompt=prompt,
|
|
timeout=timeout,
|
|
logging_obj=litellm_logging_obj,
|
|
optional_params=optional_params,
|
|
model_response=model_response,
|
|
aimg_generation=aimg_generation,
|
|
client=client,
|
|
api_base=api_base,
|
|
api_key=api_key
|
|
)
|
|
elif custom_llm_provider == "vertex_ai":
|
|
vertex_ai_project = (
|
|
optional_params.pop("vertex_project", None)
|
|
or optional_params.pop("vertex_ai_project", None)
|
|
or litellm.vertex_project
|
|
or get_secret_str("VERTEXAI_PROJECT")
|
|
)
|
|
vertex_ai_location = (
|
|
optional_params.pop("vertex_location", None)
|
|
or optional_params.pop("vertex_ai_location", None)
|
|
or litellm.vertex_location
|
|
or get_secret_str("VERTEXAI_LOCATION")
|
|
)
|
|
vertex_credentials = (
|
|
optional_params.pop("vertex_credentials", None)
|
|
or optional_params.pop("vertex_ai_credentials", None)
|
|
or get_secret_str("VERTEXAI_CREDENTIALS")
|
|
)
|
|
|
|
api_base = (
|
|
api_base
|
|
or litellm.api_base
|
|
or get_secret_str("VERTEXAI_API_BASE")
|
|
or get_secret_str("VERTEX_API_BASE")
|
|
)
|
|
|
|
model_response = vertex_image_generation.image_generation(
|
|
model=model,
|
|
prompt=prompt,
|
|
timeout=timeout,
|
|
logging_obj=litellm_logging_obj,
|
|
optional_params=optional_params,
|
|
model_response=model_response,
|
|
vertex_project=vertex_ai_project,
|
|
vertex_location=vertex_ai_location,
|
|
vertex_credentials=vertex_credentials,
|
|
aimg_generation=aimg_generation,
|
|
api_base=api_base,
|
|
client=client,
|
|
)
|
|
#########################################################
|
|
# Providers using llm_http_handler
|
|
#########################################################
|
|
elif custom_llm_provider in (
|
|
litellm.LlmProviders.RECRAFT,
|
|
litellm.LlmProviders.GEMINI,
|
|
|
|
):
|
|
if image_generation_config is None:
|
|
raise ValueError(f"image generation config is not supported for {custom_llm_provider}")
|
|
|
|
return llm_http_handler.image_generation_handler(
|
|
model=model,
|
|
prompt=prompt,
|
|
image_generation_provider_config=image_generation_config,
|
|
image_generation_optional_request_params=optional_params,
|
|
custom_llm_provider=custom_llm_provider,
|
|
litellm_params=litellm_params_dict,
|
|
logging_obj=litellm_logging_obj,
|
|
timeout=timeout,
|
|
client=client,
|
|
)
|
|
elif (
|
|
custom_llm_provider in litellm._custom_providers
|
|
): # Assume custom LLM provider
|
|
# Get the Custom Handler
|
|
custom_handler: Optional[CustomLLM] = None
|
|
for item in litellm.custom_provider_map:
|
|
if item["provider"] == custom_llm_provider:
|
|
custom_handler = item["custom_handler"]
|
|
|
|
if custom_handler is None:
|
|
raise LiteLLMUnknownProvider(
|
|
model=model, custom_llm_provider=custom_llm_provider
|
|
)
|
|
|
|
## ROUTE LLM CALL ##
|
|
if aimg_generation is True:
|
|
async_custom_client: Optional[AsyncHTTPHandler] = None
|
|
if client is not None and isinstance(client, AsyncHTTPHandler):
|
|
async_custom_client = client
|
|
|
|
## CALL FUNCTION
|
|
model_response = custom_handler.aimage_generation( # type: ignore
|
|
model=model,
|
|
prompt=prompt,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
model_response=model_response,
|
|
optional_params=optional_params,
|
|
logging_obj=litellm_logging_obj,
|
|
timeout=timeout,
|
|
client=async_custom_client,
|
|
)
|
|
else:
|
|
custom_client: Optional[HTTPHandler] = None
|
|
if client is not None and isinstance(client, HTTPHandler):
|
|
custom_client = client
|
|
|
|
## CALL FUNCTION
|
|
model_response = custom_handler.image_generation(
|
|
model=model,
|
|
prompt=prompt,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
model_response=model_response,
|
|
optional_params=optional_params,
|
|
logging_obj=litellm_logging_obj,
|
|
timeout=timeout,
|
|
client=custom_client,
|
|
)
|
|
|
|
return model_response
|
|
except Exception as e:
|
|
## Map to OpenAI Exception
|
|
raise exception_type(
|
|
model=model,
|
|
custom_llm_provider=custom_llm_provider,
|
|
original_exception=e,
|
|
completion_kwargs=locals(),
|
|
extra_kwargs=kwargs,
|
|
)
|
|
|
|
|
|
@client
|
|
async def aimage_variation(*args, **kwargs) -> ImageResponse:
|
|
"""
|
|
Asynchronously calls the `image_variation` function with the given arguments and keyword arguments.
|
|
|
|
Parameters:
|
|
- `args` (tuple): Positional arguments to be passed to the `image_variation` function.
|
|
- `kwargs` (dict): Keyword arguments to be passed to the `image_variation` function.
|
|
|
|
Returns:
|
|
- `response` (Any): The response returned by the `image_variation` function.
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
model = kwargs.get("model", None)
|
|
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
|
### PASS ARGS TO Image Generation ###
|
|
kwargs["async_call"] = True
|
|
try:
|
|
# Use a partial function to pass your keyword arguments
|
|
func = partial(image_variation, *args, **kwargs)
|
|
|
|
# Add the context to the function
|
|
ctx = contextvars.copy_context()
|
|
func_with_context = partial(ctx.run, func)
|
|
|
|
if custom_llm_provider is None and model is not None:
|
|
_, custom_llm_provider, _, _ = get_llm_provider(
|
|
model=model, api_base=kwargs.get("api_base", None)
|
|
)
|
|
|
|
# Await normally
|
|
init_response = await loop.run_in_executor(None, func_with_context)
|
|
if isinstance(init_response, dict) or isinstance(
|
|
init_response, ImageResponse
|
|
): ## CACHING SCENARIO
|
|
if isinstance(init_response, dict):
|
|
init_response = ImageResponse(**init_response)
|
|
response = init_response
|
|
elif asyncio.iscoroutine(init_response):
|
|
response = await init_response # type: ignore
|
|
else:
|
|
# Call the synchronous function using run_in_executor
|
|
response = await loop.run_in_executor(None, func_with_context)
|
|
return response
|
|
except Exception as e:
|
|
custom_llm_provider = custom_llm_provider or "openai"
|
|
raise exception_type(
|
|
model=model,
|
|
custom_llm_provider=custom_llm_provider,
|
|
original_exception=e,
|
|
completion_kwargs=args,
|
|
extra_kwargs=kwargs,
|
|
)
|
|
|
|
|
|
@client
|
|
def image_variation(
|
|
image: FileTypes,
|
|
model: str = "dall-e-2", # set to dall-e-2 by default - like OpenAI.
|
|
n: int = 1,
|
|
response_format: Literal["url", "b64_json"] = "url",
|
|
size: Optional[str] = None,
|
|
user: Optional[str] = None,
|
|
**kwargs,
|
|
) -> ImageResponse:
|
|
# get non-default params
|
|
client = kwargs.get("client", None)
|
|
# get logging object
|
|
litellm_logging_obj = cast(LiteLLMLoggingObj, kwargs.get("litellm_logging_obj"))
|
|
|
|
# get the litellm params
|
|
litellm_params = get_litellm_params(**kwargs)
|
|
# get the custom llm provider
|
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
|
model=model,
|
|
custom_llm_provider=litellm_params.get("custom_llm_provider", None),
|
|
api_base=litellm_params.get("api_base", None),
|
|
api_key=litellm_params.get("api_key", None),
|
|
)
|
|
|
|
# route to the correct provider w/ the params
|
|
try:
|
|
llm_provider = LlmProviders(custom_llm_provider)
|
|
image_variation_provider = LITELLM_IMAGE_VARIATION_PROVIDERS(llm_provider)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Invalid image variation provider: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
|
|
)
|
|
model_response = ImageResponse()
|
|
|
|
response: Optional[ImageResponse] = None
|
|
|
|
provider_config = ProviderConfigManager.get_provider_model_info(
|
|
model=model or "", # openai defaults to dall-e-2
|
|
provider=llm_provider,
|
|
)
|
|
|
|
if provider_config is None:
|
|
raise ValueError(
|
|
f"image variation provider has no known model info config - required for getting api keys, etc.: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
|
|
)
|
|
|
|
api_key = provider_config.get_api_key(litellm_params.get("api_key", None))
|
|
api_base = provider_config.get_api_base(litellm_params.get("api_base", None))
|
|
|
|
if image_variation_provider == LITELLM_IMAGE_VARIATION_PROVIDERS.OPENAI:
|
|
if api_key is None:
|
|
raise ValueError("API key is required for OpenAI image variations")
|
|
if api_base is None:
|
|
raise ValueError("API base is required for OpenAI image variations")
|
|
|
|
response = openai_image_variations.image_variations(
|
|
model_response=model_response,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
model=model,
|
|
image=image,
|
|
timeout=litellm_params.get("timeout", None),
|
|
custom_llm_provider=custom_llm_provider,
|
|
logging_obj=litellm_logging_obj,
|
|
optional_params={},
|
|
litellm_params=litellm_params,
|
|
)
|
|
elif image_variation_provider == LITELLM_IMAGE_VARIATION_PROVIDERS.TOPAZ:
|
|
if api_key is None:
|
|
raise ValueError("API key is required for Topaz image variations")
|
|
if api_base is None:
|
|
raise ValueError("API base is required for Topaz image variations")
|
|
|
|
response = base_llm_aiohttp_handler.image_variations(
|
|
model_response=model_response,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
model=model,
|
|
image=image,
|
|
timeout=litellm_params.get("timeout", None) or DEFAULT_REQUEST_TIMEOUT,
|
|
custom_llm_provider=custom_llm_provider,
|
|
logging_obj=litellm_logging_obj,
|
|
optional_params={},
|
|
litellm_params=litellm_params,
|
|
client=client,
|
|
)
|
|
|
|
# return the response
|
|
if response is None:
|
|
raise ValueError(
|
|
f"Invalid image variation provider: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
|
|
)
|
|
return response
|
|
|
|
|
|
@client
|
|
def image_edit(
|
|
image: FileTypes,
|
|
prompt: str,
|
|
model: Optional[str] = None,
|
|
mask: Optional[str] = None,
|
|
n: Optional[int] = None,
|
|
quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
|
|
response_format: Optional[str] = None,
|
|
size: Optional[str] = None,
|
|
user: Optional[str] = None,
|
|
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
|
# The extra values given here take precedence over values defined on the client or passed to this method.
|
|
extra_headers: Optional[Dict[str, Any]] = None,
|
|
extra_query: Optional[Dict[str, Any]] = None,
|
|
extra_body: Optional[Dict[str, Any]] = None,
|
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
|
# LiteLLM specific params,
|
|
custom_llm_provider: Optional[str] = None,
|
|
**kwargs,
|
|
) -> Union[ImageResponse, Coroutine[Any, Any, ImageResponse]]:
|
|
"""
|
|
Maps the image edit functionality, similar to OpenAI's images/edits endpoint.
|
|
"""
|
|
local_vars = locals()
|
|
try:
|
|
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
|
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
|
_is_async = kwargs.pop("async_call", False) is True
|
|
|
|
# get llm provider logic
|
|
litellm_params = GenericLiteLLMParams(**kwargs)
|
|
model, custom_llm_provider, _, _ = get_llm_provider(
|
|
model=model or DEFAULT_IMAGE_ENDPOINT_MODEL,
|
|
custom_llm_provider=custom_llm_provider,
|
|
)
|
|
|
|
# get provider config
|
|
image_edit_provider_config: Optional[
|
|
BaseImageEditConfig
|
|
] = ProviderConfigManager.get_provider_image_edit_config(
|
|
model=model,
|
|
provider=litellm.LlmProviders(custom_llm_provider),
|
|
)
|
|
|
|
if image_edit_provider_config is None:
|
|
raise ValueError(f"image edit is not supported for {custom_llm_provider}")
|
|
|
|
local_vars.update(kwargs)
|
|
# Get ImageEditOptionalRequestParams with only valid parameters
|
|
image_edit_optional_params: ImageEditOptionalRequestParams = (
|
|
ImageEditRequestUtils.get_requested_image_edit_optional_param(local_vars)
|
|
)
|
|
|
|
# Get optional parameters for the responses API
|
|
image_edit_request_params: Dict = (
|
|
ImageEditRequestUtils.get_optional_params_image_edit(
|
|
model=model,
|
|
image_edit_provider_config=image_edit_provider_config,
|
|
image_edit_optional_params=image_edit_optional_params,
|
|
)
|
|
)
|
|
|
|
# Pre Call logging
|
|
litellm_logging_obj.update_environment_variables(
|
|
model=model,
|
|
user=user,
|
|
optional_params=dict(image_edit_request_params),
|
|
litellm_params={
|
|
"litellm_call_id": litellm_call_id,
|
|
**image_edit_request_params,
|
|
},
|
|
custom_llm_provider=custom_llm_provider,
|
|
)
|
|
|
|
# Call the handler with _is_async flag instead of directly calling the async handler
|
|
return base_llm_http_handler.image_edit_handler(
|
|
model=model,
|
|
image=image,
|
|
prompt=prompt,
|
|
image_edit_provider_config=image_edit_provider_config,
|
|
image_edit_optional_request_params=image_edit_request_params,
|
|
custom_llm_provider=custom_llm_provider,
|
|
litellm_params=litellm_params,
|
|
logging_obj=litellm_logging_obj,
|
|
extra_headers=extra_headers,
|
|
extra_body=extra_body,
|
|
timeout=timeout or DEFAULT_REQUEST_TIMEOUT,
|
|
_is_async=_is_async,
|
|
client=kwargs.get("client"),
|
|
)
|
|
|
|
except Exception as e:
|
|
raise litellm.exception_type(
|
|
model=model,
|
|
custom_llm_provider=custom_llm_provider,
|
|
original_exception=e,
|
|
completion_kwargs=local_vars,
|
|
extra_kwargs=kwargs,
|
|
)
|
|
|
|
|
|
@client
|
|
async def aimage_edit(
|
|
image: FileTypes,
|
|
model: str,
|
|
prompt: str,
|
|
mask: Optional[str] = None,
|
|
n: Optional[int] = None,
|
|
quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
|
|
response_format: Optional[str] = None,
|
|
size: Optional[str] = None,
|
|
user: Optional[str] = None,
|
|
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
|
# The extra values given here take precedence over values defined on the client or passed to this method.
|
|
extra_headers: Optional[Dict[str, Any]] = None,
|
|
extra_query: Optional[Dict[str, Any]] = None,
|
|
extra_body: Optional[Dict[str, Any]] = None,
|
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
|
# LiteLLM specific params,
|
|
custom_llm_provider: Optional[str] = None,
|
|
**kwargs,
|
|
) -> ImageResponse:
|
|
"""
|
|
Asynchronously calls the `image_edit` function with the given arguments and keyword arguments.
|
|
|
|
Parameters:
|
|
- `args` (tuple): Positional arguments to be passed to the `image_edit` function.
|
|
- `kwargs` (dict): Keyword arguments to be passed to the `image_edit` function.
|
|
|
|
Returns:
|
|
- `response` (Any): The response returned by the `image_edit` function.
|
|
"""
|
|
local_vars = locals()
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
kwargs["async_call"] = True
|
|
|
|
# get custom llm provider so we can use this for mapping exceptions
|
|
if custom_llm_provider is None:
|
|
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
|
model=model, api_base=local_vars.get("base_url", None)
|
|
)
|
|
|
|
func = partial(
|
|
image_edit,
|
|
image=image,
|
|
prompt=prompt,
|
|
mask=mask,
|
|
model=model,
|
|
n=n,
|
|
quality=quality,
|
|
response_format=response_format,
|
|
size=size,
|
|
user=user,
|
|
timeout=timeout,
|
|
custom_llm_provider=custom_llm_provider,
|
|
**kwargs,
|
|
)
|
|
|
|
ctx = contextvars.copy_context()
|
|
func_with_context = partial(ctx.run, func)
|
|
init_response = await loop.run_in_executor(None, func_with_context)
|
|
|
|
if asyncio.iscoroutine(init_response):
|
|
response = await init_response
|
|
else:
|
|
response = init_response
|
|
|
|
return response
|
|
except Exception as e:
|
|
raise litellm.exception_type(
|
|
model=model,
|
|
custom_llm_provider=custom_llm_provider,
|
|
original_exception=e,
|
|
completion_kwargs=local_vars,
|
|
extra_kwargs=kwargs,
|
|
)
|