ai-station/.venv/lib/python3.12/site-packages/opentelemetry/instrumentation/bedrock/streaming_wrapper.py

75 lines
2.5 KiB
Python
Raw Normal View History

import json
from opentelemetry.instrumentation.bedrock.utils import (
dont_throw,
)
from wrapt import ObjectProxy
class StreamingWrapper(ObjectProxy):
def __init__(
self,
response,
stream_done_callback=None,
):
super().__init__(response)
self._stream_done_callback = stream_done_callback
self._accumulating_body = {}
def __iter__(self):
it = iter(self.__wrapped__)
done = False
while not done:
try:
event = next(it)
self._process_event(event)
yield event
except StopIteration:
done = True
if self._stream_done_callback:
self._stream_done_callback(self._accumulating_body)
@dont_throw
def _process_event(self, event):
chunk = event.get("chunk")
if not chunk:
return
decoded_chunk = json.loads(chunk.get("bytes").decode())
type = decoded_chunk.get("type")
if type is None:
self._accumulate_events(decoded_chunk)
elif type == "message_start":
self._accumulating_body = decoded_chunk.get("message")
elif type == "content_block_start":
self._accumulating_body["content"].append(
decoded_chunk.get("content_block")
)
elif type == "content_block_delta":
self._accumulating_body["content"][-1]["text"] += decoded_chunk.get(
"delta"
).get("text")
elif type == "message_stop":
self._accumulating_body["invocation_metrics"] = decoded_chunk.get(
"amazon-bedrock-invocationMetrics"
)
def _accumulate_events(self, event):
print(self._accumulating_body)
for key in event:
if key == "contentBlockDelta":
delta = event.get(key).get("delta", {}).get("text")
if "outputText" in self._accumulating_body:
self._accumulating_body["outputText"] += delta
else:
self._accumulating_body["outputText"] = delta
elif key in self._accumulating_body:
self._accumulating_body[key] += event.get(key)
elif key == "messageStop":
self._accumulating_body["stop_reason"] = event.get(key).get(
"stopReason"
)
else:
self._accumulating_body[key] = event.get(key)