73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
||
|
|
# Copyright 2023 Google LLC
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
"""Model that uses the Text service."""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from google.api_core import retry
|
||
|
|
import google.generativeai as genai
|
||
|
|
from google.generativeai.types import generation_types
|
||
|
|
from google.generativeai.notebook.lib import model as model_lib
|
||
|
|
|
||
|
|
_DEFAULT_MODEL = "models/gemini-1.5-flash"
|
||
|
|
|
||
|
|
|
||
|
|
class TextModel(model_lib.AbstractModel):
|
||
|
|
"""Concrete model that uses the generate_content service."""
|
||
|
|
|
||
|
|
def _generate_text(
|
||
|
|
self,
|
||
|
|
prompt: str,
|
||
|
|
model: str | None = None,
|
||
|
|
temperature: float | None = None,
|
||
|
|
candidate_count: int | None = None,
|
||
|
|
) -> generation_types.GenerateContentResponse:
|
||
|
|
gen_config = {}
|
||
|
|
if temperature is not None:
|
||
|
|
gen_config["temperature"] = temperature
|
||
|
|
if candidate_count is not None:
|
||
|
|
gen_config["candidate_count"] = candidate_count
|
||
|
|
|
||
|
|
model_name = model or _DEFAULT_MODEL
|
||
|
|
gen_model = genai.GenerativeModel(model_name=model_name)
|
||
|
|
gc = genai.types.generation_types.GenerationConfig(**gen_config)
|
||
|
|
return gen_model.generate_content(prompt, generation_config=gc)
|
||
|
|
|
||
|
|
def call_model(
|
||
|
|
self,
|
||
|
|
model_input: str,
|
||
|
|
model_args: model_lib.ModelArguments | None = None,
|
||
|
|
) -> model_lib.ModelResults:
|
||
|
|
if model_args is None:
|
||
|
|
model_args = model_lib.ModelArguments()
|
||
|
|
|
||
|
|
# Wrap the generation function here, rather than decorate, so that it
|
||
|
|
# applies to any overridden calls too.
|
||
|
|
retryable_fn = retry.Retry(retry.if_transient_error)(self._generate_text)
|
||
|
|
response = retryable_fn(
|
||
|
|
prompt=model_input,
|
||
|
|
model=model_args.model,
|
||
|
|
temperature=model_args.temperature,
|
||
|
|
candidate_count=model_args.candidate_count,
|
||
|
|
)
|
||
|
|
|
||
|
|
text_outputs = []
|
||
|
|
for c in response.candidates:
|
||
|
|
text_outputs.append("".join(p.text for p in c.content.parts))
|
||
|
|
|
||
|
|
return model_lib.ModelResults(
|
||
|
|
model_input=model_input,
|
||
|
|
text_results=text_outputs,
|
||
|
|
)
|