313 lines
12 KiB
Python
313 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2023 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import itertools
|
|
from typing import Any, Iterable, overload, TypeVar, Union, Mapping
|
|
|
|
import google.ai.generativelanguage as glm
|
|
from google.generativeai import protos
|
|
|
|
from google.generativeai.client import get_default_generative_client
|
|
from google.generativeai.client import get_default_generative_async_client
|
|
|
|
from google.generativeai.types import helper_types
|
|
from google.generativeai.types import model_types
|
|
from google.generativeai.types import text_types
|
|
from google.generativeai.types import content_types
|
|
|
|
DEFAULT_EMB_MODEL = "models/embedding-001"
|
|
EMBEDDING_MAX_BATCH_SIZE = 100
|
|
|
|
EmbeddingTaskType = protos.TaskType
|
|
|
|
EmbeddingTaskTypeOptions = Union[int, str, EmbeddingTaskType]
|
|
|
|
_EMBEDDING_TASK_TYPE: dict[EmbeddingTaskTypeOptions, EmbeddingTaskType] = {
|
|
EmbeddingTaskType.TASK_TYPE_UNSPECIFIED: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
|
0: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
|
"task_type_unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
|
"unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
|
EmbeddingTaskType.RETRIEVAL_QUERY: EmbeddingTaskType.RETRIEVAL_QUERY,
|
|
1: EmbeddingTaskType.RETRIEVAL_QUERY,
|
|
"retrieval_query": EmbeddingTaskType.RETRIEVAL_QUERY,
|
|
"query": EmbeddingTaskType.RETRIEVAL_QUERY,
|
|
EmbeddingTaskType.RETRIEVAL_DOCUMENT: EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
|
2: EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
|
"retrieval_document": EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
|
"document": EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
|
EmbeddingTaskType.SEMANTIC_SIMILARITY: EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
|
3: EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
|
"semantic_similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
|
"similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
|
EmbeddingTaskType.CLASSIFICATION: EmbeddingTaskType.CLASSIFICATION,
|
|
4: EmbeddingTaskType.CLASSIFICATION,
|
|
"classification": EmbeddingTaskType.CLASSIFICATION,
|
|
EmbeddingTaskType.CLUSTERING: EmbeddingTaskType.CLUSTERING,
|
|
5: EmbeddingTaskType.CLUSTERING,
|
|
"clustering": EmbeddingTaskType.CLUSTERING,
|
|
6: EmbeddingTaskType.QUESTION_ANSWERING,
|
|
"question_answering": EmbeddingTaskType.QUESTION_ANSWERING,
|
|
"qa": EmbeddingTaskType.QUESTION_ANSWERING,
|
|
EmbeddingTaskType.QUESTION_ANSWERING: EmbeddingTaskType.QUESTION_ANSWERING,
|
|
7: EmbeddingTaskType.FACT_VERIFICATION,
|
|
"fact_verification": EmbeddingTaskType.FACT_VERIFICATION,
|
|
"verification": EmbeddingTaskType.FACT_VERIFICATION,
|
|
EmbeddingTaskType.FACT_VERIFICATION: EmbeddingTaskType.FACT_VERIFICATION,
|
|
}
|
|
|
|
|
|
def to_task_type(x: EmbeddingTaskTypeOptions) -> EmbeddingTaskType:
|
|
if isinstance(x, str):
|
|
x = x.lower()
|
|
return _EMBEDDING_TASK_TYPE[x]
|
|
|
|
|
|
try:
|
|
# python 3.12+
|
|
_batched = itertools.batched # type: ignore
|
|
except AttributeError:
|
|
T = TypeVar("T")
|
|
|
|
def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]:
|
|
if n < 1:
|
|
raise ValueError(
|
|
f"Invalid input: The batch size 'n' must be a positive integer. You entered: {n}. Please enter a number greater than 0."
|
|
)
|
|
batch = []
|
|
for item in iterable:
|
|
batch.append(item)
|
|
if len(batch) == n:
|
|
yield batch
|
|
batch = []
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
|
|
@overload
|
|
def embed_content(
|
|
model: model_types.BaseModelNameOptions,
|
|
content: content_types.ContentType,
|
|
task_type: EmbeddingTaskTypeOptions | None = None,
|
|
title: str | None = None,
|
|
output_dimensionality: int | None = None,
|
|
client: glm.GenerativeServiceClient | None = None,
|
|
request_options: helper_types.RequestOptionsType | None = None,
|
|
) -> text_types.EmbeddingDict: ...
|
|
|
|
|
|
@overload
|
|
def embed_content(
|
|
model: model_types.BaseModelNameOptions,
|
|
content: Iterable[content_types.ContentType],
|
|
task_type: EmbeddingTaskTypeOptions | None = None,
|
|
title: str | None = None,
|
|
output_dimensionality: int | None = None,
|
|
client: glm.GenerativeServiceClient | None = None,
|
|
request_options: helper_types.RequestOptionsType | None = None,
|
|
) -> text_types.BatchEmbeddingDict: ...
|
|
|
|
|
|
def embed_content(
|
|
model: model_types.BaseModelNameOptions,
|
|
content: content_types.ContentType | Iterable[content_types.ContentType],
|
|
task_type: EmbeddingTaskTypeOptions | None = None,
|
|
title: str | None = None,
|
|
output_dimensionality: int | None = None,
|
|
client: glm.GenerativeServiceClient = None,
|
|
request_options: helper_types.RequestOptionsType | None = None,
|
|
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
|
|
"""Calls the API to create embeddings for content passed in.
|
|
|
|
Args:
|
|
model:
|
|
Which [model](https://ai.google.dev/models/gemini#embedding) to
|
|
call, as a string or a `types.Model`.
|
|
|
|
content:
|
|
Content to embed.
|
|
|
|
task_type:
|
|
Optional task type for which the embeddings will be used. Can only
|
|
be set for `models/embedding-001`.
|
|
|
|
title:
|
|
An optional title for the text. Only applicable when task_type is
|
|
`RETRIEVAL_DOCUMENT`.
|
|
|
|
output_dimensionality:
|
|
Optional reduced dimensionality for the output embeddings. If set,
|
|
excessive values from the output embeddings will be truncated from
|
|
the end.
|
|
|
|
request_options:
|
|
Options for the request.
|
|
|
|
Return:
|
|
Dictionary containing the embedding (list of float values) for the
|
|
input content.
|
|
"""
|
|
model = model_types.make_model_name(model)
|
|
|
|
if request_options is None:
|
|
request_options = {}
|
|
|
|
if client is None:
|
|
client = get_default_generative_client()
|
|
|
|
if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT:
|
|
raise ValueError(
|
|
f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}."
|
|
)
|
|
|
|
if output_dimensionality and output_dimensionality < 0:
|
|
raise ValueError(
|
|
f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}."
|
|
)
|
|
|
|
if task_type:
|
|
task_type = to_task_type(task_type)
|
|
|
|
if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)):
|
|
result = {"embedding": []}
|
|
requests = (
|
|
protos.EmbedContentRequest(
|
|
model=model,
|
|
content=content_types.to_content(c),
|
|
task_type=task_type,
|
|
title=title,
|
|
output_dimensionality=output_dimensionality,
|
|
)
|
|
for c in content
|
|
)
|
|
for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE):
|
|
embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch)
|
|
embedding_response = client.batch_embed_contents(
|
|
embedding_request,
|
|
**request_options,
|
|
)
|
|
embedding_dict = type(embedding_response).to_dict(embedding_response)
|
|
result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"])
|
|
return result
|
|
else:
|
|
embedding_request = protos.EmbedContentRequest(
|
|
model=model,
|
|
content=content_types.to_content(content),
|
|
task_type=task_type,
|
|
title=title,
|
|
output_dimensionality=output_dimensionality,
|
|
)
|
|
embedding_response = client.embed_content(
|
|
embedding_request,
|
|
**request_options,
|
|
)
|
|
embedding_dict = type(embedding_response).to_dict(embedding_response)
|
|
embedding_dict["embedding"] = embedding_dict["embedding"]["values"]
|
|
return embedding_dict
|
|
|
|
|
|
@overload
|
|
async def embed_content_async(
|
|
model: model_types.BaseModelNameOptions,
|
|
content: content_types.ContentType,
|
|
task_type: EmbeddingTaskTypeOptions | None = None,
|
|
title: str | None = None,
|
|
output_dimensionality: int | None = None,
|
|
client: glm.GenerativeServiceAsyncClient | None = None,
|
|
request_options: helper_types.RequestOptionsType | None = None,
|
|
) -> text_types.EmbeddingDict: ...
|
|
|
|
|
|
@overload
|
|
async def embed_content_async(
|
|
model: model_types.BaseModelNameOptions,
|
|
content: Iterable[content_types.ContentType],
|
|
task_type: EmbeddingTaskTypeOptions | None = None,
|
|
title: str | None = None,
|
|
output_dimensionality: int | None = None,
|
|
client: glm.GenerativeServiceAsyncClient | None = None,
|
|
request_options: helper_types.RequestOptionsType | None = None,
|
|
) -> text_types.BatchEmbeddingDict: ...
|
|
|
|
|
|
async def embed_content_async(
|
|
model: model_types.BaseModelNameOptions,
|
|
content: content_types.ContentType | Iterable[content_types.ContentType],
|
|
task_type: EmbeddingTaskTypeOptions | None = None,
|
|
title: str | None = None,
|
|
output_dimensionality: int | None = None,
|
|
client: glm.GenerativeServiceAsyncClient = None,
|
|
request_options: helper_types.RequestOptionsType | None = None,
|
|
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
|
|
"""Calls the API to create async embeddings for content passed in."""
|
|
|
|
model = model_types.make_model_name(model)
|
|
|
|
if request_options is None:
|
|
request_options = {}
|
|
|
|
if client is None:
|
|
client = get_default_generative_async_client()
|
|
|
|
if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT:
|
|
raise ValueError(
|
|
f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}."
|
|
)
|
|
if output_dimensionality and output_dimensionality < 0:
|
|
raise ValueError(
|
|
f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}."
|
|
)
|
|
|
|
if task_type:
|
|
task_type = to_task_type(task_type)
|
|
|
|
if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)):
|
|
result = {"embedding": []}
|
|
requests = (
|
|
protos.EmbedContentRequest(
|
|
model=model,
|
|
content=content_types.to_content(c),
|
|
task_type=task_type,
|
|
title=title,
|
|
output_dimensionality=output_dimensionality,
|
|
)
|
|
for c in content
|
|
)
|
|
for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE):
|
|
embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch)
|
|
embedding_response = await client.batch_embed_contents(
|
|
embedding_request,
|
|
**request_options,
|
|
)
|
|
embedding_dict = type(embedding_response).to_dict(embedding_response)
|
|
result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"])
|
|
return result
|
|
else:
|
|
embedding_request = protos.EmbedContentRequest(
|
|
model=model,
|
|
content=content_types.to_content(content),
|
|
task_type=task_type,
|
|
title=title,
|
|
output_dimensionality=output_dimensionality,
|
|
)
|
|
embedding_response = await client.embed_content(
|
|
embedding_request,
|
|
**request_options,
|
|
)
|
|
embedding_dict = type(embedding_response).to_dict(embedding_response)
|
|
embedding_dict["embedding"] = embedding_dict["embedding"]["values"]
|
|
return embedding_dict
|