ai-station/.venv/lib/python3.12/site-packages/opentelemetry/instrumentation/weaviate/wrapper.py

313 lines
8.0 KiB
Python

import json
import logging
from typing import Optional
from opentelemetry import context as context_api
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.instrumentation.weaviate.utils import dont_throw
from opentelemetry.semconv.trace import SpanAttributes
logger = logging.getLogger(__name__)
def _with_tracer_wrapper(func):
"""Helper for providing tracer for wrapper functions."""
def _with_tracer(tracer, to_wrap):
def wrapper(wrapped, instance, args, kwargs):
return func(tracer, to_wrap, wrapped, instance, args, kwargs)
return wrapper
return _with_tracer
def _set_span_attribute(span, name, value):
if value is not None:
if value != "":
span.set_attribute(name, value)
return
@_with_tracer_wrapper
def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs):
"""Instruments and calls every function defined in TO_WRAP."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return wrapped(*args, **kwargs)
name = to_wrap.get("span_name")
with tracer.start_as_current_span(name) as span:
span.set_attribute(SpanAttributes.DB_SYSTEM, "weaviate")
span.set_attribute(SpanAttributes.DB_OPERATION, to_wrap.get("method"))
obj = to_wrap.get("object")
instrumentor = InstrumentorFactory.from_name(obj)
if instrumentor:
instrumentor.instrument(to_wrap.get("method"), span, args, kwargs)
return_value = wrapped(*args, **kwargs)
return return_value
def count_or_none(obj):
if obj:
return len(obj)
return None
class ArgsGetter:
"""Helper to make sure we get arguments regardless
of whether they were passed as args or as kwargs.
Additionally, cast serializes dicts to JSON string.
"""
def __init__(self, args, kwargs):
self.args = args
self.kwargs = kwargs
def __call__(self, index, name):
try:
obj = self.args[index]
except IndexError:
obj = self.kwargs.get(name)
if obj:
try:
return json.dumps(obj)
except json.decoder.JSONDecodeError:
logger.warning(
"Failed to decode argument (%s) (%s) to JSON", index, name
)
class _Instrumentor:
def map_attributes(self, span, method_name, attributes, args, kwargs):
getter = ArgsGetter(args, kwargs)
for idx, attribute in enumerate(attributes):
_set_span_attribute(
span,
f"{self.namespace}.{method_name}.{attribute}",
getter(idx, attribute),
)
@dont_throw
def instrument(self, method_name, span, args, kwargs):
attributes = self.mapped_attributes.get(method_name)
if attributes:
self.map_attributes(span, method_name, attributes, args, kwargs)
class _SchemaInstrumentorV3(_Instrumentor):
"""v3, replaced in v4 by _CollectionsInstrumentor"""
namespace = "db.weaviate.schema"
mapped_attributes = {
"get": ["class_name"],
"create_class": ["schema_class"],
"create": ["schema"],
"delete_class": ["class_name"],
}
class _CollectionsInstrumentor(_Instrumentor):
namespace = "db.weaviate.collections"
mapped_attributes = {
"create": ["name"],
"create_from_dict": ["config"],
"get": ["name"],
"delete": ["name"],
}
class _DataObjectInstrumentorV3(_Instrumentor):
"""v3, replaced in v4 by _DataObjectInstrumentor"""
namespace = "weaviate.data.crud_data"
mapped_attributes = {
"create": [
"data_object",
"class_name",
"uuid",
"vector",
"consistency_level",
"tenant",
],
"validate": [
"data_object",
"class_name",
"uuid",
"vector",
],
"get": [
"uuid",
"additional_properties",
"with_vector",
"class_name",
"node_name",
"consistency_level",
"limit",
"after",
"offset",
"sort",
"tenant",
],
}
class _DataObjectInstrumentor(_Instrumentor):
namespace = "weaviate.collections.data"
mapped_attributes = {
"insert": [
"properties",
"references",
"uuid",
"vector",
],
"replace": [
"uuid",
"properties",
"references",
"vector",
],
"update": [
"uuid",
"properties",
"references",
"vector",
],
}
class _BatchInstrumentorV3(_Instrumentor):
"""v3, replaced in v4 by _BatchInstrumentor"""
namespace = "db.weaviate.batch"
mapped_attributes = {
"add_data_object": [
"data_object",
"class_name",
"uuid",
"vector",
"tenant",
],
"flush": [],
}
class _BatchInstrumentor(_Instrumentor):
namespace = "db.weaviate.collections.batch"
mapped_attributes = {
"add_object": [
"properties",
"references",
"uuid",
"vector",
],
}
class _QueryInstrumentorV3(_Instrumentor):
"""v3, replaced in v4 by _QueryInstrumentor"""
namespace = "db.weaviate.query"
mapped_attributes = {
"get": [
"class_name",
"properties",
],
"aggregate": ["class_name"],
"raw": ["gql_query"],
}
class _QueryInstrumentor(_Instrumentor):
namespace = "db.weaviate.collections.query"
mapped_attributes = {
"fetch_object_by_id": [
"uuid",
"include_vector",
"return_properties",
"return_references",
],
"fetch_objects": [
"limit",
"offset",
"after",
"filters",
"sort",
"include_vector",
"return_metadata",
"return_properties",
"return_references",
],
}
class _AggregateBuilderInstrumentor(_Instrumentor):
namespace = "db.weaviate.gql.aggregate"
mapped_attributes = {
"do": [],
}
class _GetBuilderInstrumentorV3(_Instrumentor):
"""v3, replaced in v4 by _GetBuilderInstrumentor"""
namespace = "db.weaviate.query.get"
mapped_attributes = {
"do": [],
}
class _GetBuilderInstrumentor(_Instrumentor):
namespace = "db.weaviate.gql.get"
mapped_attributes = {
"do": [],
}
class _GraphQLInstrumentor(_Instrumentor):
namespace = "db.weaviate.gql.filter"
mapped_attributes = {
"do": [],
}
class _RawInstrumentor(_Instrumentor):
namespace = "db.weaviate.client"
mapped_attributes = {
"graphql_raw_query": ["gql_query", ],
}
class InstrumentorFactory:
@classmethod
def from_name(cls, name: str) -> Optional[_Instrumentor]:
if name == "Schema":
return _SchemaInstrumentorV3()
elif name == "DataObject":
return _DataObjectInstrumentorV3()
elif name == "Batch":
return _BatchInstrumentorV3()
elif name == "Query":
return _QueryInstrumentorV3()
elif name == "GetBuilder":
return _GetBuilderInstrumentorV3()
if name == "_Collections":
return _CollectionsInstrumentor()
if name == "_DataCollection":
return _DataObjectInstrumentor()
if name == "_BatchCollection":
return _BatchInstrumentor()
if name in ("_FetchObjectByIDQuery", "_FetchObjectsQuery", "_QueryGRPC"):
return _QueryInstrumentor()
if name == "AggregateBuilder":
return _AggregateBuilderInstrumentor()
if name == "GetBuilder":
return _GetBuilderInstrumentor()
if name == "GraphQL":
return _GraphQLInstrumentor()
if name == "WeaviateClient":
return _RawInstrumentor()
return None