125 lines
5.3 KiB
Python
125 lines
5.3 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
||
|
|
# Copyright 2023 Google LLC
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
"""MagicsEngine class."""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import AbstractSet, Sequence
|
||
|
|
|
||
|
|
from google.generativeai.notebook import argument_parser
|
||
|
|
from google.generativeai.notebook import cmd_line_parser
|
||
|
|
from google.generativeai.notebook import command
|
||
|
|
from google.generativeai.notebook import compare_cmd
|
||
|
|
from google.generativeai.notebook import compile_cmd
|
||
|
|
from google.generativeai.notebook import eval_cmd
|
||
|
|
from google.generativeai.notebook import ipython_env
|
||
|
|
from google.generativeai.notebook import model_registry
|
||
|
|
from google.generativeai.notebook import parsed_args_lib
|
||
|
|
from google.generativeai.notebook import post_process_utils
|
||
|
|
from google.generativeai.notebook import run_cmd
|
||
|
|
from google.generativeai.notebook.lib import prompt_utils
|
||
|
|
|
||
|
|
|
||
|
|
class MagicsEngine:
|
||
|
|
"""Implementation of functionality used by Magics.
|
||
|
|
|
||
|
|
This class provides the implementation for Magics, decoupled from the
|
||
|
|
details of integrating with Colab Magics such as registration.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
registry: model_registry.ModelRegistry | None = None,
|
||
|
|
env: ipython_env.IPythonEnv | None = None,
|
||
|
|
):
|
||
|
|
self._ipython_env = env
|
||
|
|
models = registry or model_registry.ModelRegistry()
|
||
|
|
self._cmd_handlers: dict[parsed_args_lib.CommandName, command.Command] = {
|
||
|
|
parsed_args_lib.CommandName.RUN_CMD: run_cmd.RunCommand(models=models, env=env),
|
||
|
|
parsed_args_lib.CommandName.COMPILE_CMD: compile_cmd.CompileCommand(
|
||
|
|
models=models, env=env
|
||
|
|
),
|
||
|
|
parsed_args_lib.CommandName.COMPARE_CMD: compare_cmd.CompareCommand(env=env),
|
||
|
|
parsed_args_lib.CommandName.EVAL_CMD: eval_cmd.EvalCommand(models=models, env=env),
|
||
|
|
}
|
||
|
|
|
||
|
|
def parse_line(
|
||
|
|
self,
|
||
|
|
line: str,
|
||
|
|
placeholders: AbstractSet[str],
|
||
|
|
) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]:
|
||
|
|
return cmd_line_parser.CmdLineParser().parse_line(line, placeholders)
|
||
|
|
|
||
|
|
def _get_handler(self, line: str, placeholders: AbstractSet[str]) -> tuple[
|
||
|
|
command.Command,
|
||
|
|
parsed_args_lib.ParsedArgs,
|
||
|
|
Sequence[post_process_utils.ParsedPostProcessExpr],
|
||
|
|
]:
|
||
|
|
"""Given the command line, parse and return all components.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
line: The LLM Magics command line.
|
||
|
|
placeholders: Placeholders from prompts in the cell contents.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A three-tuple containing:
|
||
|
|
- The command (e.g. "run")
|
||
|
|
- Parsed arguments for the command,
|
||
|
|
- Parsed post-processing expressions
|
||
|
|
"""
|
||
|
|
parsed_args, post_processing_tokens = self.parse_line(line, placeholders)
|
||
|
|
cmd_name = parsed_args.cmd
|
||
|
|
handler = self._cmd_handlers[cmd_name]
|
||
|
|
post_processing_fns = handler.parse_post_processing_tokens(post_processing_tokens)
|
||
|
|
return handler, parsed_args, post_processing_fns
|
||
|
|
|
||
|
|
def execute_cell(self, line: str, cell_content: str):
|
||
|
|
"""Executes the supplied magic line and cell payload."""
|
||
|
|
cell = _clean_cell(cell_content)
|
||
|
|
placeholders = prompt_utils.get_placeholders(cell)
|
||
|
|
|
||
|
|
try:
|
||
|
|
handler, parsed_args, post_processing_fns = self._get_handler(line, placeholders)
|
||
|
|
return handler.execute(parsed_args, cell, post_processing_fns)
|
||
|
|
except argument_parser.ParserNormalExit as e:
|
||
|
|
if self._ipython_env is not None:
|
||
|
|
e.set_ipython_env(self._ipython_env)
|
||
|
|
# ParserNormalExit implements the _ipython_display_ method so it can
|
||
|
|
# be returned as the output of this cell for display.
|
||
|
|
return e
|
||
|
|
except argument_parser.ParserError as e:
|
||
|
|
e.display(self._ipython_env)
|
||
|
|
# Raise an exception to indicate that execution for this cell has
|
||
|
|
# failed.
|
||
|
|
# The exception is re-raised as SystemExit because Colab automatically
|
||
|
|
# suppresses traceback for SystemExit but not other exceptions. Because
|
||
|
|
# ParserErrors are usually due to user error (e.g. a missing required
|
||
|
|
# flag or an invalid flag value), we want to hide the traceback to
|
||
|
|
# avoid detracting the user from the error message, and we want to
|
||
|
|
# reserve exceptions-with-traceback for actual bugs and unexpected
|
||
|
|
# errors.
|
||
|
|
error_msg = "Got parser error: {}".format(e.msgs()[-1]) if e.msgs() else ""
|
||
|
|
raise SystemExit(error_msg) from e
|
||
|
|
|
||
|
|
|
||
|
|
def _clean_cell(cell_content: str) -> str:
|
||
|
|
# Colab includes a trailing newline in cell_content. Remove only the last
|
||
|
|
# line break from cell contents (i.e. not rstrip), so that multi-line and
|
||
|
|
# intentional line breaks are preserved, but single-line prompts don't have
|
||
|
|
# a trailing line break.
|
||
|
|
cell = cell_content
|
||
|
|
if cell.endswith("\n"):
|
||
|
|
cell = cell[:-1]
|
||
|
|
return cell
|