504 lines
17 KiB
Python
504 lines
17 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2023 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Classes that define arguments for populating ArgumentParser.
|
|
|
|
The argparse module's ArgumentParser.add_argument() takes several parameters and
|
|
is quite customizable. However this can lead to bugs where arguments do not
|
|
behave as expected.
|
|
|
|
For better ease-of-use and better testability, define a set of classes for the
|
|
types of flags used by LLM Magics.
|
|
|
|
Sample usage:
|
|
|
|
str_flag = SingleValueFlagDef(name="title", required=True)
|
|
enum_flag = EnumFlagDef(name="colors", required=True, enum_type=ColorsEnum)
|
|
|
|
str_flag.add_argument_to_parser(my_parser)
|
|
enum_flag.add_argument_to_parser(my_parser)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import abc
|
|
import argparse
|
|
import dataclasses
|
|
import enum
|
|
from typing import Any, Callable, Sequence, Tuple, Union
|
|
|
|
from google.generativeai.notebook.lib import llmfn_inputs_source
|
|
from google.generativeai.notebook.lib import llmfn_outputs
|
|
|
|
# These are the intermediate types that argparse.ArgumentParser.parse_args()
|
|
# will pass command line arguments into.
|
|
_PARSETYPES = Union[str, int, float]
|
|
# These are the final result types that the intermediate parsed values will be
|
|
# converted into. It is a superset of _PARSETYPES because we support converting
|
|
# the parsed type into a more precise type, e.g. from str to Enum.
|
|
_DESTTYPES = Union[
|
|
_PARSETYPES,
|
|
enum.Enum,
|
|
Tuple[str, Callable[[str, str], Any]],
|
|
Sequence[str], # For --compare_fn
|
|
llmfn_inputs_source.LLMFnInputsSource, # For --ground_truth
|
|
llmfn_outputs.LLMFnOutputsSink, # For --inputs # For --outputs
|
|
]
|
|
|
|
# The signature of a function that converts a command line argument from the
|
|
# intermediate parsed type to the result type.
|
|
_PARSEFN = Callable[[_PARSETYPES], _DESTTYPES]
|
|
|
|
|
|
def _get_type_name(x: type[Any]) -> str:
|
|
try:
|
|
return x.__name__
|
|
except AttributeError:
|
|
return str(x)
|
|
|
|
|
|
def _validate_flag_name(name: str) -> str:
|
|
"""Validation for long and short names for flags."""
|
|
if not name:
|
|
raise ValueError("Cannot be empty")
|
|
if name[0] == "-":
|
|
raise ValueError("Cannot start with dash")
|
|
return name
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FlagDef(abc.ABC):
|
|
"""Abstract base class for flag definitions.
|
|
|
|
Attributes:
|
|
name: Long name, e.g. "colors" will define the flag "--colors".
|
|
required: Whether the flag must be provided on the command line.
|
|
short_name: Optional short name.
|
|
parse_type: The type that ArgumentParser should parse the command line
|
|
argument to.
|
|
dest_type: The type that the parsed value is converted to. This is used when
|
|
we want ArgumentParser to parse as one type, then convert to a different
|
|
type. E.g. for enums we parse as "str" then convert to the desired enum
|
|
type in order to provide cleaner help messages.
|
|
parse_to_dest_type_fn: If provided, this function will be used to convert
|
|
the value from `parse_type` to `dest_type`. This can be used for
|
|
validation as well.
|
|
choices: If provided, limit the set of acceptable values to these choices.
|
|
help_msg: If provided, adds help message when -h is used in the command
|
|
line.
|
|
"""
|
|
|
|
name: str
|
|
required: bool = False
|
|
|
|
short_name: str | None = None
|
|
|
|
parse_type: type[_PARSETYPES] = str
|
|
dest_type: type[_DESTTYPES] | None = None
|
|
parse_to_dest_type_fn: _PARSEFN | None = None
|
|
|
|
choices: list[_PARSETYPES] | None = None
|
|
help_msg: str | None = None
|
|
|
|
@abc.abstractmethod
|
|
def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
|
|
"""Adds this flag as an argument to `parser`.
|
|
|
|
Child classes should implement this as a call to parser.add_argument()
|
|
with the appropriate parameters.
|
|
|
|
Args:
|
|
parser: The parser to which this argument will be added.
|
|
"""
|
|
|
|
@abc.abstractmethod
|
|
def _do_additional_validation(self) -> None:
|
|
"""For child classes to do additional validation."""
|
|
|
|
def _get_dest_type(self) -> type[_DESTTYPES]:
|
|
"""Returns the final converted type."""
|
|
return self.parse_type if self.dest_type is None else self.dest_type
|
|
|
|
def _get_parse_to_dest_type_fn(
|
|
self,
|
|
) -> _PARSEFN:
|
|
"""Returns a function to convert from parse_type to dest_type."""
|
|
if self.parse_to_dest_type_fn is not None:
|
|
return self.parse_to_dest_type_fn
|
|
|
|
dest_type = self._get_dest_type()
|
|
if dest_type == self.parse_type:
|
|
return lambda x: x
|
|
else:
|
|
return dest_type
|
|
|
|
def __post_init__(self):
|
|
_validate_flag_name(self.name)
|
|
if self.short_name is not None:
|
|
_validate_flag_name(self.short_name)
|
|
|
|
self._do_additional_validation()
|
|
|
|
|
|
def _has_non_default_value(
|
|
namespace: argparse.Namespace,
|
|
dest: str,
|
|
has_default: bool = False,
|
|
default_value: Any = None,
|
|
) -> bool:
|
|
"""Returns true if `namespace.dest` is set to a non-default value.
|
|
|
|
Args:
|
|
namespace: The Namespace that is populated by ArgumentParser.
|
|
dest: The attribute in the Namespace to be populated.
|
|
has_default: "None" is a valid default value so we use an additional
|
|
`has_default` boolean to indicate that `default_value` is present.
|
|
default_value: The default value to use when `has_default` is True.
|
|
|
|
Returns:
|
|
Whether namespace.dest is set to something other than the default value.
|
|
"""
|
|
if not hasattr(namespace, dest):
|
|
return False
|
|
|
|
if not has_default:
|
|
# No default value provided so `namespace.dest` cannot possibly be equal to
|
|
# the default value.
|
|
return True
|
|
|
|
return getattr(namespace, dest) != default_value
|
|
|
|
|
|
class _SingleValueStoreAction(argparse.Action):
|
|
"""Custom Action for storing a value in an argparse.Namespace.
|
|
|
|
This action checks that the flag is specified at-most once.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
option_strings,
|
|
dest,
|
|
dest_type: type[Any],
|
|
parse_to_dest_type_fn: _PARSEFN,
|
|
**kwargs,
|
|
):
|
|
super().__init__(option_strings, dest, **kwargs)
|
|
self._dest_type = dest_type
|
|
self._parse_to_dest_type_fn = parse_to_dest_type_fn
|
|
|
|
def __call__(
|
|
self,
|
|
parser: argparse.ArgumentParser,
|
|
namespace: argparse.Namespace,
|
|
values: str | Sequence[Any] | None,
|
|
option_string: str | None = None,
|
|
):
|
|
# Because `nargs` is set to 1, `values` must be a Sequence, rather
|
|
# than a string.
|
|
assert not isinstance(values, str) and not isinstance(values, bytes)
|
|
|
|
if _has_non_default_value(
|
|
namespace,
|
|
self.dest,
|
|
has_default=hasattr(self, "default"),
|
|
default_value=getattr(self, "default"),
|
|
):
|
|
raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string))
|
|
|
|
try:
|
|
converted_value = self._parse_to_dest_type_fn(values[0])
|
|
except Exception as e:
|
|
raise argparse.ArgumentError(
|
|
self,
|
|
'Error with value "{}", got {}: {}'.format(values[0], _get_type_name(type(e)), e),
|
|
)
|
|
|
|
if not isinstance(converted_value, self._dest_type):
|
|
raise RuntimeError(
|
|
"Converted to wrong type, expected {} got {}".format(
|
|
_get_type_name(self._dest_type),
|
|
_get_type_name(type(converted_value)),
|
|
)
|
|
)
|
|
setattr(namespace, self.dest, converted_value)
|
|
|
|
|
|
class _MultiValuesAppendAction(argparse.Action):
|
|
"""Custom Action for appending values in an argparse.Namespace.
|
|
|
|
This action checks that the flag is specified at-most once.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
option_strings,
|
|
dest,
|
|
dest_type: type[Any],
|
|
parse_to_dest_type_fn: _PARSEFN,
|
|
**kwargs,
|
|
):
|
|
super().__init__(option_strings, dest, **kwargs)
|
|
self._dest_type = dest_type
|
|
self._parse_to_dest_type_fn = parse_to_dest_type_fn
|
|
|
|
def __call__(
|
|
self,
|
|
parser: argparse.ArgumentParser,
|
|
namespace: argparse.Namespace,
|
|
values: str | Sequence[Any] | None,
|
|
option_string: str | None = None,
|
|
):
|
|
# Because `nargs` is set to "+", `values` must be a Sequence, rather
|
|
# than a string.
|
|
assert not isinstance(values, str) and not isinstance(values, bytes)
|
|
|
|
curr_value = getattr(namespace, self.dest)
|
|
if curr_value:
|
|
raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string))
|
|
|
|
for value in values:
|
|
try:
|
|
converted_value = self._parse_to_dest_type_fn(value)
|
|
except Exception as e:
|
|
raise argparse.ArgumentError(
|
|
self,
|
|
'Error with value "{}", got {}: {}'.format(
|
|
values[0], _get_type_name(type(e)), e
|
|
),
|
|
)
|
|
|
|
if not isinstance(converted_value, self._dest_type):
|
|
raise RuntimeError(
|
|
"Converted to wrong type, expected {} got {}".format(
|
|
self._dest_type, type(converted_value)
|
|
)
|
|
)
|
|
if converted_value in curr_value:
|
|
raise argparse.ArgumentError(self, 'Duplicate values "{}"'.format(value))
|
|
|
|
curr_value.append(converted_value)
|
|
|
|
|
|
class _BooleanValueStoreAction(argparse.Action):
|
|
"""Custom Action for setting a boolean value in argparse.Namespace.
|
|
|
|
The boolean flag expects the default to be False and will set the value to
|
|
True.
|
|
This action checks that the flag is specified at-most once.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
option_strings,
|
|
dest,
|
|
**kwargs,
|
|
):
|
|
super().__init__(option_strings, dest, **kwargs)
|
|
|
|
def __call__(
|
|
self,
|
|
parser: argparse.ArgumentParser,
|
|
namespace: argparse.Namespace,
|
|
values: str | Sequence[Any] | None,
|
|
option_string: str | None = None,
|
|
):
|
|
if _has_non_default_value(
|
|
namespace,
|
|
self.dest,
|
|
has_default=True,
|
|
default_value=False,
|
|
):
|
|
raise argparse.ArgumentError(self, "Cannot set {} more than once".format(option_string))
|
|
|
|
setattr(namespace, self.dest, True)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class SingleValueFlagDef(FlagDef):
|
|
"""Definition for a flag that takes a single value.
|
|
|
|
Sample usage:
|
|
# This defines a flag that can be specified on the command line as:
|
|
# --count=10
|
|
flag = SingleValueFlagDef(name="count", parse_type=int, required=True)
|
|
flag.add_argument_to_parser(argument_parser)
|
|
|
|
Attributes:
|
|
default_value: Default value for optional flags.
|
|
"""
|
|
|
|
class _DefaultValue(enum.Enum):
|
|
"""Special value to represent "no value provided".
|
|
|
|
"None" can be used as a default value, so in order to differentiate between
|
|
"None" and "no value provided", create a special value for "no value
|
|
provided".
|
|
"""
|
|
|
|
NOT_SET = None
|
|
|
|
default_value: _DESTTYPES | _DefaultValue | None = _DefaultValue.NOT_SET
|
|
|
|
def _has_default_value(self) -> bool:
|
|
"""Returns whether `default_value` has been provided."""
|
|
return self.default_value != SingleValueFlagDef._DefaultValue.NOT_SET
|
|
|
|
def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
|
|
args = ["--" + self.name]
|
|
if self.short_name is not None:
|
|
args += ["-" + self.short_name]
|
|
|
|
kwargs = {}
|
|
if self._has_default_value():
|
|
kwargs["default"] = self.default_value
|
|
if self.choices is not None:
|
|
kwargs["choices"] = self.choices
|
|
if self.help_msg is not None:
|
|
kwargs["help"] = self.help_msg
|
|
|
|
parser.add_argument(
|
|
*args,
|
|
action=_SingleValueStoreAction,
|
|
type=self.parse_type,
|
|
dest_type=self._get_dest_type(),
|
|
parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(),
|
|
required=self.required,
|
|
nargs=1,
|
|
**kwargs,
|
|
)
|
|
|
|
def _do_additional_validation(self) -> None:
|
|
if self.required:
|
|
if self._has_default_value():
|
|
raise ValueError("Required flags cannot have default value")
|
|
else:
|
|
if not self._has_default_value():
|
|
raise ValueError("Optional flags must have a default value")
|
|
|
|
if self._has_default_value() and self.default_value is not None:
|
|
if not isinstance(self.default_value, self._get_dest_type()):
|
|
raise ValueError("Default value must be of the same type as the destination type")
|
|
|
|
|
|
class EnumFlagDef(SingleValueFlagDef):
|
|
"""Definition for a flag that takes a value from an Enum.
|
|
|
|
Sample usage:
|
|
# This defines a flag that can be specified on the command line as:
|
|
# --color=red
|
|
flag = SingleValueFlagDef(name="color", enum_type=ColorsEnum,
|
|
required=True)
|
|
flag.add_argument_to_parser(argument_parser)
|
|
"""
|
|
|
|
def __init__(self, *args, enum_type: type[enum.Enum], **kwargs):
|
|
if not issubclass(enum_type, enum.Enum):
|
|
raise TypeError('"enum_type" must be of type Enum')
|
|
|
|
# These properties are set by "enum_type" so don"t let the caller set them.
|
|
if "parse_type" in kwargs:
|
|
raise ValueError('Cannot set "parse_type" for EnumFlagDef; set "enum_type" instead')
|
|
kwargs["parse_type"] = str
|
|
|
|
if "dest_type" in kwargs:
|
|
raise ValueError('Cannot set "dest_type" for EnumFlagDef; set "enum_type" instead')
|
|
kwargs["dest_type"] = enum_type
|
|
|
|
if "choices" in kwargs:
|
|
# Verify that entries in `choices` are valid enum values.
|
|
for x in kwargs["choices"]:
|
|
try:
|
|
enum_type(x)
|
|
except ValueError:
|
|
raise ValueError('Invalid value in "choices": "{}"'.format(x)) from None
|
|
else:
|
|
kwargs["choices"] = [x.value for x in enum_type]
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
class MultiValuesFlagDef(FlagDef):
|
|
"""Definition for a flag that takes multiple values.
|
|
|
|
Sample usage:
|
|
# This defines a flag that can be specified on the command line as:
|
|
# --colors=red green blue
|
|
flag = MultiValuesFlagDef(name="colors", parse_type=str, required=True)
|
|
flag.add_argument_to_parser(argument_parser)
|
|
"""
|
|
|
|
def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
|
|
args = ["--" + self.name]
|
|
if self.short_name is not None:
|
|
args += ["-" + self.short_name]
|
|
|
|
kwargs = {}
|
|
if self.choices is not None:
|
|
kwargs["choices"] = self.choices
|
|
if self.help_msg is not None:
|
|
kwargs["help"] = self.help_msg
|
|
|
|
parser.add_argument(
|
|
*args,
|
|
action=_MultiValuesAppendAction,
|
|
type=self.parse_type,
|
|
dest_type=self._get_dest_type(),
|
|
parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(),
|
|
required=self.required,
|
|
default=[],
|
|
nargs="+",
|
|
**kwargs,
|
|
)
|
|
|
|
def _do_additional_validation(self) -> None:
|
|
# No additional validation needed.
|
|
pass
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class BooleanFlagDef(FlagDef):
|
|
"""Definition for a Boolean flag.
|
|
|
|
A boolean flag is always optional with a default value of False. The flag does
|
|
not take any values. Specifying the flag on the commandline will set it to
|
|
True.
|
|
"""
|
|
|
|
def _do_additional_validation(self) -> None:
|
|
if self.dest_type is not None:
|
|
raise ValueError("dest_type cannot be set for BooleanFlagDef")
|
|
if self.parse_to_dest_type_fn is not None:
|
|
raise ValueError("parse_to_dest_type_fn cannot be set for BooleanFlagDef")
|
|
if self.choices is not None:
|
|
raise ValueError("choices cannot be set for BooleanFlagDef")
|
|
|
|
def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None:
|
|
args = ["--" + self.name]
|
|
if self.short_name is not None:
|
|
args += ["-" + self.short_name]
|
|
|
|
kwargs = {}
|
|
if self.help_msg is not None:
|
|
kwargs["help"] = self.help_msg
|
|
|
|
parser.add_argument(
|
|
*args,
|
|
action=_BooleanValueStoreAction,
|
|
type=bool,
|
|
required=False,
|
|
default=False,
|
|
nargs=0,
|
|
**kwargs,
|
|
)
|