ai-station/app.py

464 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import uuid
import shutil
import requests
import time
import json
from datetime import datetime
from typing import Optional, Dict, List, Any
import chainlit as cl
import ollama
from docling.document_converter import DocumentConverter
from qdrant_client import AsyncQdrantClient
# CORREZIONE IMPORT: Importiamo le classi necessarie direttamente dalla libreria
from qdrant_client.models import PointStruct, Distance, VectorParams, SparseVectorParams, Prefetch
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.types import ThreadDict
from functools import lru_cache
# === FIX IMPORT ROBUSTO ===
try:
from chainlit.data.storage_clients import BaseStorageClient
except ImportError:
try:
from chainlit.data.base import BaseStorageClient
except ImportError:
from chainlit.data.storage_clients.base import BaseStorageClient
# === CONFIGURAZIONE ===
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://ai_user:secure_password_here@postgres:5432/ai_station")
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://192.168.1.243:11434")
QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333")
BGE_API_URL = os.getenv("BGE_API_URL", "http://192.168.1.243:8001/embed")
VISION_MODEL = "minicpm-v"
DEFAULT_TEXT_MODEL = "glm-4.6:cloud"
WORKSPACES_DIR = "./workspaces"
STORAGE_DIR = "./.files"
os.makedirs(STORAGE_DIR, exist_ok=True)
os.makedirs(WORKSPACES_DIR, exist_ok=True)
# === MAPPING UTENTI ===
USER_PROFILES = {
"giuseppe@defranceschi.pro": { "role": "admin", "name": "Giuseppe", "workspace": "admin_workspace", "rag_collection": "admin_docs", "capabilities": ["debug", "all"], "show_code": True },
"giuseppe.defranceschi@gmail.com": { "role": "admin", "name": "Giuseppe", "workspace": "admin_workspace", "rag_collection": "admin_docs", "capabilities": ["debug", "all"], "show_code": True },
"federica.tecchio@gmail.com": { "role": "business", "name": "Federica", "workspace": "business_workspace", "rag_collection": "contabilita", "capabilities": ["basic_chat"], "show_code": False },
"riccardob545@gmail.com": { "role": "engineering", "name": "Riccardo", "workspace": "engineering_workspace", "rag_collection": "engineering_docs", "capabilities": ["code"], "show_code": True },
"giuliadefranceschi05@gmail.com": { "role": "architecture", "name": "Giulia", "workspace": "architecture_workspace", "rag_collection": "architecture_manuals", "capabilities": ["visual"], "show_code": False }
}
# === STORAGE CLIENT ===
class LocalStorageClient(BaseStorageClient):
def __init__(self, storage_path: str):
self.storage_path = storage_path
os.makedirs(storage_path, exist_ok=True)
async def upload_file(self, object_key: str, data: bytes, mime: str = "application/octet-stream", overwrite: bool = True) -> Dict[str, str]:
file_path = os.path.join(self.storage_path, object_key)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f: f.write(data)
return {"object_key": object_key, "url": f"/files/{object_key}"}
async def get_read_url(self, object_key: str) -> str: return f"/files/{object_key}"
async def delete_file(self, object_key: str) -> bool:
path = os.path.join(self.storage_path, object_key)
if os.path.exists(path): os.remove(path); return True
return False
async def close(self): pass
@cl.data_layer
def get_data_layer():
return SQLAlchemyDataLayer(conninfo=DATABASE_URL, storage_provider=LocalStorageClient(STORAGE_DIR))
# === OAUTH & UTILS ===
@cl.oauth_callback
def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, str], default_user: cl.User) -> Optional[cl.User]:
if provider_id == "google":
email = raw_user_data.get("email", "").lower()
profile = USER_PROFILES.get(email, USER_PROFILES.get("guest", {"role": "guest", "name": "Guest", "workspace": "guest", "rag_collection": "public", "show_code": False}))
default_user.metadata.update({"role": profile["role"], "workspace": profile["workspace"], "rag_collection": profile["rag_collection"], "show_code": profile["show_code"], "display_name": profile["name"]})
return default_user
return default_user
def create_workspace(workspace_name: str) -> str:
path = os.path.join(WORKSPACES_DIR, workspace_name)
os.makedirs(path, exist_ok=True)
return path
# === CORE: DOCLING ===
def process_file_with_docling(file_path: str) -> str:
try:
converter = DocumentConverter()
result = converter.convert(file_path)
return result.document.export_to_markdown()
except Exception as e:
print(f"❌ Docling Error: {e}")
return ""
# === CORE: BGE-M3 CLIENT ===
def get_bge_embeddings(text: str) -> Optional[Dict[str, Any]]:
try:
payload = {"texts": [text[:8000]]}
response = requests.post(BGE_API_URL, json=payload, timeout=30)
response.raise_for_status()
data = response.json().get("data", [])
if data:
return data[0]
return None
except Exception as e:
print(f"❌ BGE API Error: {e}")
return None
# === CORE: QDRANT ===
async def ensure_collection(collection_name: str):
client = AsyncQdrantClient(url=QDRANT_URL)
if not await client.collection_exists(collection_name):
await client.create_collection(
collection_name=collection_name,
vectors_config={"dense": VectorParams(size=1024, distance=Distance.COSINE)},
sparse_vectors_config={"sparse": SparseVectorParams()}
)
async def index_document(file_name: str, content: str, collection_name: str):
await ensure_collection(collection_name)
client = AsyncQdrantClient(url=QDRANT_URL)
chunk_size = 2000
overlap = 200
points = []
for i in range(0, len(content), chunk_size - overlap):
chunk = content[i : i + chunk_size]
embedding_data = get_bge_embeddings(chunk)
if embedding_data:
points.append(PointStruct(
id=str(uuid.uuid4()),
vector={
"dense": embedding_data["dense"],
"sparse": embedding_data["sparse"]
},
payload={
"file_name": file_name,
"content": chunk,
"indexed_at": datetime.now().isoformat()
}
))
if points:
await client.upsert(collection_name=collection_name, points=points)
return len(points)
return 0
async def search_hybrid(query: str, collection_name: str, limit: int = 4) -> str:
client = AsyncQdrantClient(url=QDRANT_URL)
if not await client.collection_exists(collection_name): return ""
query_emb = get_bge_embeddings(query)
if not query_emb: return ""
# CORREZIONE QUI: Usiamo l'oggetto Prefetch importato correttamente
results = await client.query_points(
collection_name=collection_name,
prefetch=[
Prefetch(
query=query_emb["sparse"],
using="sparse",
limit=limit * 2
)
],
query=query_emb["dense"],
using="dense",
limit=limit
)
context = []
for hit in results.points:
context.append(f"--- DA {hit.payload['file_name']} ---\n{hit.payload['content']}")
return "\n\n".join(context)
# === Caching Embeddings ===
@lru_cache(maxsize=1000)
def get_bge_embeddings_cached(text: str):
"""Cache per query ripetute"""
return get_bge_embeddings(text)
# === CHAINLIT HANDLERS ===
@cl.on_chat_start
async def start():
# 1. Profilo utente
user = cl.user_session.get("user")
email = user.identifier if user else "guest"
profile = USER_PROFILES.get(email, USER_PROFILES["giuseppe@defranceschi.pro"])
cl.user_session.set("profile", profile)
create_workspace(profile["workspace"])
# 2. Badge HTML personalizzato
role_color = {
"admin": "#e74c3c",
"engineering": "#3498db",
"business": "#2ecc71",
"architecture": "#9b59b6",
}.get(profile["role"], "#95a5a6")
badge_html = f"""
<div style="background:{role_color}; padding:8px; border-radius:8px; margin-bottom:16px;">
👤 <b>{profile['name']}</b> | 🔧 {profile['role'].upper()} | 📁 {profile['workspace']}
</div>
"""
await cl.Message(content=badge_html).send()
# 3. Settings UI
settings = await cl.ChatSettings(
[
cl.input_widget.Slider(
id="top_k",
label="Numero Documenti RAG",
initial=4,
min=1,
max=10,
step=1,
),
cl.input_widget.Select(
id="vision_detail",
label="Dettaglio Analisi Immagini",
values=["auto", "low", "high"],
initial_value="auto",
),
cl.input_widget.TextInput(
id="system_instruction",
label="Istruzione Sistema Custom (opzionale)",
initial="",
placeholder="Es: Rispondi sempre in formato tecnico...",
),
cl.input_widget.Select(
id="model",
label="Modello di Ragionamento",
values=[DEFAULT_TEXT_MODEL, "llama3.2", "mistral", "qwen2.5-coder:32b"],
initial_value=DEFAULT_TEXT_MODEL,
),
cl.input_widget.Slider(
id="temperature",
label="Creatività (Temperatura)",
initial=0.3,
min=0,
max=1,
step=0.1,
),
cl.input_widget.Switch(
id="rag_enabled",
label="Usa Conoscenza Documenti (RAG)",
initial=True,
),
]
).send()
cl.user_session.set("settings", settings)
# 4. Messaggio iniziale (opzionale)
await cl.Message(
content=(
f"🚀 **Vision-RAG Hybrid System Online**\n"
f"Utente: {profile['name']} | Workspace: {profile['workspace']}\n"
f"Engine: Docling + BGE-M3 + {VISION_MODEL}"
)
).send()
cl.user_session.set("settings", settings)
await cl.Message(f"🚀 **Vision-RAG Hybrid System Online**\nUtente: {profile['name']} | Workspace: {profile['workspace']}\nEngine: Docling + BGE-M3 + {VISION_MODEL}").send()
@cl.on_settings_update
async def setup_agent(settings):
cl.user_session.set("settings", settings)
await cl.Message(content=f"✅ Impostazioni aggiornate: Modello {settings['model']}").send()
async def log_metrics(metrics: dict):
# Versione minima: log su stdout
print("[METRICS]", metrics)
# In futuro puoi:
# - salvarle in Postgres
# - mandarle a Prometheus / Grafana
# - scriverle su file JSON per analisi settimanale
# - Resume Chat Handler
@cl.on_chat_resume
async def on_chat_resume(thread: ThreadDict):
"""
Viene chiamato quando l'utente clicca 'Riprendi' su una chat archiviata.
Chainlit carica già i messaggi nella UI, qui puoi solo ripristinare la sessione.
"""
# Se vuoi, puoi recuperare l'identifier dellutente dal thread
user_identifier = thread.get("userIdentifier")
profile = USER_PROFILES.get(
user_identifier,
USER_PROFILES["giuseppe@defranceschi.pro"],
)
cl.user_session.set("profile", profile)
# Puoi anche ripristinare eventuale stato custom (es: impostazioni di default)
# oppure semplicemente salutare lutente
await cl.Message(
content="👋 Bentornato! Possiamo riprendere da questa conversazione."
).send()
@cl.on_message
async def main(message: cl.Message):
start_time = time.time()
profile = cl.user_session.get("profile")
settings = cl.user_session.get("settings", {})
selected_model = settings.get("model", DEFAULT_TEXT_MODEL)
temperature = settings.get("temperature", 0.3)
rag_enabled = settings.get("rag_enabled", True)
workspace = create_workspace(profile["workspace"])
images_for_vision = []
doc_context = ""
rag_context = "" # ← la inizializzi qui, così esiste sempre
# 1. GESTIONE FILE
if message.elements:
for element in message.elements:
file_path = os.path.join(workspace, element.name)
shutil.copy(element.path, file_path)
if "image" in element.mime:
images_for_vision.append(file_path)
msg_img = cl.Message(
content=f"👁️ Analizzo immagine **{element.name}** con {VISION_MODEL}..."
)
await msg_img.send()
with open(file_path, "rb") as img_file:
img_bytes = img_file.read()
client_sync = ollama.Client(host=OLLAMA_URL)
res = client_sync.chat(
model=VISION_MODEL,
messages=[{
"role": "user",
"content": (
"Analizza questa immagine tecnica. Trascrivi testi, codici "
"e descrivi diagrammi o tabelle in dettaglio."
),
"images": [img_bytes],
}],
)
desc = res["message"]["content"]
doc_context += f"\n\n[DESCRIZIONE IMMAGINE {element.name}]:\n{desc}"
msg_img.content = f"✅ Immagine analizzata:\n{desc[:200]}..."
await msg_img.update()
elif element.name.endswith((".pdf", ".docx")):
msg_doc = cl.Message(
content=f"📄 Leggo **{element.name}** con Docling (tabelle/formule)..."
)
await msg_doc.send()
markdown_content = process_file_with_docling(file_path)
if markdown_content:
chunks = await index_document(
element.name, markdown_content, profile["rag_collection"]
)
msg_doc.content = (
f"✅ **{element.name}**: Convertito e salvato {chunks} "
"frammenti nel DB vettoriale."
)
doc_context += (
f"\n\n[CONTENUTO FILE {element.name}]:\n"
f"{markdown_content[:1000]}..."
)
else:
msg_doc.content = f"❌ Errore lettura {element.name}"
await msg_doc.update()
# 2. RAG RETRIEVAL
if rag_enabled and not images_for_vision:
rag_context = await search_hybrid(
message.content, profile["rag_collection"]
)
final_context = ""
if rag_context:
final_context += f"CONTESTO RAG:\n{rag_context}\n"
if doc_context:
final_context += f"CONTESTO SESSIONE CORRENTE:\n{doc_context}\n"
system_prompt = (
"Sei un assistente tecnico esperto. Usa il contesto fornito "
"(incluso Markdown di tabelle e descrizioni immagini) per "
"rispondere con precisione. Cita i documenti fonte."
)
msg = cl.Message(content="")
await msg.send()
error = None
# 3. GENERAZIONE
try:
client_async = ollama.AsyncClient(host=OLLAMA_URL)
stream = await client_async.chat(
model=selected_model,
messages=[
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": f"Domanda: {message.content}\n\n{final_context}",
},
],
options={"temperature": temperature},
stream=True,
)
async for chunk in stream:
content = chunk["message"]["content"]
await msg.stream_token(content)
await msg.update()
except Exception as e:
error = str(e)
await msg.stream_token(f"❌ Errore AI: {error}")
await msg.update()
# 4. SALVATAGGIO CODICE
if profile["show_code"]:
code_blocks = re.findall(r"``````", msg.content, re.DOTALL)
if code_blocks:
for i, code in enumerate(code_blocks):
fname = f"script_{datetime.now().strftime('%H%M%S')}_{i}.py"
with open(os.path.join(workspace, fname), "w") as f:
f.write(code.strip())
await cl.Message(
content=f"💾 Script salvato: `{fname}`"
).send()
# 5. METRICHE (ALLA FINE)
elapsed = time.time() - start_time
# Se rag_context è una stringa concatenata, puoi stimare i "rag_hits"
# contando i separatori che usi in search_hybrid (es. '--- DA ')
if rag_context:
rag_hits = rag_context.count("--- DA ")
else:
rag_hits = 0
metrics = {
"response_time": elapsed,
"rag_hits": rag_hits,
"model": selected_model,
"user_role": profile["role"],
"error": error,
}
await log_metrics(metrics)