344 lines
12 KiB
Python
344 lines
12 KiB
Python
|
|
import os
|
||
|
|
import re
|
||
|
|
import uuid
|
||
|
|
import shutil
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import Optional, Dict, List
|
||
|
|
import chainlit as cl
|
||
|
|
import ollama
|
||
|
|
import fitz # PyMuPDF
|
||
|
|
from qdrant_client import AsyncQdrantClient
|
||
|
|
from qdrant_client.models import PointStruct, Distance, VectorParams
|
||
|
|
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
|
||
|
|
|
||
|
|
# === FIX IMPORT ROBUSTO ===
|
||
|
|
# Gestisce le differenze tra le versioni di Chainlit 2.x
|
||
|
|
try:
|
||
|
|
from chainlit.data.storage_clients import BaseStorageClient
|
||
|
|
except ImportError:
|
||
|
|
try:
|
||
|
|
from chainlit.data.base import BaseStorageClient
|
||
|
|
except ImportError:
|
||
|
|
from chainlit.data.storage_clients.base import BaseStorageClient
|
||
|
|
|
||
|
|
# === CONFIGURAZIONE ===
|
||
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://ai_user:secure_password_here@postgres:5432/ai_station")
|
||
|
|
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://192.168.1.243:11434")
|
||
|
|
QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333")
|
||
|
|
WORKSPACES_DIR = "./workspaces"
|
||
|
|
STORAGE_DIR = "./.files"
|
||
|
|
|
||
|
|
os.makedirs(STORAGE_DIR, exist_ok=True)
|
||
|
|
os.makedirs(WORKSPACES_DIR, exist_ok=True)
|
||
|
|
|
||
|
|
# === MAPPING UTENTI E RUOLI ===
|
||
|
|
USER_PROFILES = {
|
||
|
|
"giuseppe@defranceschi.pro": {
|
||
|
|
"role": "admin",
|
||
|
|
"name": "Giuseppe",
|
||
|
|
"workspace": "admin_workspace",
|
||
|
|
"rag_collection": "admin_docs",
|
||
|
|
"capabilities": ["debug", "system_prompts", "user_management", "all_models"],
|
||
|
|
"show_code": True
|
||
|
|
},
|
||
|
|
"federica.tecchio@gmail.com": {
|
||
|
|
"role": "business",
|
||
|
|
"name": "Federica",
|
||
|
|
"workspace": "business_workspace",
|
||
|
|
"rag_collection": "contabilita",
|
||
|
|
"capabilities": ["pdf_upload", "basic_chat"],
|
||
|
|
"show_code": False
|
||
|
|
},
|
||
|
|
"giuseppe.defranceschi@gmail.com": {
|
||
|
|
"role": "admin",
|
||
|
|
"name": "Giuseppe",
|
||
|
|
"workspace": "admin_workspace",
|
||
|
|
"rag_collection": "admin_docs",
|
||
|
|
"capabilities": ["debug", "system_prompts", "user_management", "all_models"],
|
||
|
|
"show_code": True
|
||
|
|
},
|
||
|
|
"riccardob545@gmail.com": {
|
||
|
|
"role": "engineering",
|
||
|
|
"name": "Riccardo",
|
||
|
|
"workspace": "engineering_workspace",
|
||
|
|
"rag_collection": "engineering_docs",
|
||
|
|
"capabilities": ["code_execution", "data_viz", "advanced_chat"],
|
||
|
|
"show_code": True
|
||
|
|
},
|
||
|
|
"giuliadefranceschi05@gmail.com": {
|
||
|
|
"role": "architecture",
|
||
|
|
"name": "Giulia",
|
||
|
|
"workspace": "architecture_workspace",
|
||
|
|
"rag_collection": "architecture_manuals",
|
||
|
|
"capabilities": ["visual_chat", "pdf_upload", "image_gen"],
|
||
|
|
"show_code": False
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
# === CUSTOM LOCAL STORAGE CLIENT (FIXED) ===# Questa classe ora implementa tutti i metodi astratti richiesti da Chainlit 2.8.3
|
||
|
|
class LocalStorageClient(BaseStorageClient):
|
||
|
|
"""Storage locale su filesystem per file/elementi"""
|
||
|
|
|
||
|
|
def __init__(self, storage_path: str):
|
||
|
|
self.storage_path = storage_path
|
||
|
|
os.makedirs(storage_path, exist_ok=True)
|
||
|
|
|
||
|
|
async def upload_file(
|
||
|
|
self,
|
||
|
|
object_key: str,
|
||
|
|
data: bytes,
|
||
|
|
mime: str = "application/octet-stream",
|
||
|
|
overwrite: bool = True,
|
||
|
|
) -> Dict[str, str]:
|
||
|
|
file_path = os.path.join(self.storage_path, object_key)
|
||
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||
|
|
with open(file_path, "wb") as f:
|
||
|
|
f.write(data)
|
||
|
|
return {"object_key": object_key, "url": f"/files/{object_key}"}
|
||
|
|
|
||
|
|
# Implementazione metodi obbligatori mancanti nella versione precedente
|
||
|
|
async def get_read_url(self, object_key: str) -> str:
|
||
|
|
return f"/files/{object_key}"
|
||
|
|
|
||
|
|
async def delete_file(self, object_key: str) -> bool:
|
||
|
|
file_path = os.path.join(self.storage_path, object_key)
|
||
|
|
if os.path.exists(file_path):
|
||
|
|
os.remove(file_path)
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def close(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
# === DATA LAYER ===
|
||
|
|
@cl.data_layer
|
||
|
|
def get_data_layer():
|
||
|
|
return SQLAlchemyDataLayer(
|
||
|
|
conninfo=DATABASE_URL,
|
||
|
|
user_thread_limit=1000,
|
||
|
|
storage_provider=LocalStorageClient(storage_path=STORAGE_DIR)
|
||
|
|
)
|
||
|
|
|
||
|
|
# === OAUTH CALLBACK ===
|
||
|
|
@cl.oauth_callback
|
||
|
|
def oauth_callback(
|
||
|
|
provider_id: str,
|
||
|
|
token: str,
|
||
|
|
raw_user_data: Dict[str, str],
|
||
|
|
default_user: cl.User,
|
||
|
|
) -> Optional[cl.User]:
|
||
|
|
if provider_id == "google":
|
||
|
|
email = raw_user_data.get("email", "").lower()
|
||
|
|
|
||
|
|
# Verifica se utente è autorizzato (opzionale: blocca se non in lista)
|
||
|
|
# if email not in USER_PROFILES:
|
||
|
|
# return None
|
||
|
|
|
||
|
|
# Recupera profilo o usa default Guest
|
||
|
|
profile = USER_PROFILES.get(email, get_user_profile("guest"))
|
||
|
|
|
||
|
|
default_user.metadata.update({
|
||
|
|
"picture": raw_user_data.get("picture", ""),
|
||
|
|
"role": profile["role"],
|
||
|
|
"workspace": profile["workspace"],
|
||
|
|
"rag_collection": profile["rag_collection"],
|
||
|
|
"capabilities": profile["capabilities"],
|
||
|
|
"show_code": profile["show_code"],
|
||
|
|
"display_name": profile["name"]
|
||
|
|
})
|
||
|
|
return default_user
|
||
|
|
return default_user
|
||
|
|
|
||
|
|
# === UTILITY FUNCTIONS ===
|
||
|
|
def get_user_profile(user_email: str) -> Dict:
|
||
|
|
return USER_PROFILES.get(user_email.lower(), {
|
||
|
|
"role": "guest",
|
||
|
|
"name": "Ospite",
|
||
|
|
"workspace": "guest_workspace",
|
||
|
|
"rag_collection": "documents",
|
||
|
|
"capabilities": [],
|
||
|
|
"show_code": False
|
||
|
|
})
|
||
|
|
|
||
|
|
def create_workspace(workspace_name: str) -> str:
|
||
|
|
path = os.path.join(WORKSPACES_DIR, workspace_name)
|
||
|
|
os.makedirs(path, exist_ok=True)
|
||
|
|
return path
|
||
|
|
|
||
|
|
def save_code_to_file(code: str, workspace: str) -> str:
|
||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
|
file_name = f"code_{timestamp}.py"
|
||
|
|
file_path = os.path.join(WORKSPACES_DIR, workspace, file_name)
|
||
|
|
with open(file_path, "w", encoding="utf-8") as f:
|
||
|
|
f.write(code)
|
||
|
|
return file_path
|
||
|
|
|
||
|
|
def extract_text_from_pdf(pdf_path: str) -> str:
|
||
|
|
try:
|
||
|
|
doc = fitz.open(pdf_path)
|
||
|
|
text = "\n".join([page.get_text() for page in doc])
|
||
|
|
doc.close()
|
||
|
|
return text
|
||
|
|
except Exception:
|
||
|
|
return ""
|
||
|
|
|
||
|
|
# === QDRANT FUNCTIONS ===
|
||
|
|
async def get_qdrant_client() -> AsyncQdrantClient:
|
||
|
|
return AsyncQdrantClient(url=QDRANT_URL)
|
||
|
|
|
||
|
|
async def ensure_collection(collection_name: str):
|
||
|
|
client = await get_qdrant_client()
|
||
|
|
if not await client.collection_exists(collection_name):
|
||
|
|
await client.create_collection(
|
||
|
|
collection_name=collection_name,
|
||
|
|
vectors_config=VectorParams(size=768, distance=Distance.COSINE)
|
||
|
|
)
|
||
|
|
|
||
|
|
async def get_embeddings(text: str) -> list:
|
||
|
|
client = ollama.Client(host=OLLAMA_URL)
|
||
|
|
try:
|
||
|
|
response = client.embed(model='nomic-embed-text', input=text[:2000])
|
||
|
|
if 'embeddings' in response: return response['embeddings'][0]
|
||
|
|
return response.get('embedding', [])
|
||
|
|
except: return []
|
||
|
|
|
||
|
|
async def index_document(file_name: str, content: str, collection_name: str) -> bool:
|
||
|
|
try:
|
||
|
|
await ensure_collection(collection_name)
|
||
|
|
embedding = await get_embeddings(content)
|
||
|
|
if not embedding: return False
|
||
|
|
|
||
|
|
qdrant = await get_qdrant_client()
|
||
|
|
await qdrant.upsert(
|
||
|
|
collection_name=collection_name,
|
||
|
|
points=[PointStruct(
|
||
|
|
id=str(uuid.uuid4()),
|
||
|
|
vector=embedding,
|
||
|
|
payload={"file_name": file_name, "content": content[:3000], "indexed_at": datetime.now().isoformat()}
|
||
|
|
)]
|
||
|
|
)
|
||
|
|
return True
|
||
|
|
except: return False
|
||
|
|
|
||
|
|
async def search_qdrant(query: str, collection: str) -> str:
|
||
|
|
try:
|
||
|
|
client = await get_qdrant_client()
|
||
|
|
if not await client.collection_exists(collection): return ""
|
||
|
|
emb = await get_embeddings(query)
|
||
|
|
if not emb: return ""
|
||
|
|
res = await client.query_points(collection_name=collection, query=emb, limit=3)
|
||
|
|
return "\n\n".join([hit.payload['content'] for hit in res.points if hit.payload])
|
||
|
|
except: return ""
|
||
|
|
|
||
|
|
# === CHAINLIT HANDLERS ===
|
||
|
|
|
||
|
|
@cl.on_chat_start
|
||
|
|
async def on_chat_start():
|
||
|
|
user = cl.user_session.get("user")
|
||
|
|
|
||
|
|
if not user:
|
||
|
|
# Fallback locale se non c'è auth
|
||
|
|
user_email = "guest@local"
|
||
|
|
profile = get_user_profile(user_email)
|
||
|
|
else:
|
||
|
|
user_email = user.identifier
|
||
|
|
# I metadati sono già popolati dalla callback oauth
|
||
|
|
profile = USER_PROFILES.get(user_email, get_user_profile("guest"))
|
||
|
|
|
||
|
|
# Salva in sessione
|
||
|
|
cl.user_session.set("email", user_email)
|
||
|
|
cl.user_session.set("role", profile["role"])
|
||
|
|
cl.user_session.set("workspace", profile["workspace"])
|
||
|
|
cl.user_session.set("rag_collection", profile["rag_collection"])
|
||
|
|
cl.user_session.set("show_code", profile["show_code"])
|
||
|
|
|
||
|
|
create_workspace(profile["workspace"])
|
||
|
|
|
||
|
|
# === SETTINGS WIDGETS ===
|
||
|
|
settings_widgets = [
|
||
|
|
cl.input_widget.Select(
|
||
|
|
id="model",
|
||
|
|
label="Modello AI",
|
||
|
|
values=["glm-4.6:cloud", "llama3.2", "mistral", "qwen2.5-coder:32b"],
|
||
|
|
initial_value="glm-4.6:cloud",
|
||
|
|
),
|
||
|
|
cl.input_widget.Slider(
|
||
|
|
id="temperature",
|
||
|
|
label="Temperatura",
|
||
|
|
initial=0.7, min=0, max=2, step=0.1,
|
||
|
|
),
|
||
|
|
]
|
||
|
|
if profile["role"] == "admin":
|
||
|
|
settings_widgets.append(cl.input_widget.Switch(id="rag_enabled", label="Abilita RAG", initial=True))
|
||
|
|
|
||
|
|
await cl.ChatSettings(settings_widgets).send()
|
||
|
|
|
||
|
|
await cl.Message(
|
||
|
|
content=f"👋 Ciao **{profile['name']}**!\n"
|
||
|
|
f"Ruolo: `{profile['role']}` | Workspace: `{profile['workspace']}`\n"
|
||
|
|
).send()
|
||
|
|
|
||
|
|
@cl.on_settings_update
|
||
|
|
async def on_settings_update(settings):
|
||
|
|
cl.user_session.set("settings", settings)
|
||
|
|
await cl.Message(content="✅ Impostazioni aggiornate").send()
|
||
|
|
|
||
|
|
@cl.on_message
|
||
|
|
async def on_message(message: cl.Message):
|
||
|
|
workspace = cl.user_session.get("workspace")
|
||
|
|
rag_collection = cl.user_session.get("rag_collection")
|
||
|
|
user_role = cl.user_session.get("role")
|
||
|
|
show_code = cl.user_session.get("show_code")
|
||
|
|
|
||
|
|
settings = cl.user_session.get("settings", {})
|
||
|
|
model = settings.get("model", "glm-4.6:cloud")
|
||
|
|
temperature = settings.get("temperature", 0.7)
|
||
|
|
rag_enabled = settings.get("rag_enabled", True) if user_role == "admin" else True
|
||
|
|
|
||
|
|
# 1. GESTIONE FILE
|
||
|
|
if message.elements:
|
||
|
|
for element in message.elements:
|
||
|
|
dest = os.path.join(WORKSPACES_DIR, workspace, element.name)
|
||
|
|
shutil.copy(element.path, dest)
|
||
|
|
if element.name.endswith(".pdf"):
|
||
|
|
text = extract_text_from_pdf(dest)
|
||
|
|
if text:
|
||
|
|
await index_document(element.name, text, rag_collection)
|
||
|
|
await cl.Message(content=f"✅ **{element.name}** indicizzato.").send()
|
||
|
|
|
||
|
|
# 2. RAG
|
||
|
|
context = ""
|
||
|
|
if rag_enabled:
|
||
|
|
context = await search_qdrant(message.content, rag_collection)
|
||
|
|
|
||
|
|
system_prompt = "Sei un assistente esperto."
|
||
|
|
if context: system_prompt += f"\n\nCONTESTO:\n{context}"
|
||
|
|
|
||
|
|
# 3. GENERAZIONE
|
||
|
|
client = ollama.AsyncClient(host=OLLAMA_URL)
|
||
|
|
msg = cl.Message(content="")
|
||
|
|
await msg.send()
|
||
|
|
|
||
|
|
stream = await client.chat(
|
||
|
|
model=model,
|
||
|
|
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": message.content}],
|
||
|
|
options={"temperature": temperature},
|
||
|
|
stream=True
|
||
|
|
)
|
||
|
|
|
||
|
|
full_resp = ""
|
||
|
|
async for chunk in stream:
|
||
|
|
token = chunk['message']['content']
|
||
|
|
full_resp += token
|
||
|
|
await msg.stream_token(token)
|
||
|
|
await msg.update()
|
||
|
|
|
||
|
|
# 4. SALVATAGGIO CODICE
|
||
|
|
if show_code:
|
||
|
|
blocks = re.findall(r"``````", full_resp, re.DOTALL)
|
||
|
|
elements = []
|
||
|
|
for code in blocks:
|
||
|
|
path = save_code_to_file(code.strip(), workspace)
|
||
|
|
elements.append(cl.File(name=os.path.basename(path), path=path, display="inline"))
|
||
|
|
if elements:
|
||
|
|
await cl.Message(content="💾 Codice salvato", elements=elements).send()
|