ai-station/.venv/lib/python3.12/site-packages/opentelemetry/instrumentation/milvus/wrapper.py

530 lines
15 KiB
Python

import time
from opentelemetry.instrumentation.milvus.utils import dont_throw
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
from opentelemetry import context as context_api
from opentelemetry.instrumentation.utils import (
_SUPPRESS_INSTRUMENTATION_KEY,
)
from opentelemetry.semconv_ai import Events, EventAttributes
from opentelemetry.semconv_ai import SpanAttributes as AISpanAttributes
from pymilvus.client.types import Status
from pymilvus.exceptions import ErrorCode
code_to_error_type = {}
for name in dir(Status):
value = getattr(Status, name)
if isinstance(value, int):
code_to_error_type[value] = name
code_to_error_type.update({code.value: code.name for code in ErrorCode})
def _with_tracer_wrapper(func):
"""Helper for providing tracer for wrapper functions."""
def _with_tracer(
tracer,
query_duration_metric,
distance_metric,
insert_units_metric,
upsert_units_metric,
delete_units_metric,
to_wrap):
def wrapper(wrapped, instance, args, kwargs):
return func(
tracer,
query_duration_metric,
distance_metric,
insert_units_metric,
upsert_units_metric,
delete_units_metric,
to_wrap,
wrapped,
args,
kwargs)
return wrapper
return _with_tracer
def _set_span_attribute(span, name, value):
if value is not None:
if value != "":
span.set_attribute(name, value)
return
@_with_tracer_wrapper
def _wrap(
tracer,
query_duration_metric,
distance_metric,
insert_units_metric,
upsert_units_metric,
delete_units_metric,
to_wrap,
wrapped,
args,
kwargs
):
"""Instruments and calls every function defined in TO_WRAP."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return wrapped(*args, **kwargs)
method = to_wrap.get("method")
name = to_wrap.get("span_name")
with tracer.start_as_current_span(name) as span:
span.set_attribute(SpanAttributes.DB_SYSTEM, "milvus")
span.set_attribute(SpanAttributes.DB_OPERATION, to_wrap.get("method"))
if method == "insert":
_set_insert_attributes(span, kwargs)
elif method == "upsert":
_set_upsert_attributes(span, kwargs)
elif method == "delete":
_set_delete_attributes(span, kwargs)
elif method == "search":
_set_search_attributes(span, kwargs)
elif method == "get":
_set_get_attributes(span, kwargs)
elif method == "query":
_set_query_attributes(span, kwargs)
elif method == "create_collection":
_set_create_collection_attributes(span, kwargs)
elif method == "hybrid_search":
_set_hybrid_search_attributes(span, kwargs)
try:
start_time = time.time()
return_value = wrapped(*args, **kwargs)
end_time = time.time()
if method == "query":
_add_query_result_events(span, return_value)
if method == "search" or method == "hybrid_search":
_add_search_result_events(span, return_value)
except Exception as e:
error_type = code_to_error_type.get(
getattr(e, "code", None), type(e).__name__
)
span.set_attribute(ERROR_TYPE, error_type)
raise
shared_attributes = {
SpanAttributes.DB_SYSTEM: "milvus",
SpanAttributes.DB_OPERATION: method,
}
duration = end_time - start_time
if duration > 0 and query_duration_metric and method == "query":
query_duration_metric.record(duration, shared_attributes)
if return_value:
if method == "search" or method == "hybrid_search":
set_search_response(distance_metric, shared_attributes, return_value)
_set_response_attributes(
insert_units_metric,
upsert_units_metric,
delete_units_metric,
shared_attributes,
return_value,
)
return return_value
def _encode_filter(_filter):
_filter_str = None
if _filter:
_filter_str = str(_filter)
return _filter_str
def _encode_partition_name(partition_name):
partition_name_str = None
if partition_name:
partition_name_str = str(partition_name)
return partition_name_str
def _encode_include(include):
include_str = None
if include:
include_str = str(include)
return include_str
def count_or_none(obj):
if obj:
return len(obj)
return None
@dont_throw
def _set_response_attributes(
insert_units_metric,
upsert_units_metric,
delete_units_metric,
shared_attributes,
response
):
if not isinstance(response, dict):
return
if 'upsert_count' in response:
upsert_count = response['upsert_count'] or 0
upsert_units_metric.add(upsert_count, shared_attributes)
if 'insert_count' in response:
insert_count = response['insert_count'] or 0
insert_units_metric.add(insert_count, shared_attributes)
if 'delete_count' in response:
delete_count = response['delete_count'] or 0
delete_units_metric.add(delete_count, shared_attributes)
@dont_throw
def _set_create_collection_attributes(span, kwargs):
_set_span_attribute(
span,
AISpanAttributes.MILVUS_CREATE_COLLECTION_NAME,
kwargs.get("collection_name"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_CREATE_COLLECTION_DIMENSION,
kwargs.get("dimension"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_CREATE_COLLECTION_PRIMARY_FIELD,
kwargs.get("primary_field_name"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_CREATE_COLLECTION_METRIC_TYPE,
kwargs.get("metric_type"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_CREATE_COLLECTION_TIMEOUT,
kwargs.get("timeout"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_CREATE_COLLECTION_ID_TYPE,
kwargs.get("id_type"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_CREATE_COLLECTION_VECTOR_FIELD,
kwargs.get("vector_field_name"),
)
@dont_throw
def _set_insert_attributes(span, kwargs):
_set_span_attribute(
span,
AISpanAttributes.MILVUS_INSERT_COLLECTION_NAME,
kwargs.get("collection_name"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_INSERT_DATA_COUNT,
count_or_none(kwargs.get("data")),
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_INSERT_TIMEOUT, kwargs.get("timeout")
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_INSERT_PARTITION_NAME,
_encode_partition_name(kwargs.get("partition_name")),
)
@dont_throw
def _set_get_attributes(span, kwargs):
_set_span_attribute(
span, AISpanAttributes.MILVUS_GET_COLLECTION_NAME, kwargs.get("collection_name")
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_GET_IDS_COUNT, count_or_none(kwargs.get("ids"))
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_GET_OUTPUT_FIELDS_COUNT,
count_or_none(kwargs.get("output_fields")),
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_GET_TIMEOUT, kwargs.get("timeout")
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_GET_PARTITION_NAMES_COUNT,
count_or_none(kwargs.get("partition_names")),
)
@dont_throw
def _set_search_attributes(span, kwargs):
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_COLLECTION_NAME,
kwargs.get("collection_name"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_DATA_COUNT,
count_or_none(kwargs.get("data")),
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_SEARCH_FILTER, kwargs.get("filter")
)
_set_span_attribute(span, AISpanAttributes.MILVUS_SEARCH_LIMIT, kwargs.get("limit"))
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_OUTPUT_FIELDS_COUNT,
count_or_none(kwargs.get("output_fields")),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_SEARCH_PARAMS,
_encode_include(kwargs.get("search_params")),
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_SEARCH_TIMEOUT, kwargs.get("timeout")
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_PARTITION_NAMES_COUNT,
count_or_none(kwargs.get("partition_names")),
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_SEARCH_ANNS_FIELD, kwargs.get("anns_field")
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_PARTITION_NAMES,
_encode_partition_name(kwargs.get("partition_names")),
)
query_vectors = kwargs.get("data", [])
vector_dims = [len(vec) for vec in query_vectors]
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_QUERY_VECTOR_DIMENSION,
_encode_include(vector_dims),
)
@dont_throw
def _set_hybrid_search_attributes(span, kwargs):
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_COLLECTION_NAME,
kwargs.get("collection_name"),
)
reqs_info = []
for req in kwargs.get("reqs", []):
req_info = {
"anns_field": req.anns_field,
"param": req.param,
}
reqs_info.append(req_info)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_ANNSEARCH_REQUEST,
_encode_include(reqs_info),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_RANKER_TYPE,
_encode_include(type(kwargs.get("ranker")).__name__),
)
_set_span_attribute(span, AISpanAttributes.MILVUS_SEARCH_LIMIT, kwargs.get("limit"))
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_DATA_COUNT,
count_or_none(kwargs.get("reqs")),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_OUTPUT_FIELDS_COUNT,
count_or_none(kwargs.get("output_fields")),
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_SEARCH_TIMEOUT, kwargs.get("timeout")
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_PARTITION_NAMES_COUNT,
count_or_none(kwargs.get("partition_names")),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_PARTITION_NAMES,
_encode_partition_name(kwargs.get("partition_names")),
)
@dont_throw
def _set_query_attributes(span, kwargs):
_set_span_attribute(
span,
AISpanAttributes.MILVUS_QUERY_COLLECTION_NAME,
kwargs.get("collection_name"),
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_QUERY_FILTER, _encode_filter(kwargs.get("filter"))
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_QUERY_OUTPUT_FIELDS_COUNT,
count_or_none(kwargs.get("output_fields")),
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_QUERY_TIMEOUT, kwargs.get("timeout")
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_QUERY_IDS_COUNT, count_or_none(kwargs.get("ids"))
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_QUERY_PARTITION_NAMES_COUNT,
count_or_none(kwargs.get("partition_names")),
)
_set_span_attribute(span, AISpanAttributes.MILVUS_QUERY_LIMIT, kwargs.get("limit"))
@dont_throw
def _add_query_result_events(span, kwargs):
for element in kwargs:
span.add_event(name=Events.DB_QUERY_RESULT.value, attributes=element)
@dont_throw
def _add_search_result_events(span, kwargs):
all_distances = []
total_matches = 0
single_query = len(kwargs) == 1
def set_query_stats(query_idx, distances, match_ids):
"""Helper function to set per-query stats in the span."""
_set_span_attribute(
span,
f"{AISpanAttributes.MILVUS_SEARCH_RESULT_COUNT}_{query_idx}",
len(distances),
)
def set_global_stats():
"""Helper function to set global stats for a single query."""
_set_span_attribute(
span, AISpanAttributes.MILVUS_SEARCH_RESULT_COUNT, total_matches
)
for query_idx, query_results in enumerate(kwargs):
query_distances = []
query_match_ids = []
for match in query_results:
distance = float(match["distance"])
query_distances.append(distance)
all_distances.append(distance)
total_matches += 1
query_match_ids.append(match["id"])
span.add_event(
Events.DB_SEARCH_RESULT.value,
attributes={
EventAttributes.DB_SEARCH_RESULT_QUERY_ID.value: query_idx,
EventAttributes.DB_SEARCH_RESULT_ID.value: match["id"],
EventAttributes.DB_SEARCH_RESULT_DISTANCE.value: str(distance),
EventAttributes.DB_SEARCH_RESULT_ENTITY.value: _encode_include(
match["entity"]
),
},
)
if not single_query:
set_query_stats(query_idx, query_distances, query_match_ids)
if single_query:
set_global_stats()
@dont_throw
def _set_upsert_attributes(span, kwargs):
_set_span_attribute(
span,
AISpanAttributes.MILVUS_UPSERT_COLLECTION_NAME,
kwargs.get("collection_name"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_UPSERT_DATA_COUNT,
count_or_none(kwargs.get("data")),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_UPSERT_TIMEOUT,
kwargs.get("timeout"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_UPSERT_PARTITION_NAME,
_encode_partition_name(kwargs.get("partition_name")),
)
@dont_throw
def _set_delete_attributes(span, kwargs):
_set_span_attribute(
span,
AISpanAttributes.MILVUS_DELETE_COLLECTION_NAME,
kwargs.get("collection_name"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_DELETE_TIMEOUT,
kwargs.get("timeout"),
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_DELETE_PARTITION_NAME,
_encode_partition_name(kwargs.get("partition_name")),
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_DELETE_IDS_COUNT, count_or_none(kwargs.get("ids"))
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_DELETE_FILTER,
_encode_filter(kwargs.get("filter")),
)
@dont_throw
def set_search_response(distance_metric, shared_attributes, response):
for query_result in response:
for match in query_result:
distance = match.get("distance")
if distance_metric and distance is not None:
distance_metric.record(distance, shared_attributes)