ai-station/.venv/lib/python3.12/site-packages/litellm/llms/ollama_chat.py

440 lines
16 KiB
Python
Raw Normal View History

2025-12-25 14:54:33 +00:00
import json
import time
import uuid
from typing import Any, List, Optional, Union
import aiohttp
import httpx
from pydantic import BaseModel
import litellm
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
from litellm.types.utils import ModelResponse, StreamingChoices
class OllamaError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="http://localhost:11434")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
# ollama implementation
def get_ollama_response( # noqa: PLR0915
model_response: ModelResponse,
messages: list,
optional_params: dict,
model: str,
logging_obj: Any,
api_base="http://localhost:11434",
api_key: Optional[str] = None,
acompletion: bool = False,
encoding=None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
if api_base.endswith("/api/chat"):
url = api_base
else:
url = f"{api_base}/api/chat"
## Load Config
config = litellm.OllamaChatConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
stream = optional_params.pop("stream", False)
format = optional_params.pop("format", None)
keep_alive = optional_params.pop("keep_alive", None)
function_name = optional_params.pop("function_name", None)
tools = optional_params.pop("tools", None)
new_messages = []
for m in messages:
if isinstance(
m, BaseModel
): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319
m = m.model_dump(exclude_none=True)
if m.get("tool_calls") is not None and isinstance(m["tool_calls"], list):
new_tools: List[OllamaToolCall] = []
for tool in m["tool_calls"]:
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
if typed_tool["type"] == "function":
arguments = {}
if "arguments" in typed_tool["function"]:
arguments = json.loads(typed_tool["function"]["arguments"])
ollama_tool_call = OllamaToolCall(
function=OllamaToolCallFunction(
name=typed_tool["function"].get("name") or "",
arguments=arguments,
)
)
new_tools.append(ollama_tool_call)
m["tool_calls"] = new_tools
new_messages.append(m)
data = {
"model": model,
"messages": new_messages,
"options": optional_params,
"stream": stream,
}
if format is not None:
data["format"] = format
if tools is not None:
data["tools"] = tools
if keep_alive is not None:
data["keep_alive"] = keep_alive
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={
"api_base": url,
"complete_input_dict": data,
"headers": {},
"acompletion": acompletion,
},
)
if acompletion is True:
if stream is True:
response = ollama_async_streaming(
url=url,
api_key=api_key,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
else:
response = ollama_acompletion(
url=url,
api_key=api_key,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
function_name=function_name,
)
return response
elif stream is True:
return ollama_completion_stream(
url=url, api_key=api_key, data=data, logging_obj=logging_obj
)
headers: Optional[dict] = None
if api_key is not None:
headers = {"Authorization": "Bearer {}".format(api_key)}
sync_client = litellm.module_level_client
if client is not None and isinstance(client, HTTPHandler):
sync_client = client
response = sync_client.post(
url=url,
json=data,
headers=headers,
)
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=response.text,
additional_args={
"headers": None,
"api_base": api_base,
},
)
response_json = response.json()
## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json" and function_name is not None:
function_call = json.loads(response_json["message"]["content"])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call.get("name", function_name),
"arguments": json.dumps(
function_call.get("arguments", function_call)
),
},
"type": "function",
}
],
)
model_response.choices[0].message = message # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
else:
_message = litellm.Message(**response_json["message"])
model_response.choices[0].message = _message # type: ignore
model_response.created = int(time.time())
model_response.model = "ollama_chat/" + model
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
completion_tokens = response_json.get(
"eval_count", litellm.token_counter(text=response_json["message"]["content"])
)
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return model_response
def ollama_completion_stream(url, api_key, data, logging_obj):
_request = {
"url": f"{url}",
"json": data,
"method": "POST",
"timeout": litellm.request_timeout,
"follow_redirects": True,
}
if api_key is not None:
_request["headers"] = {"Authorization": "Bearer {}".format(api_key)}
with httpx.stream(**_request) as response:
try:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code, message=response.iter_lines()
)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=data["model"],
custom_llm_provider="ollama_chat",
logging_obj=logging_obj,
)
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
content_chunks = []
for chunk in streamwrapper:
chunk_choice = chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
content_chunks.append(chunk_choice.delta.content)
response_content = "".join(content_chunks)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response = content_chunks[0]
model_response.choices[0].delta = delta # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
yield model_response
else:
for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e
async def ollama_async_streaming(
url, api_key, data, model_response, encoding, logging_obj
):
try:
_async_http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.OLLAMA
)
client = _async_http_client.client
_request = {
"url": f"{url}",
"json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = {"Authorization": "Bearer {}".format(api_key)}
async with client.stream(**_request) as response:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code, message=response.text
)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.aiter_lines(),
model=data["model"],
custom_llm_provider="ollama_chat",
logging_obj=logging_obj,
)
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = await anext(streamwrapper) # noqa F821
chunk_choice = first_chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
first_chunk_content = chunk_choice.delta.content or ""
else:
first_chunk_content = ""
content_chunks = []
async for chunk in streamwrapper:
chunk_choice = chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
content_chunks.append(chunk_choice.delta.content)
response_content = first_chunk_content + "".join(content_chunks)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call.get(
"name", function_call.get("function", None)
),
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response = first_chunk
model_response.choices[0].delta = delta # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
yield model_response
else:
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
verbose_logger.exception(
"LiteLLM.ollama(): Exception occured - {}".format(str(e))
)
raise e
async def ollama_acompletion(
url,
api_key: Optional[str],
data,
model_response: litellm.ModelResponse,
encoding,
logging_obj,
function_name,
):
data["stream"] = False
try:
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
_request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = {"Authorization": "Bearer {}".format(api_key)}
resp = await session.post(**_request)
if resp.status != 200:
text = await resp.text()
raise OllamaError(status_code=resp.status, message=text)
response_json = await resp.json()
## LOGGING
logging_obj.post_call(
input=data,
api_key="",
original_response=response_json,
additional_args={
"headers": None,
"api_base": url,
},
)
## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json" and function_name is not None:
function_call = json.loads(response_json["message"]["content"])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call.get("name", function_name),
"arguments": json.dumps(
function_call.get("arguments", function_call)
),
},
"type": "function",
}
],
)
model_response.choices[0].message = message # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
else:
_message = litellm.Message(**response_json["message"])
model_response.choices[0].message = _message # type: ignore
model_response.created = int(time.time())
model_response.model = "ollama_chat/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=data["messages"])) # type: ignore
completion_tokens = response_json.get(
"eval_count",
litellm.token_counter(
text=response_json["message"]["content"], count_response_tokens=True
),
)
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return model_response
except Exception as e:
raise e # don't use verbose_logger.exception, if exception is raised