315 lines
10 KiB
Python
315 lines
10 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2024 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import datetime
|
|
import textwrap
|
|
from typing import Iterable, Optional
|
|
|
|
from google.generativeai import protos
|
|
from google.generativeai.types import caching_types
|
|
from google.generativeai.types import content_types
|
|
from google.generativeai.client import get_default_cache_client
|
|
|
|
from google.protobuf import field_mask_pb2
|
|
|
|
_USER_ROLE = "user"
|
|
_MODEL_ROLE = "model"
|
|
|
|
|
|
class CachedContent:
|
|
"""Cached content resource."""
|
|
|
|
def __init__(self, name):
|
|
"""Fetches a `CachedContent` resource.
|
|
|
|
Identical to `CachedContent.get`.
|
|
|
|
Args:
|
|
name: The resource name referring to the cached content.
|
|
"""
|
|
client = get_default_cache_client()
|
|
|
|
if "cachedContents/" not in name:
|
|
name = "cachedContents/" + name
|
|
|
|
request = protos.GetCachedContentRequest(name=name)
|
|
response = client.get_cached_content(request)
|
|
self._proto = response
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._proto.name
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
return self._proto.model
|
|
|
|
@property
|
|
def display_name(self) -> str:
|
|
return self._proto.display_name
|
|
|
|
@property
|
|
def usage_metadata(self) -> protos.CachedContent.UsageMetadata:
|
|
return self._proto.usage_metadata
|
|
|
|
@property
|
|
def create_time(self) -> datetime.datetime:
|
|
return self._proto.create_time
|
|
|
|
@property
|
|
def update_time(self) -> datetime.datetime:
|
|
return self._proto.update_time
|
|
|
|
@property
|
|
def expire_time(self) -> datetime.datetime:
|
|
return self._proto.expire_time
|
|
|
|
def __str__(self):
|
|
return textwrap.dedent(
|
|
f"""\
|
|
CachedContent(
|
|
name='{self.name}',
|
|
model='{self.model}',
|
|
display_name='{self.display_name}',
|
|
usage_metadata={'{'}
|
|
'total_token_count': {self.usage_metadata.total_token_count},
|
|
{'}'},
|
|
create_time={self.create_time},
|
|
update_time={self.update_time},
|
|
expire_time={self.expire_time}
|
|
)"""
|
|
)
|
|
|
|
__repr__ = __str__
|
|
|
|
@classmethod
|
|
def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent:
|
|
"""Creates an instance of CachedContent form an object, without calling `get`."""
|
|
self = cls.__new__(cls)
|
|
self._proto = protos.CachedContent()
|
|
self._update(obj)
|
|
return self
|
|
|
|
def _update(self, updates):
|
|
"""Updates this instance inplace, does not call the API's `update` method"""
|
|
if isinstance(updates, CachedContent):
|
|
updates = updates._proto
|
|
|
|
if not isinstance(updates, dict):
|
|
updates = type(updates).to_dict(updates, including_default_value_fields=False)
|
|
|
|
for key, value in updates.items():
|
|
setattr(self._proto, key, value)
|
|
|
|
@staticmethod
|
|
def _prepare_create_request(
|
|
model: str,
|
|
*,
|
|
display_name: str | None = None,
|
|
system_instruction: Optional[content_types.ContentType] = None,
|
|
contents: Optional[content_types.ContentsType] = None,
|
|
tools: Optional[content_types.FunctionLibraryType] = None,
|
|
tool_config: Optional[content_types.ToolConfigType] = None,
|
|
ttl: Optional[caching_types.TTLTypes] = None,
|
|
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
|
|
) -> protos.CreateCachedContentRequest:
|
|
"""Prepares a CreateCachedContentRequest."""
|
|
if ttl and expire_time:
|
|
raise ValueError(
|
|
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
|
|
)
|
|
|
|
if "/" not in model:
|
|
model = "models/" + model
|
|
|
|
if display_name and len(display_name) > 128:
|
|
raise ValueError("`display_name` must be no more than 128 unicode characters.")
|
|
|
|
if system_instruction:
|
|
system_instruction = content_types.to_content(system_instruction)
|
|
|
|
tools_lib = content_types.to_function_library(tools)
|
|
if tools_lib:
|
|
tools_lib = tools_lib.to_proto()
|
|
|
|
if tool_config:
|
|
tool_config = content_types.to_tool_config(tool_config)
|
|
|
|
if contents:
|
|
contents = content_types.to_contents(contents)
|
|
if not contents[-1].role:
|
|
contents[-1].role = _USER_ROLE
|
|
|
|
ttl = caching_types.to_optional_ttl(ttl)
|
|
expire_time = caching_types.to_optional_expire_time(expire_time)
|
|
|
|
cached_content = protos.CachedContent(
|
|
model=model,
|
|
display_name=display_name,
|
|
system_instruction=system_instruction,
|
|
contents=contents,
|
|
tools=tools_lib,
|
|
tool_config=tool_config,
|
|
ttl=ttl,
|
|
expire_time=expire_time,
|
|
)
|
|
|
|
return protos.CreateCachedContentRequest(cached_content=cached_content)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
model: str,
|
|
*,
|
|
display_name: str | None = None,
|
|
system_instruction: Optional[content_types.ContentType] = None,
|
|
contents: Optional[content_types.ContentsType] = None,
|
|
tools: Optional[content_types.FunctionLibraryType] = None,
|
|
tool_config: Optional[content_types.ToolConfigType] = None,
|
|
ttl: Optional[caching_types.TTLTypes] = None,
|
|
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
|
|
) -> CachedContent:
|
|
"""Creates `CachedContent` resource.
|
|
|
|
Args:
|
|
model: The name of the `model` to use for cached content creation.
|
|
Any `CachedContent` resource can be only used with the
|
|
`model` it was created for.
|
|
display_name: The user-generated meaningful display name
|
|
of the cached content. `display_name` must be no
|
|
more than 128 unicode characters.
|
|
system_instruction: Developer set system instruction.
|
|
contents: Contents to cache.
|
|
tools: A list of `Tools` the model may use to generate response.
|
|
tool_config: Config to apply to all tools.
|
|
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
|
|
`ttl` and `expire_time` are exclusive arguments.
|
|
expire_time: Expiration time for cached resource.
|
|
`ttl` and `expire_time` are exclusive arguments.
|
|
|
|
Returns:
|
|
`CachedContent` resource with specified name.
|
|
"""
|
|
client = get_default_cache_client()
|
|
|
|
request = cls._prepare_create_request(
|
|
model=model,
|
|
display_name=display_name,
|
|
system_instruction=system_instruction,
|
|
contents=contents,
|
|
tools=tools,
|
|
tool_config=tool_config,
|
|
ttl=ttl,
|
|
expire_time=expire_time,
|
|
)
|
|
|
|
response = client.create_cached_content(request)
|
|
result = CachedContent._from_obj(response)
|
|
return result
|
|
|
|
@classmethod
|
|
def get(cls, name: str) -> CachedContent:
|
|
"""Fetches required `CachedContent` resource.
|
|
|
|
Args:
|
|
name: The resource name referring to the cached content.
|
|
|
|
Returns:
|
|
`CachedContent` resource with specified `name`.
|
|
"""
|
|
client = get_default_cache_client()
|
|
|
|
if "cachedContents/" not in name:
|
|
name = "cachedContents/" + name
|
|
|
|
request = protos.GetCachedContentRequest(name=name)
|
|
response = client.get_cached_content(request)
|
|
result = CachedContent._from_obj(response)
|
|
return result
|
|
|
|
@classmethod
|
|
def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]:
|
|
"""Lists `CachedContent` objects associated with the project.
|
|
|
|
Args:
|
|
page_size: The maximum number of permissions to return (per page).
|
|
The service may return fewer `CachedContent` objects.
|
|
|
|
Returns:
|
|
A paginated list of `CachedContent` objects.
|
|
"""
|
|
client = get_default_cache_client()
|
|
|
|
request = protos.ListCachedContentsRequest(page_size=page_size)
|
|
for cached_content in client.list_cached_contents(request):
|
|
cached_content = CachedContent._from_obj(cached_content)
|
|
yield cached_content
|
|
|
|
def delete(self) -> None:
|
|
"""Deletes `CachedContent` resource."""
|
|
client = get_default_cache_client()
|
|
|
|
request = protos.DeleteCachedContentRequest(name=self.name)
|
|
client.delete_cached_content(request)
|
|
return
|
|
|
|
def update(
|
|
self,
|
|
*,
|
|
ttl: Optional[caching_types.TTLTypes] = None,
|
|
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
|
|
) -> None:
|
|
"""Updates requested `CachedContent` resource.
|
|
|
|
Args:
|
|
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
|
|
`ttl` and `expire_time` are exclusive arguments.
|
|
expire_time: Expiration time for cached resource.
|
|
`ttl` and `expire_time` are exclusive arguments.
|
|
"""
|
|
client = get_default_cache_client()
|
|
|
|
if ttl and expire_time:
|
|
raise ValueError(
|
|
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
|
|
)
|
|
|
|
ttl = caching_types.to_optional_ttl(ttl)
|
|
expire_time = caching_types.to_optional_expire_time(expire_time)
|
|
|
|
updates = protos.CachedContent(
|
|
name=self.name,
|
|
ttl=ttl,
|
|
expire_time=expire_time,
|
|
)
|
|
|
|
field_mask = field_mask_pb2.FieldMask()
|
|
|
|
if ttl:
|
|
field_mask.paths.append("ttl")
|
|
elif expire_time:
|
|
field_mask.paths.append("expire_time")
|
|
else:
|
|
raise ValueError(
|
|
f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`."
|
|
)
|
|
|
|
request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask)
|
|
updated_cc = client.update_cached_content(request)
|
|
self._update(updated_cc)
|
|
|
|
return
|