ai-station/.venv/lib/python3.12/site-packages/chainlit/slack/app.py

524 lines
15 KiB
Python

import asyncio
import os
import re
import uuid
from functools import partial
from typing import Dict, List, Optional, Union
import httpx
from slack_bolt.adapter.fastapi.async_handler import AsyncSlackRequestHandler
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
from slack_bolt.async_app import AsyncApp
from chainlit.config import config
from chainlit.context import ChainlitContext, HTTPSession, context, context_var
from chainlit.data import get_data_layer
from chainlit.element import Element, ElementDict
from chainlit.emitter import BaseChainlitEmitter
from chainlit.logger import logger
from chainlit.message import Message, StepDict
from chainlit.types import Feedback
from chainlit.user import PersistedUser, User
from chainlit.user_session import user_session
class SlackEmitter(BaseChainlitEmitter):
def __init__(
self,
session: HTTPSession,
app: AsyncApp,
channel_id: str,
say,
thread_ts: Optional[str] = None,
):
super().__init__(session)
self.app = app
self.channel_id = channel_id
self.say = say
self.thread_ts = thread_ts
async def send_element(self, element_dict: ElementDict):
if element_dict.get("display") != "inline":
return
persisted_file = self.session.files.get(element_dict.get("chainlitKey") or "")
file: Optional[Union[bytes, str]] = None
if persisted_file:
file = str(persisted_file["path"])
elif file_url := element_dict.get("url"):
async with httpx.AsyncClient() as client:
response = await client.get(file_url)
if response.status_code == 200:
file = response.content
if not file:
return
await self.app.client.files_upload_v2(
channel=self.channel_id,
thread_ts=self.thread_ts,
file=file,
title=element_dict.get("name"),
)
async def send_step(self, step_dict: StepDict):
step_type = step_dict.get("type")
is_assistant_message = step_type == "assistant_message"
is_empty_output = not step_dict.get("output")
if is_empty_output or not is_assistant_message:
return
enable_feedback = get_data_layer()
blocks: List[Dict] = [
{
"type": "section",
"text": {"type": "mrkdwn", "text": step_dict["output"]},
}
]
if enable_feedback:
current_run = context.current_run
scorable_id = current_run.id if current_run else step_dict.get("id")
blocks.append(
{
"type": "actions",
"elements": [
{
"action_id": "thumbdown",
"type": "button",
"text": {
"type": "plain_text",
"emoji": True,
"text": ":thumbsdown:",
},
"value": scorable_id,
},
{
"action_id": "thumbup",
"type": "button",
"text": {
"type": "plain_text",
"emoji": True,
"text": ":thumbsup:",
},
"value": scorable_id,
},
],
}
)
await self.say(
text=step_dict["output"], blocks=blocks, thread_ts=self.thread_ts
)
async def update_step(self, step_dict: StepDict):
is_assistant_message = step_dict["type"] == "assistant_message"
if not is_assistant_message:
return
await self.send_step(step_dict)
slack_app = AsyncApp(
token=os.environ.get("SLACK_BOT_TOKEN"),
signing_secret=os.environ.get("SLACK_SIGNING_SECRET"),
)
async def start_socket_mode():
"""
Initializes and starts the Slack app in Socket Mode asynchronously.
Uses the SLACK_WEBSOCKET_TOKEN from environment variables to authenticate.
"""
handler = AsyncSocketModeHandler(slack_app, os.environ.get("SLACK_WEBSOCKET_TOKEN"))
await handler.start_async()
def init_slack_context(
session: HTTPSession,
slack_channel_id: str,
event,
say,
thread_ts: Optional[str] = None,
) -> ChainlitContext:
emitter = SlackEmitter(
session=session,
app=slack_app,
channel_id=slack_channel_id,
say=say,
thread_ts=thread_ts,
)
context = ChainlitContext(session=session, emitter=emitter)
context_var.set(context)
user_session.set("slack_event", event)
user_session.set(
"fetch_slack_message_history",
partial(
fetch_message_history, channel_id=slack_channel_id, thread_ts=thread_ts
),
)
return context
slack_app_handler = AsyncSlackRequestHandler(slack_app)
users_by_slack_id: Dict[str, Union[User, PersistedUser]] = {}
USER_PREFIX = "slack_"
bot_user_id: Optional[str] = None
async def get_bot_user_id() -> Optional[str]:
"""Get and cache the bot's user ID."""
global bot_user_id
if bot_user_id:
return bot_user_id
try:
result = await slack_app.client.auth_test()
if result.get("ok"):
bot_user_id = result.get("user_id")
return bot_user_id
except Exception as e:
logger.error(f"Failed to get bot user ID: {e}")
return None
def clean_content(message: str):
cleaned_text = re.sub(r"<@[\w]+>", "", message).strip()
return cleaned_text
async def get_user(slack_user_id: str):
if slack_user_id in users_by_slack_id:
return users_by_slack_id[slack_user_id]
slack_user = await slack_app.client.users_info(user=slack_user_id)
slack_user_profile = slack_user["user"]["profile"]
user_identifier = slack_user_profile.get("email") or slack_user_id
user = User(identifier=USER_PREFIX + user_identifier, metadata=slack_user_profile)
users_by_slack_id[slack_user_id] = user
if data_layer := get_data_layer():
try:
persisted_user = await data_layer.create_user(user)
if persisted_user:
users_by_slack_id[slack_user_id] = persisted_user
except Exception as e:
logger.error(f"Error creating user: {e}")
return users_by_slack_id[slack_user_id]
async def fetch_message_history(
channel_id: str, thread_ts: Optional[str] = None, limit=30
):
if not thread_ts:
result = await slack_app.client.conversations_history(
channel=channel_id, limit=limit
)
else:
result = await slack_app.client.conversations_replies(
channel=channel_id, ts=thread_ts, limit=limit
)
if result["ok"]:
messages = result["messages"]
return messages
else:
raise Exception(f"Failed to fetch messages: {result['error']}")
async def download_slack_file(url, token):
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers)
if response.status_code == 200:
return response.content
else:
return None
async def download_slack_files(session: HTTPSession, files, token):
download_coros = [
download_slack_file(file.get("url_private"), token) for file in files
]
file_bytes_list = await asyncio.gather(*download_coros)
file_refs = []
for idx, file_bytes in enumerate(file_bytes_list):
if file_bytes:
name = files[idx].get("name")
mime_type = files[idx].get("mimetype")
file_ref = await session.persist_file(
name=name, mime=mime_type, content=file_bytes
)
file_refs.append(file_ref)
files_dicts = [
session.files[file["id"]] for file in file_refs if file["id"] in session.files
]
elements = [
Element.from_dict(
{
"id": file["id"],
"name": file["name"],
"path": str(file["path"]),
"chainlitKey": file["id"],
"display": "inline",
"type": Element.infer_type_from_mime(file["type"]),
}
)
for file in files_dicts
]
return elements
async def add_reaction_if_enabled(event, emoji: str = "eyes"):
if config.features.slack.reaction_on_message_received:
try:
await slack_app.client.reactions_add(
channel=event["channel"], timestamp=event["ts"], name=emoji
)
except Exception as e:
logger.warning(f"Failed to add reaction: {e}")
async def process_slack_message(
event,
say,
thread_id: str,
thread_name: Optional[str] = None,
bind_thread_to_user=False,
thread_ts: Optional[str] = None,
):
await add_reaction_if_enabled(event)
user = await get_user(event["user"])
channel_id = event["channel"]
text = event.get("text")
slack_files = event.get("files", [])
session_id = str(uuid.uuid4())
session = HTTPSession(
id=session_id,
thread_id=thread_id,
user=user,
client_type="slack",
)
ctx = init_slack_context(
session=session,
slack_channel_id=channel_id,
event=event,
say=say,
thread_ts=thread_ts,
)
file_elements = await download_slack_files(
session, slack_files, slack_app.client.token
)
if on_chat_start := config.code.on_chat_start:
await on_chat_start()
msg = Message(
content=clean_content(text),
elements=file_elements,
type="user_message",
author=user.metadata.get("real_name"),
)
if on_message := config.code.on_message:
await on_message(msg)
if on_chat_end := config.code.on_chat_end:
await on_chat_end()
if data_layer := get_data_layer():
user_id = None
if isinstance(user, PersistedUser):
user_id = user.id if bind_thread_to_user else None
try:
await data_layer.update_thread(
thread_id=thread_id,
name=thread_name or msg.content,
metadata=ctx.session.to_persistable(),
user_id=user_id,
)
except Exception as e:
logger.error(f"Error updating thread: {e}")
await ctx.session.delete()
@slack_app.event("app_home_opened")
async def handle_app_home_opened(event, say):
pass
@slack_app.event("app_mention")
async def handle_app_mentions(event, say):
thread_ts = event.get("thread_ts", event["ts"])
thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, thread_ts))
await process_slack_message(event, say, thread_id=thread_id, thread_ts=thread_ts)
@slack_app.event("message")
async def handle_message(message, say):
thread_ts = message.get("thread_ts", message["ts"])
thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, thread_ts))
await process_slack_message(
event=message,
say=say,
thread_id=thread_id,
bind_thread_to_user=True,
thread_ts=thread_ts,
)
@slack_app.event("reaction_added")
async def handle_reaction_added(event):
bot_id = await get_bot_user_id()
if event.get("user") == bot_id:
return
item = event.get("item", {})
channel_id = item.get("channel")
thread_ts = item.get("ts")
if not channel_id:
logger.warning(
"reaction_added event missing channel_id, skipping context setup"
)
return
try:
result = await slack_app.client.conversations_replies(
channel=channel_id, ts=thread_ts, limit=1
)
if result.get("ok"):
messages = result.get("messages")
message = messages[0]
message_user = message.get("user")
message_bot_id = message.get("bot_id")
if message_user != bot_id and message_bot_id != bot_id:
return
else:
raise Exception(
f"Failed to fetch message: {result.get('error', 'Unknown error')}"
)
except Exception as e:
logger.warning(f"Failed to fetch message for reaction: {e}")
return
async def say(text: str = "", **kwargs):
await slack_app.client.chat_postMessage(
channel=channel_id, text=text, thread_ts=thread_ts, **kwargs
)
user = await get_user(event["user"])
thread_id = (
str(uuid.uuid5(uuid.NAMESPACE_DNS, thread_ts))
if thread_ts
else str(uuid.uuid4())
)
session_id = str(uuid.uuid4())
session = HTTPSession(
id=session_id,
thread_id=thread_id,
user=user,
client_type="slack",
)
ctx = init_slack_context(
session=session,
slack_channel_id=channel_id,
event=event,
say=say,
thread_ts=thread_ts,
)
try:
if on_chat_start := config.code.on_chat_start:
await on_chat_start()
if on_slack_reaction_added := config.code.on_slack_reaction_added:
await on_slack_reaction_added(event)
finally:
await ctx.session.delete()
@slack_app.block_action("thumbdown")
async def thumb_down(ack, context, body):
await ack()
step_id = body["actions"][0]["value"]
thread_ts = body["message"]["thread_ts"]
thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, thread_ts))
if data_layer := get_data_layer():
feedback = Feedback(forId=step_id, value=0, threadId=thread_id)
await data_layer.upsert_feedback(feedback)
text = body["message"]["text"]
blocks = body["message"]["blocks"]
updated_blocks = [block for block in blocks if block["type"] != "actions"]
updated_blocks.append(
{
"type": "section",
"text": {"type": "mrkdwn", "text": ":thumbsdown: Feedback received."},
}
)
await context.client.chat_update(
channel=body["channel"]["id"],
ts=body["container"]["message_ts"],
text=text,
blocks=updated_blocks,
)
@slack_app.block_action("thumbup")
async def thumb_up(ack, context, body):
await ack()
step_id = body["actions"][0]["value"]
thread_ts = body["message"]["thread_ts"]
thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, thread_ts))
if data_layer := get_data_layer():
feedback = Feedback(forId=step_id, value=1, threadId=thread_id)
await data_layer.upsert_feedback(feedback)
text = body["message"]["text"]
blocks = body["message"]["blocks"]
updated_blocks = [block for block in blocks if block["type"] != "actions"]
updated_blocks.append(
{
"type": "section",
"text": {"type": "mrkdwn", "text": ":thumbsup: Feedback received."},
}
)
await context.client.chat_update(
channel=body["channel"]["id"],
ts=body["container"]["message_ts"],
text=text,
blocks=updated_blocks,
)