364 lines
14 KiB
Python
364 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2023 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from collections.abc import Iterable
|
|
import itertools
|
|
from typing import Any, Iterable, Union, Mapping, Optional
|
|
from typing_extensions import TypedDict
|
|
|
|
import google.ai.generativelanguage as glm
|
|
from google.generativeai import protos
|
|
|
|
from google.generativeai.client import (
|
|
get_default_generative_client,
|
|
get_default_generative_async_client,
|
|
)
|
|
from google.generativeai.types import model_types
|
|
from google.generativeai.types import helper_types
|
|
from google.generativeai.types import safety_types
|
|
from google.generativeai.types import content_types
|
|
from google.generativeai.types import retriever_types
|
|
from google.generativeai.types.retriever_types import MetadataFilter
|
|
|
|
DEFAULT_ANSWER_MODEL = "models/aqa"
|
|
|
|
AnswerStyle = protos.GenerateAnswerRequest.AnswerStyle
|
|
|
|
AnswerStyleOptions = Union[int, str, AnswerStyle]
|
|
|
|
_ANSWER_STYLES: dict[AnswerStyleOptions, AnswerStyle] = {
|
|
AnswerStyle.ANSWER_STYLE_UNSPECIFIED: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
|
|
0: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
|
|
"answer_style_unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
|
|
"unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
|
|
AnswerStyle.ABSTRACTIVE: AnswerStyle.ABSTRACTIVE,
|
|
1: AnswerStyle.ABSTRACTIVE,
|
|
"answer_style_abstractive": AnswerStyle.ABSTRACTIVE,
|
|
"abstractive": AnswerStyle.ABSTRACTIVE,
|
|
AnswerStyle.EXTRACTIVE: AnswerStyle.EXTRACTIVE,
|
|
2: AnswerStyle.EXTRACTIVE,
|
|
"answer_style_extractive": AnswerStyle.EXTRACTIVE,
|
|
"extractive": AnswerStyle.EXTRACTIVE,
|
|
AnswerStyle.VERBOSE: AnswerStyle.VERBOSE,
|
|
3: AnswerStyle.VERBOSE,
|
|
"answer_style_verbose": AnswerStyle.VERBOSE,
|
|
"verbose": AnswerStyle.VERBOSE,
|
|
}
|
|
|
|
|
|
def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle:
|
|
if isinstance(x, str):
|
|
x = x.lower()
|
|
return _ANSWER_STYLES[x]
|
|
|
|
|
|
GroundingPassageOptions = (
|
|
Union[
|
|
protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType
|
|
],
|
|
)
|
|
|
|
GroundingPassagesOptions = Union[
|
|
protos.GroundingPassages,
|
|
Iterable[GroundingPassageOptions],
|
|
Mapping[str, content_types.ContentType],
|
|
]
|
|
|
|
|
|
def _make_grounding_passages(source: GroundingPassagesOptions) -> protos.GroundingPassages:
|
|
"""
|
|
Converts the `source` into a `protos.GroundingPassage`. A `GroundingPassages` contains a list of
|
|
`protos.GroundingPassage` objects, which each contain a `protos.Content` and a string `id`.
|
|
|
|
Args:
|
|
source: `Content` or a `GroundingPassagesOptions` that will be converted to protos.GroundingPassages.
|
|
|
|
Return:
|
|
`protos.GroundingPassages` to be passed into `protos.GenerateAnswer`.
|
|
"""
|
|
if isinstance(source, protos.GroundingPassages):
|
|
return source
|
|
|
|
if not isinstance(source, Iterable):
|
|
raise TypeError(
|
|
f"Invalid input: The 'source' argument must be an instance of 'GroundingPassagesOptions'. Received a '{type(source).__name__}' object instead."
|
|
)
|
|
|
|
passages = []
|
|
if isinstance(source, Mapping):
|
|
source = source.items()
|
|
|
|
for n, data in enumerate(source):
|
|
if isinstance(data, protos.GroundingPassage):
|
|
passages.append(data)
|
|
elif isinstance(data, tuple):
|
|
id, content = data # tuple must have exactly 2 items.
|
|
passages.append({"id": id, "content": content_types.to_content(content)})
|
|
else:
|
|
passages.append({"id": str(n), "content": content_types.to_content(data)})
|
|
|
|
return protos.GroundingPassages(passages=passages)
|
|
|
|
|
|
SourceNameType = Union[
|
|
str, retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document
|
|
]
|
|
|
|
|
|
class SemanticRetrieverConfigDict(TypedDict):
|
|
source: SourceNameType
|
|
query: content_types.ContentsType
|
|
metadata_filter: Optional[Iterable[MetadataFilter]]
|
|
max_chunks_count: Optional[int]
|
|
minimum_relevance_score: Optional[float]
|
|
|
|
|
|
SemanticRetrieverConfigOptions = Union[
|
|
SourceNameType,
|
|
SemanticRetrieverConfigDict,
|
|
protos.SemanticRetrieverConfig,
|
|
]
|
|
|
|
|
|
def _maybe_get_source_name(source) -> str | None:
|
|
if isinstance(source, str):
|
|
return source
|
|
elif isinstance(
|
|
source, (retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document)
|
|
):
|
|
return source.name
|
|
else:
|
|
return None
|
|
|
|
|
|
def _make_semantic_retriever_config(
|
|
source: SemanticRetrieverConfigOptions,
|
|
query: content_types.ContentsType,
|
|
) -> protos.SemanticRetrieverConfig:
|
|
if isinstance(source, protos.SemanticRetrieverConfig):
|
|
return source
|
|
|
|
name = _maybe_get_source_name(source)
|
|
if name is not None:
|
|
source = {"source": name}
|
|
elif isinstance(source, dict):
|
|
source["source"] = _maybe_get_source_name(source["source"])
|
|
else:
|
|
raise TypeError(
|
|
f"Invalid input: Failed to create a 'protos.SemanticRetrieverConfig' from the provided source. "
|
|
f"Received type: {type(source).__name__}, "
|
|
f"Received value: {source}"
|
|
)
|
|
|
|
if source["query"] is None:
|
|
source["query"] = query
|
|
elif isinstance(source["query"], str):
|
|
source["query"] = content_types.to_content(source["query"])
|
|
|
|
return protos.SemanticRetrieverConfig(source)
|
|
|
|
|
|
def _make_generate_answer_request(
|
|
*,
|
|
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
|
|
contents: content_types.ContentsType,
|
|
inline_passages: GroundingPassagesOptions | None = None,
|
|
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
|
|
answer_style: AnswerStyle | None = None,
|
|
safety_settings: safety_types.SafetySettingOptions | None = None,
|
|
temperature: float | None = None,
|
|
) -> protos.GenerateAnswerRequest:
|
|
"""
|
|
constructs a protos.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model.
|
|
|
|
Args:
|
|
model: Name of the model used to generate the grounded response.
|
|
contents: Content of the current conversation with the model. For single-turn query, this is a
|
|
single question to answer. For multi-turn queries, this is a repeated field that contains
|
|
conversation history and the last `Content` in the list containing the question.
|
|
inline_passages: Grounding passages (a list of `Content`-like objects or `(id, content)` pairs,
|
|
or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
|
|
one must be set, but not both.
|
|
semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
|
|
`inline_passages`, one must be set, but not both.
|
|
answer_style: Style for grounded answers.
|
|
safety_settings: Safety settings for generated output.
|
|
temperature: The temperature for randomness in the output.
|
|
|
|
Returns:
|
|
Call for protos.GenerateAnswerRequest().
|
|
"""
|
|
model = model_types.make_model_name(model)
|
|
|
|
contents = content_types.to_contents(contents)
|
|
|
|
if safety_settings:
|
|
safety_settings = safety_types.normalize_safety_settings(safety_settings)
|
|
|
|
if inline_passages is not None and semantic_retriever is not None:
|
|
raise ValueError(
|
|
f"Invalid configuration: Please set either 'inline_passages' or 'semantic_retriever_config', but not both. "
|
|
f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}."
|
|
)
|
|
elif inline_passages is not None:
|
|
inline_passages = _make_grounding_passages(inline_passages)
|
|
elif semantic_retriever is not None:
|
|
semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1])
|
|
else:
|
|
raise TypeError(
|
|
f"Invalid configuration: Either 'inline_passages' or 'semantic_retriever_config' must be provided, but currently both are 'None'. "
|
|
f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}."
|
|
)
|
|
|
|
if answer_style:
|
|
answer_style = to_answer_style(answer_style)
|
|
|
|
return protos.GenerateAnswerRequest(
|
|
model=model,
|
|
contents=contents,
|
|
inline_passages=inline_passages,
|
|
semantic_retriever=semantic_retriever,
|
|
safety_settings=safety_settings,
|
|
temperature=temperature,
|
|
answer_style=answer_style,
|
|
)
|
|
|
|
|
|
def generate_answer(
|
|
*,
|
|
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
|
|
contents: content_types.ContentsType,
|
|
inline_passages: GroundingPassagesOptions | None = None,
|
|
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
|
|
answer_style: AnswerStyle | None = None,
|
|
safety_settings: safety_types.SafetySettingOptions | None = None,
|
|
temperature: float | None = None,
|
|
client: glm.GenerativeServiceClient | None = None,
|
|
request_options: helper_types.RequestOptionsType | None = None,
|
|
):
|
|
"""Calls the GenerateAnswer API and returns a `types.Answer` containing the response.
|
|
|
|
You can pass a literal list of text chunks:
|
|
|
|
>>> from google.generativeai import answer
|
|
>>> answer.generate_answer(
|
|
... content=question,
|
|
... inline_passages=splitter.split(document)
|
|
... )
|
|
|
|
Or pass a reference to a retreiver Document or Corpus:
|
|
|
|
>>> from google.generativeai import answer
|
|
>>> from google.generativeai import retriever
|
|
>>> my_corpus = retriever.get_corpus('my_corpus')
|
|
>>> genai.generate_answer(
|
|
... content=question,
|
|
... semantic_retriever=my_corpus
|
|
... )
|
|
|
|
|
|
Args:
|
|
model: Which model to call, as a string or a `types.Model`.
|
|
contents: The question to be answered by the model, grounded in the
|
|
provided source.
|
|
inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
|
|
or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
|
|
one must be set, but not both.
|
|
semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
|
|
`inline_passages`, one must be set, but not both.
|
|
answer_style: Style in which the grounded answer should be returned.
|
|
safety_settings: Safety settings for generated output. Defaults to None.
|
|
temperature: Controls the randomness of the output.
|
|
client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead.
|
|
request_options: Options for the request.
|
|
|
|
Returns:
|
|
A `types.Answer` containing the model's text answer response.
|
|
"""
|
|
if request_options is None:
|
|
request_options = {}
|
|
|
|
if client is None:
|
|
client = get_default_generative_client()
|
|
|
|
request = _make_generate_answer_request(
|
|
model=model,
|
|
contents=contents,
|
|
inline_passages=inline_passages,
|
|
semantic_retriever=semantic_retriever,
|
|
safety_settings=safety_settings,
|
|
temperature=temperature,
|
|
answer_style=answer_style,
|
|
)
|
|
|
|
response = client.generate_answer(request, **request_options)
|
|
|
|
return response
|
|
|
|
|
|
async def generate_answer_async(
|
|
*,
|
|
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
|
|
contents: content_types.ContentsType,
|
|
inline_passages: GroundingPassagesOptions | None = None,
|
|
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
|
|
answer_style: AnswerStyle | None = None,
|
|
safety_settings: safety_types.SafetySettingOptions | None = None,
|
|
temperature: float | None = None,
|
|
client: glm.GenerativeServiceClient | None = None,
|
|
request_options: helper_types.RequestOptionsType | None = None,
|
|
):
|
|
"""
|
|
Calls the API and returns a `types.Answer` containing the answer.
|
|
|
|
Args:
|
|
model: Which model to call, as a string or a `types.Model`.
|
|
contents: The question to be answered by the model, grounded in the
|
|
provided source.
|
|
inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
|
|
or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
|
|
one must be set, but not both.
|
|
semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
|
|
`inline_passages`, one must be set, but not both.
|
|
answer_style: Style in which the grounded answer should be returned.
|
|
safety_settings: Safety settings for generated output. Defaults to None.
|
|
temperature: Controls the randomness of the output.
|
|
client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead.
|
|
|
|
Returns:
|
|
A `types.Answer` containing the model's text answer response.
|
|
"""
|
|
if request_options is None:
|
|
request_options = {}
|
|
|
|
if client is None:
|
|
client = get_default_generative_async_client()
|
|
|
|
request = _make_generate_answer_request(
|
|
model=model,
|
|
contents=contents,
|
|
inline_passages=inline_passages,
|
|
semantic_retriever=semantic_retriever,
|
|
safety_settings=safety_settings,
|
|
temperature=temperature,
|
|
answer_style=answer_style,
|
|
)
|
|
|
|
response = await client.generate_answer(request, **request_options)
|
|
|
|
return response
|