604 lines
20 KiB
Python
604 lines
20 KiB
Python
import os
|
|
import re
|
|
import json
|
|
import uuid
|
|
import shutil
|
|
import requests
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Optional, Dict, Any, List
|
|
|
|
import chainlit as cl
|
|
import ollama
|
|
|
|
from docling.document_converter import DocumentConverter
|
|
from qdrant_client import AsyncQdrantClient
|
|
from qdrant_client.models import (
|
|
PointStruct,
|
|
Distance,
|
|
VectorParams,
|
|
SparseVectorParams,
|
|
Prefetch,
|
|
)
|
|
|
|
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
|
|
from chainlit.types import ThreadDict
|
|
from functools import lru_cache
|
|
|
|
# === FIX IMPORT ROBUSTO Storage Client ===
|
|
try:
|
|
from chainlit.data.storage_clients import BaseStorageClient
|
|
except ImportError:
|
|
try:
|
|
from chainlit.data.base import BaseStorageClient
|
|
except ImportError:
|
|
from chainlit.data.storage_clients.base import BaseStorageClient
|
|
|
|
|
|
# =========================
|
|
# CONFIG
|
|
# =========================
|
|
# SECURITY: Fail fast if DATABASE_URL is not set
|
|
DATABASE_URL = os.getenv("DATABASE_URL")
|
|
if not DATABASE_URL:
|
|
raise EnvironmentError(
|
|
"DATABASE_URL environment variable is required. "
|
|
"Set it via: export DATABASE_URL='postgresql+asyncpg://user:pass@host:5432/db'"
|
|
)
|
|
|
|
# Service URLs - can be overridden via environment
|
|
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://192.168.1.243:11434")
|
|
QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333")
|
|
BGE_API_URL = os.getenv("BGE_API_URL", "http://192.168.1.243:8001/embed")
|
|
|
|
VISION_MODEL = "minicpm-v"
|
|
|
|
DEFAULT_TEXT_MODEL = "glm-4.6:cloud"
|
|
MINIMAX_MODEL = "minimax-m2.1:cloud"
|
|
|
|
MODEL_CHOICES = [
|
|
DEFAULT_TEXT_MODEL,
|
|
MINIMAX_MODEL,
|
|
"llama3.2",
|
|
"mistral",
|
|
"qwen2.5-coder:32b",
|
|
]
|
|
|
|
WORKSPACES_DIR = "./workspaces"
|
|
STORAGE_DIR = "./.files"
|
|
os.makedirs(STORAGE_DIR, exist_ok=True)
|
|
os.makedirs(WORKSPACES_DIR, exist_ok=True)
|
|
|
|
# =========================
|
|
# USER PROFILES (loaded from JSON for security)
|
|
# =========================
|
|
def load_user_profiles() -> Dict[str, Dict]:
|
|
"""Load user profiles from JSON config file."""
|
|
config_path = os.getenv("USER_PROFILES_PATH", "config/user_profiles.json")
|
|
try:
|
|
with open(config_path, "r", encoding="utf-8") as f:
|
|
config = json.load(f)
|
|
return config.get("user_profiles", {})
|
|
except FileNotFoundError:
|
|
print(f"WARNING: User profiles config not found at {config_path}")
|
|
return {}
|
|
except json.JSONDecodeError as e:
|
|
print(f"WARNING: Invalid JSON in user profiles config: {e}")
|
|
return {}
|
|
|
|
# Load profiles at startup
|
|
USER_PROFILES = load_user_profiles()
|
|
|
|
GUEST_PROFILE = {
|
|
"role": "guest",
|
|
"name": "Guest",
|
|
"workspace": "guest",
|
|
"rag_collection": "public",
|
|
"capabilities": ["basic_chat"],
|
|
"show_code": False,
|
|
}
|
|
|
|
# Sensible defaults per ruolo (S-Tier: thoughtful defaults) [file:3]
|
|
ROLE_DEFAULTS = {
|
|
"admin": {
|
|
"model": DEFAULT_TEXT_MODEL,
|
|
"top_k": 6,
|
|
"temperature": 0.3,
|
|
"rag_enabled": True,
|
|
"vision_detail": "high",
|
|
},
|
|
"engineering": {
|
|
"model": MINIMAX_MODEL,
|
|
"top_k": 5,
|
|
"temperature": 0.3,
|
|
"rag_enabled": True,
|
|
"vision_detail": "low",
|
|
},
|
|
"business": {
|
|
"model": DEFAULT_TEXT_MODEL,
|
|
"top_k": 4,
|
|
"temperature": 0.2,
|
|
"rag_enabled": True,
|
|
"vision_detail": "auto",
|
|
},
|
|
"architecture": {
|
|
"model": DEFAULT_TEXT_MODEL,
|
|
"top_k": 4,
|
|
"temperature": 0.3,
|
|
"rag_enabled": True,
|
|
"vision_detail": "high",
|
|
},
|
|
"guest": {
|
|
"model": DEFAULT_TEXT_MODEL,
|
|
"top_k": 3,
|
|
"temperature": 0.2,
|
|
"rag_enabled": False,
|
|
"vision_detail": "auto",
|
|
},
|
|
}
|
|
|
|
|
|
# =========================
|
|
# STORAGE
|
|
# =========================
|
|
class LocalStorageClient(BaseStorageClient):
|
|
def __init__(self, storage_path: str):
|
|
self.storage_path = storage_path
|
|
os.makedirs(storage_path, exist_ok=True)
|
|
|
|
async def upload_file(
|
|
self,
|
|
object_key: str,
|
|
data: bytes,
|
|
mime: str = "application/octet-stream",
|
|
overwrite: bool = True,
|
|
) -> Dict[str, str]:
|
|
file_path = os.path.join(self.storage_path, object_key)
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
with open(file_path, "wb") as f:
|
|
f.write(data)
|
|
return {"object_key": object_key, "url": f"/files/{object_key}"}
|
|
|
|
async def get_read_url(self, object_key: str) -> str:
|
|
return f"/files/{object_key}"
|
|
|
|
async def delete_file(self, object_key: str) -> bool:
|
|
path = os.path.join(self.storage_path, object_key)
|
|
if os.path.exists(path):
|
|
os.remove(path)
|
|
return True
|
|
return False
|
|
|
|
async def close(self):
|
|
pass
|
|
|
|
|
|
@cl.data_layer
|
|
def get_data_layer():
|
|
return SQLAlchemyDataLayer(conninfo=DATABASE_URL, storage_provider=LocalStorageClient(STORAGE_DIR))
|
|
|
|
|
|
# =========================
|
|
# OAUTH
|
|
# =========================
|
|
@cl.oauth_callback
|
|
def oauth_callback(
|
|
provider_id: str,
|
|
token: str,
|
|
raw_user_data: Dict[str, str],
|
|
default_user: cl.User,
|
|
) -> Optional[cl.User]:
|
|
if provider_id == "google":
|
|
email = raw_user_data.get("email", "").lower()
|
|
profile = USER_PROFILES.get(email, GUEST_PROFILE)
|
|
|
|
default_user.metadata.update(
|
|
{
|
|
"role": profile["role"],
|
|
"workspace": profile["workspace"],
|
|
"rag_collection": profile["rag_collection"],
|
|
"show_code": profile["show_code"],
|
|
"display_name": profile["name"],
|
|
}
|
|
)
|
|
return default_user
|
|
return default_user
|
|
|
|
|
|
def create_workspace(workspace_name: str) -> str:
|
|
path = os.path.join(WORKSPACES_DIR, workspace_name)
|
|
os.makedirs(path, exist_ok=True)
|
|
return path
|
|
|
|
|
|
# =========================
|
|
# CORE: DOCLING
|
|
# =========================
|
|
def process_file_with_docling(file_path: str) -> str:
|
|
try:
|
|
converter = DocumentConverter()
|
|
result = converter.convert(file_path)
|
|
return result.document.export_to_markdown()
|
|
except Exception as e:
|
|
print(f"❌ Docling Error: {e}")
|
|
return ""
|
|
|
|
|
|
# =========================
|
|
# CORE: BGE-M3 embeddings
|
|
# =========================
|
|
def get_bge_embeddings(text: str) -> Optional[Dict[str, Any]]:
|
|
try:
|
|
payload = {"texts": [text[:8000]]}
|
|
response = requests.post(BGE_API_URL, json=payload, timeout=30)
|
|
response.raise_for_status()
|
|
data = response.json().get("data", [])
|
|
if data:
|
|
return data[0]
|
|
return None
|
|
except Exception as e:
|
|
print(f"❌ BGE API Error: {e}")
|
|
return None
|
|
|
|
|
|
@lru_cache(maxsize=1000)
|
|
def get_bge_embeddings_cached(text: str):
|
|
return get_bge_embeddings(text)
|
|
|
|
|
|
# =========================
|
|
# CORE: QDRANT
|
|
# =========================
|
|
async def ensure_collection(collection_name: str):
|
|
client = AsyncQdrantClient(url=QDRANT_URL)
|
|
if not await client.collection_exists(collection_name):
|
|
await client.create_collection(
|
|
collection_name=collection_name,
|
|
vectors_config={"dense": VectorParams(size=1024, distance=Distance.COSINE)},
|
|
sparse_vectors_config={"sparse": SparseVectorParams()},
|
|
)
|
|
|
|
|
|
async def index_document(file_name: str, content: str, collection_name: str) -> int:
|
|
await ensure_collection(collection_name)
|
|
client = AsyncQdrantClient(url=QDRANT_URL)
|
|
|
|
chunk_size = 2000
|
|
overlap = 200
|
|
points: List[PointStruct] = []
|
|
|
|
for i in range(0, len(content), chunk_size - overlap):
|
|
chunk = content[i : i + chunk_size]
|
|
embedding_data = get_bge_embeddings(chunk)
|
|
if embedding_data:
|
|
points.append(
|
|
PointStruct(
|
|
id=str(uuid.uuid4()),
|
|
vector={"dense": embedding_data["dense"], "sparse": embedding_data["sparse"]},
|
|
payload={
|
|
"file_name": file_name,
|
|
"content": chunk,
|
|
"indexed_at": datetime.now().isoformat(),
|
|
},
|
|
)
|
|
)
|
|
|
|
if points:
|
|
await client.upsert(collection_name=collection_name, points=points)
|
|
return len(points)
|
|
|
|
return 0
|
|
|
|
|
|
async def search_hybrid(query: str, collection_name: str, limit: int = 4) -> str:
|
|
client = AsyncQdrantClient(url=QDRANT_URL)
|
|
|
|
if not await client.collection_exists(collection_name):
|
|
return ""
|
|
|
|
query_emb = get_bge_embeddings(query)
|
|
if not query_emb:
|
|
return ""
|
|
|
|
results = await client.query_points(
|
|
collection_name=collection_name,
|
|
prefetch=[Prefetch(query=query_emb["sparse"], using="sparse", limit=limit * 2)],
|
|
query=query_emb["dense"],
|
|
using="dense",
|
|
limit=limit,
|
|
)
|
|
|
|
context = []
|
|
for hit in results.points:
|
|
context.append(f"--- DA {hit.payload['file_name']} ---\n{hit.payload['content']}")
|
|
return "\n\n".join(context)
|
|
|
|
|
|
# =========================
|
|
# UX HELPERS (S-Tier: clarity, consistency)
|
|
# =========================
|
|
def role_to_badge_class(role: str) -> str:
|
|
allowed = {"admin", "engineering", "business", "architecture", "guest"}
|
|
return f"dfm-badge--{role}" if role in allowed else "dfm-badge--guest"
|
|
|
|
|
|
def build_system_prompt(system_instruction: str, has_rag: bool, has_files: bool) -> str:
|
|
base = (
|
|
"Sei un assistente tecnico esperto.\n"
|
|
"Obiettivo: rispondere in modo chiaro, preciso e operativo.\n"
|
|
"- Se mancano dettagli, fai 1-2 domande mirate.\n"
|
|
"- Se scrivi codice, includi snippet piccoli e verificabili.\n"
|
|
)
|
|
if has_rag:
|
|
base += "- Usa il contesto RAG come fonte primaria quando presente.\n"
|
|
if has_files:
|
|
base += "- Se sono presenti file/immagini, sfrutta le informazioni estratte.\n"
|
|
|
|
if system_instruction.strip():
|
|
base += "\nIstruzione aggiuntiva (utente): " + system_instruction.strip() + "\n"
|
|
|
|
return base
|
|
|
|
|
|
def extract_code_blocks(text: str) -> List[str]:
|
|
return re.findall(r"```(?:\w+)?\n(.*?)```", text, re.DOTALL)
|
|
|
|
|
|
async def log_metrics(metrics: dict):
|
|
# Mantieni semplice: stdout (come nella tua versione) [file:6]
|
|
print("METRICS:", metrics)
|
|
|
|
|
|
# =========================
|
|
# CHAINLIT HANDLERS
|
|
# =========================
|
|
@cl.on_chat_start
|
|
async def start():
|
|
# 1) Profilo utente
|
|
user = cl.user_session.get("user")
|
|
email = user.identifier if user else "guest"
|
|
|
|
profile = USER_PROFILES.get(email, GUEST_PROFILE)
|
|
cl.user_session.set("profile", profile)
|
|
|
|
create_workspace(profile["workspace"])
|
|
|
|
role = profile.get("role", "guest")
|
|
defaults = ROLE_DEFAULTS.get(role, ROLE_DEFAULTS["guest"])
|
|
cl.user_session.set("role_defaults", defaults)
|
|
|
|
# 2) Badge (HTML controllato; stile via CSS)
|
|
badge_html = f"""
|
|
<div class="dfm-badge {role_to_badge_class(role)}">
|
|
<span><b>{profile['name']}</b></span>
|
|
<span style="opacity:.8">{role.upper()}</span>
|
|
<span style="opacity:.7">· {profile['workspace']}</span>
|
|
</div>
|
|
"""
|
|
await cl.Message(content=badge_html).send()
|
|
|
|
# 3) Settings UI (Clarity + sensible defaults)
|
|
settings = await cl.ChatSettings(
|
|
[
|
|
cl.input_widget.Switch(
|
|
id="rag_enabled",
|
|
label="📚 Usa Conoscenza Documenti",
|
|
initial=bool(defaults["rag_enabled"]),
|
|
description="Attiva la ricerca nei documenti caricati (consigliato).",
|
|
),
|
|
cl.input_widget.Slider(
|
|
id="top_k",
|
|
label="Profondità Ricerca (documenti)",
|
|
initial=int(defaults["top_k"]),
|
|
min=1,
|
|
max=10,
|
|
step=1,
|
|
description="Più documenti = risposta più completa ma più lenta.",
|
|
),
|
|
cl.input_widget.Select(
|
|
id="model",
|
|
label="🤖 Modello AI",
|
|
values=MODEL_CHOICES,
|
|
initial_value=str(defaults["model"]),
|
|
),
|
|
cl.input_widget.Slider(
|
|
id="temperature",
|
|
label="Creatività",
|
|
initial=float(defaults["temperature"]),
|
|
min=0,
|
|
max=1,
|
|
step=0.1,
|
|
description="Bassa = più precisione (consigliato per codice).",
|
|
),
|
|
cl.input_widget.Select(
|
|
id="vision_detail",
|
|
label="🔍 Dettaglio Analisi Immagini",
|
|
values=["auto", "low", "high"],
|
|
initial_value=str(defaults["vision_detail"]),
|
|
),
|
|
cl.input_widget.TextInput(
|
|
id="system_instruction",
|
|
label="✏️ Istruzione Sistema (opzionale)",
|
|
initial="",
|
|
placeholder="es: Rispondi con bullet points e includi esempi",
|
|
description="Personalizza stile/format delle risposte.",
|
|
),
|
|
]
|
|
).send()
|
|
|
|
cl.user_session.set("settings", settings)
|
|
|
|
await cl.Message(
|
|
content=(
|
|
"✅ Ai Station online.\n"
|
|
f"• Workspace: `{profile['workspace']}`\n"
|
|
f"• Default modello: `{defaults['model']}`\n"
|
|
f"• Vision: `{VISION_MODEL}`"
|
|
)
|
|
).send()
|
|
|
|
|
|
@cl.on_settings_update
|
|
async def setupagentsettings(settings):
|
|
cl.user_session.set("settings", settings)
|
|
|
|
await cl.Message(
|
|
content=(
|
|
"✅ Impostazioni aggiornate:\n"
|
|
f"• Modello: `{settings.get('model')}`\n"
|
|
f"• RAG: {'ON' if settings.get('rag_enabled') else 'OFF'} · top_k={settings.get('top_k')}\n"
|
|
f"• Creatività: {settings.get('temperature')}\n"
|
|
f"• Vision detail: `{settings.get('vision_detail')}`"
|
|
)
|
|
).send()
|
|
|
|
|
|
@cl.on_chat_resume
|
|
async def on_chat_resume(thread: ThreadDict):
|
|
user_identifier = thread.get("userIdentifier")
|
|
profile = USER_PROFILES.get(user_identifier, GUEST_PROFILE)
|
|
cl.user_session.set("profile", profile)
|
|
create_workspace(profile["workspace"])
|
|
await cl.Message(content="Bentornato! Riprendiamo da qui.").send()
|
|
|
|
|
|
@cl.on_message
|
|
async def main(message: cl.Message):
|
|
start_time = time.time()
|
|
|
|
profile = cl.user_session.get("profile", GUEST_PROFILE)
|
|
settings = cl.user_session.get("settings", {})
|
|
|
|
selected_model = settings.get("model", DEFAULT_TEXT_MODEL)
|
|
temperature = float(settings.get("temperature", 0.3))
|
|
rag_enabled = bool(settings.get("rag_enabled", True))
|
|
top_k = int(settings.get("top_k", 4))
|
|
vision_detail = settings.get("vision_detail", "auto")
|
|
system_instruction = (settings.get("system_instruction", "") or "").strip()
|
|
|
|
workspace = create_workspace(profile["workspace"])
|
|
|
|
# 1) Gestione upload (immagini / pdf / docx)
|
|
images_for_vision: List[str] = []
|
|
doc_context = ""
|
|
|
|
if message.elements:
|
|
for element in message.elements:
|
|
file_path = os.path.join(workspace, element.name)
|
|
shutil.copy(element.path, file_path)
|
|
|
|
# Immagini
|
|
if "image" in (element.mime or ""):
|
|
images_for_vision.append(file_path)
|
|
msg_img = cl.Message(content=f"🖼️ Analizzo immagine `{element.name}` con `{VISION_MODEL}`...")
|
|
await msg_img.send()
|
|
|
|
try:
|
|
with open(file_path, "rb") as imgfile:
|
|
imgbytes = imgfile.read()
|
|
|
|
client_sync = ollama.Client(host=OLLAMA_URL)
|
|
res = client_sync.chat(
|
|
model=VISION_MODEL,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
"Analizza questa immagine tecnica. "
|
|
"Trascrivi testi/codici e descrivi diagrammi o tabelle in dettaglio. "
|
|
f"Dettaglio richiesto: {vision_detail}."
|
|
),
|
|
"images": [imgbytes],
|
|
}
|
|
],
|
|
)
|
|
desc = res.get("message", {}).get("content", "")
|
|
doc_context += f"\n\n## DESCRIZIONE IMMAGINE: {element.name}\n{desc}\n"
|
|
msg_img.content = f"✅ Immagine analizzata: {desc[:300]}..."
|
|
await msg_img.update()
|
|
except Exception as e:
|
|
msg_img.content = f"❌ Errore analisi immagine: {e}"
|
|
await msg_img.update()
|
|
|
|
# Documenti (pdf/docx)
|
|
elif element.name.lower().endswith((".pdf", ".docx")):
|
|
msg_doc = cl.Message(content=f"📄 Leggo `{element.name}` con Docling (tabelle/formule)...")
|
|
await msg_doc.send()
|
|
|
|
markdown_content = process_file_with_docling(file_path)
|
|
if markdown_content:
|
|
chunks = await index_document(element.name, markdown_content, profile["rag_collection"])
|
|
doc_context += f"\n\n## CONTENUTO FILE: {element.name}\n{markdown_content[:2000]}\n"
|
|
msg_doc.content = f"✅ `{element.name}` convertito e indicizzato ({chunks} chunks)."
|
|
else:
|
|
msg_doc.content = f"❌ Errore lettura `{element.name}`."
|
|
await msg_doc.update()
|
|
|
|
# 2) RAG retrieval (solo se attivo e senza immagini-only flow)
|
|
rag_context = ""
|
|
if rag_enabled and not images_for_vision:
|
|
rag_context = await search_hybrid(message.content, profile["rag_collection"], limit=top_k)
|
|
|
|
has_rag = bool(rag_context.strip())
|
|
has_files = bool(doc_context.strip())
|
|
|
|
# 3) Prompt building
|
|
system_prompt = build_system_prompt(system_instruction, has_rag=has_rag, has_files=has_files)
|
|
|
|
final_context = ""
|
|
if has_rag:
|
|
final_context += "\n\n# CONTESTO RAG\n" + rag_context
|
|
if has_files:
|
|
final_context += "\n\n# CONTESTO FILE SESSIONE\n" + doc_context
|
|
|
|
# 4) Generazione (stream)
|
|
msg = cl.Message(content="")
|
|
await msg.send()
|
|
|
|
error: Optional[str] = None
|
|
try:
|
|
client_async = ollama.AsyncClient(host=OLLAMA_URL)
|
|
stream = await client_async.chat(
|
|
model=selected_model,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": f"Domanda: {message.content}\n{final_context}"},
|
|
],
|
|
options={"temperature": temperature},
|
|
stream=True,
|
|
)
|
|
async for chunk in stream:
|
|
content = chunk.get("message", {}).get("content", "")
|
|
if content:
|
|
await msg.stream_token(content)
|
|
|
|
await msg.update()
|
|
except Exception as e:
|
|
error = str(e)
|
|
await msg.stream_token(f"\n\n❌ Errore AI: {error}")
|
|
await msg.update()
|
|
|
|
# 5) Salvataggio code blocks (solo per profili con show_code)
|
|
if profile.get("show_code", False) and msg.content:
|
|
codeblocks = extract_code_blocks(msg.content)
|
|
if codeblocks:
|
|
for i, code in enumerate(codeblocks):
|
|
fname = f"script_{datetime.now().strftime('%H%M%S')}_{i}.py"
|
|
try:
|
|
with open(os.path.join(workspace, fname), "w", encoding="utf-8") as f:
|
|
f.write(code.strip())
|
|
await cl.Message(content=f"💾 Script salvato: `{fname}`").send()
|
|
except Exception as e:
|
|
await cl.Message(content=f"❌ Errore salvataggio `{fname}`: {e}").send()
|
|
|
|
# 6) Metriche
|
|
elapsed = time.time() - start_time
|
|
metrics = {
|
|
"response_time": elapsed,
|
|
"rag_hits": rag_context.count("--- DA ") if rag_context else 0,
|
|
"model": selected_model,
|
|
"user_role": profile.get("role", "unknown"),
|
|
"error": error,
|
|
}
|
|
await log_metrics(metrics)
|