317 lines
12 KiB
Python
317 lines
12 KiB
Python
import os
|
|
import re
|
|
import uuid
|
|
import shutil
|
|
import pandas as pd # NUOVO: Gestione Excel
|
|
import httpx # NUOVO: Chiamate API Remote
|
|
from datetime import datetime
|
|
from typing import Optional, Dict, List
|
|
import chainlit as cl
|
|
import ollama
|
|
import fitz # PyMuPDF
|
|
from qdrant_client import AsyncQdrantClient
|
|
from qdrant_client import models
|
|
from qdrant_client.models import PointStruct, Distance, VectorParams, SparseVectorParams, SparseIndexParams
|
|
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
|
|
|
|
# === FIX IMPORT ===
|
|
try:
|
|
from chainlit.data.storage_clients import BaseStorageClient
|
|
except ImportError:
|
|
try:
|
|
from chainlit.data.base import BaseStorageClient
|
|
except ImportError:
|
|
from chainlit.data.storage_clients.base import BaseStorageClient
|
|
|
|
# === CONFIGURAZIONE ===
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://ai_user:secure_password_here@postgres:5432/ai_station")
|
|
# PUNTANO AL SERVER .243 (Il "Cervello")
|
|
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://192.168.1.243:11434")
|
|
BGE_API_URL = os.getenv("BGE_API_URL", "http://192.168.1.243:8001")
|
|
QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333")
|
|
|
|
WORKSPACES_DIR = "./workspaces"
|
|
STORAGE_DIR = "./.files"
|
|
|
|
os.makedirs(STORAGE_DIR, exist_ok=True)
|
|
os.makedirs(WORKSPACES_DIR, exist_ok=True)
|
|
|
|
# === MAPPING UTENTI E RUOLI ===
|
|
USER_PROFILES = {
|
|
"giuseppe@defranceschi.pro": {
|
|
"role": "admin",
|
|
"name": "Giuseppe",
|
|
"workspace": "admin_workspace",
|
|
"rag_collection": "admin_docs",
|
|
"capabilities": ["debug", "system_prompts", "user_management", "all_models"],
|
|
"show_code": True
|
|
},
|
|
"federica.tecchio@gmail.com": {
|
|
"role": "business",
|
|
"name": "Federica",
|
|
"workspace": "business_workspace",
|
|
"rag_collection": "contabilita",
|
|
"capabilities": ["pdf_upload", "basic_chat"],
|
|
"show_code": False
|
|
},
|
|
"giuseppe.defranceschi@gmail.com": {
|
|
"role": "admin",
|
|
"name": "Giuseppe",
|
|
"workspace": "admin_workspace",
|
|
"rag_collection": "admin_docs",
|
|
"capabilities": ["debug", "system_prompts", "user_management", "all_models"],
|
|
"show_code": True
|
|
},
|
|
"riccardob545@gmail.com": {
|
|
"role": "engineering",
|
|
"name": "Riccardo",
|
|
"workspace": "engineering_workspace",
|
|
"rag_collection": "engineering_docs",
|
|
"capabilities": ["code_execution", "data_viz", "advanced_chat"],
|
|
"show_code": True
|
|
},
|
|
"giuliadefranceschi05@gmail.com": {
|
|
"role": "architecture",
|
|
"name": "Giulia",
|
|
"workspace": "architecture_workspace",
|
|
"rag_collection": "architecture_manuals",
|
|
"capabilities": ["visual_chat", "pdf_upload", "image_gen"],
|
|
"show_code": False
|
|
}
|
|
}
|
|
|
|
# === STORAGE CLIENT ===
|
|
class LocalStorageClient(BaseStorageClient):
|
|
def __init__(self, storage_path: str):
|
|
self.storage_path = storage_path
|
|
os.makedirs(storage_path, exist_ok=True)
|
|
|
|
async def upload_file(self, object_key: str, data: bytes, mime: str = "application/octet-stream", overwrite: bool = True) -> Dict[str, str]:
|
|
file_path = os.path.join(self.storage_path, object_key)
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
with open(file_path, "wb") as f: f.write(data)
|
|
return {"object_key": object_key, "url": f"/files/{object_key}"}
|
|
|
|
async def get_read_url(self, object_key: str) -> str: return f"/files/{object_key}"
|
|
async def delete_file(self, object_key: str) -> bool:
|
|
path = os.path.join(self.storage_path, object_key)
|
|
if os.path.exists(path): os.remove(path); return True
|
|
return False
|
|
async def close(self): pass
|
|
|
|
@cl.data_layer
|
|
def get_data_layer():
|
|
return SQLAlchemyDataLayer(conninfo=DATABASE_URL, user_thread_limit=1000, storage_provider=LocalStorageClient(STORAGE_DIR))
|
|
|
|
# === OAUTH ===
|
|
@cl.oauth_callback
|
|
def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, str], default_user: cl.User) -> Optional[cl.User]:
|
|
if provider_id == "google":
|
|
email = raw_user_data.get("email", "").lower()
|
|
profile = USER_PROFILES.get(email, get_user_profile("guest"))
|
|
default_user.metadata.update({
|
|
"picture": raw_user_data.get("picture", ""),
|
|
"role": profile["role"], "workspace": profile["workspace"],
|
|
"rag_collection": profile["rag_collection"], "capabilities": profile["capabilities"],
|
|
"show_code": profile["show_code"], "display_name": profile["name"]
|
|
})
|
|
return default_user
|
|
return default_user
|
|
|
|
def get_user_profile(email: str) -> Dict:
|
|
return USER_PROFILES.get(email.lower(), {"role": "guest", "name": "Ospite", "workspace": "guest_workspace", "rag_collection": "documents", "show_code": False})
|
|
|
|
def create_workspace(name: str) -> str:
|
|
path = os.path.join(WORKSPACES_DIR, name)
|
|
os.makedirs(path, exist_ok=True)
|
|
return path
|
|
|
|
def save_code_to_file(code: str, workspace: str) -> str:
|
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
path = os.path.join(WORKSPACES_DIR, workspace, f"code_{ts}.py")
|
|
with open(path, "w", encoding="utf-8") as f: f.write(code)
|
|
return path
|
|
|
|
# === PARSING DOCUMENTI ===
|
|
def extract_text_from_pdf(path: str) -> str:
|
|
try:
|
|
doc = fitz.open(path)
|
|
return "\n".join([page.get_text() for page in doc])
|
|
except: return ""
|
|
|
|
def extract_text_from_excel(path: str) -> str:
|
|
"""Estrae testo da Excel convertendo i fogli in Markdown"""
|
|
try:
|
|
xl = pd.read_excel(path, sheet_name=None)
|
|
text_content = []
|
|
for sheet, df in xl.items():
|
|
text_content.append(f"\n--- Foglio Excel: {sheet} ---\n")
|
|
# Pulisce NaN e converte in stringa
|
|
clean_df = df.fillna("").astype(str)
|
|
text_content.append(clean_df.to_markdown(index=False))
|
|
return "\n".join(text_content)
|
|
except Exception as e:
|
|
print(f"❌ Errore Excel: {e}")
|
|
return ""
|
|
|
|
# === AI & EMBEDDINGS (Remoto) ===
|
|
async def get_embeddings(text: str) -> list:
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
resp = await client.post(
|
|
f"{BGE_API_URL}/embed",
|
|
json={"texts": [text], "normalize": True}
|
|
)
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
# Gestisce sia il vecchio formato (lista diretta) che il nuovo (dict)
|
|
if isinstance(data, list): return data[0] # Vecchia API
|
|
if "dense" in data: return data["dense"][0] # Nuova API Hybrid
|
|
if "embeddings" in data: return data["embeddings"][0] # API precedente
|
|
except Exception as e:
|
|
print(f"⚠️ Errore Embedding: {e}")
|
|
return []
|
|
|
|
async def ensure_collection(name: str):
|
|
client = AsyncQdrantClient(url=QDRANT_URL)
|
|
if not await client.collection_exists(name):
|
|
# Creiamo una collezione ottimizzata
|
|
await client.create_collection(
|
|
collection_name=name,
|
|
vectors_config={
|
|
"bge_dense": VectorParams(size=1024, distance=Distance.COSINE)
|
|
}
|
|
# Se in futuro abilitiamo lo sparse output dal .243, aggiungeremo:
|
|
# sparse_vectors_config={"bge_sparse": SparseVectorParams(index=SparseIndexParams(on_disk=False))}
|
|
)
|
|
|
|
async def index_document(filename: str, content: str, collection: str) -> bool:
|
|
try:
|
|
await ensure_collection(collection)
|
|
chunks = [content[i:i+3000] for i in range(0, len(content), 3000)]
|
|
|
|
qdrant = AsyncQdrantClient(url=QDRANT_URL)
|
|
points = []
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
# Ottieni embedding (assume che get_embeddings ritorni la lista float)
|
|
# Nota: Se hai aggiornato l'API .243 per ritornare un dict {"dense": ...},
|
|
# devi aggiornare get_embeddings per estrarre ["dense"]!
|
|
|
|
# Vedere funzione get_embeddings aggiornata sotto
|
|
emb = await get_embeddings(chunk)
|
|
|
|
if emb:
|
|
points.append(PointStruct(
|
|
id=str(uuid.uuid4()),
|
|
# Vettori nominati
|
|
vector={"bge_dense": emb},
|
|
payload={"file_name": filename, "content": chunk, "chunk_id": i}
|
|
))
|
|
|
|
if points:
|
|
await qdrant.upsert(collection_name=collection, points=points)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Index Error: {e}")
|
|
return False
|
|
|
|
async def search_qdrant(query: str, collection: str) -> str:
|
|
try:
|
|
client = AsyncQdrantClient(url=QDRANT_URL)
|
|
if not await client.collection_exists(collection): return ""
|
|
|
|
emb = await get_embeddings(query)
|
|
if not emb: return ""
|
|
|
|
# Ricerca mirata sul vettore BGE
|
|
res = await client.query_points(
|
|
collection_name=collection,
|
|
query=emb,
|
|
using="bge_dense", # Specifica quale indice usare
|
|
limit=5
|
|
)
|
|
return "\n\n".join([f"📄 {hit.payload['file_name']}:\n{hit.payload['content']}" for hit in res.points if hit.payload])
|
|
except: return ""
|
|
|
|
# === CHAT LOGIC ===
|
|
@cl.on_chat_start
|
|
async def on_chat_start():
|
|
user = cl.user_session.get("user")
|
|
if not user:
|
|
email = "guest@local"
|
|
profile = get_user_profile(email)
|
|
else:
|
|
email = user.identifier
|
|
profile = USER_PROFILES.get(email, get_user_profile("guest"))
|
|
|
|
cl.user_session.set("email", email)
|
|
cl.user_session.set("role", profile["role"])
|
|
cl.user_session.set("workspace", profile["workspace"])
|
|
cl.user_session.set("rag_collection", profile["rag_collection"])
|
|
cl.user_session.set("show_code", profile["show_code"])
|
|
|
|
create_workspace(profile["workspace"])
|
|
|
|
settings = [
|
|
cl.input_widget.Select(id="model", label="Modello", values=["glm-4.6:cloud", "llama3"], initial_value="glm-4.6:cloud"),
|
|
cl.input_widget.Slider(id="temp", label="Temperatura", initial=0.5, min=0, max=1, step=0.1)
|
|
]
|
|
if profile["role"] == "admin":
|
|
settings.append(cl.input_widget.Switch(id="rag", label="RAG Attivo", initial=True))
|
|
|
|
await cl.ChatSettings(settings).send()
|
|
await cl.Message(content=f"👋 Ciao **{profile['name']}**! Pronto per l'automazione.").send()
|
|
|
|
@cl.on_settings_update
|
|
async def on_settings_update(s): cl.user_session.set("settings", s)
|
|
|
|
@cl.on_message
|
|
async def on_message(message: cl.Message):
|
|
workspace = cl.user_session.get("workspace")
|
|
rag_collection = cl.user_session.get("rag_collection")
|
|
role = cl.user_session.get("role")
|
|
settings = cl.user_session.get("settings", {})
|
|
|
|
# 1. FILE UPLOAD (PDF & EXCEL)
|
|
if message.elements:
|
|
for el in message.elements:
|
|
dest = os.path.join(WORKSPACES_DIR, workspace, el.name)
|
|
shutil.copy(el.path, dest)
|
|
|
|
content = ""
|
|
if el.name.endswith(".pdf"):
|
|
content = extract_text_from_pdf(dest)
|
|
elif el.name.endswith((".xlsx", ".xls")):
|
|
await cl.Message(content=f"📊 Analisi Excel **{el.name}**...").send()
|
|
content = extract_text_from_excel(dest)
|
|
|
|
if content:
|
|
ok = await index_document(el.name, content, rag_collection)
|
|
icon = "✅" if ok else "❌"
|
|
await cl.Message(content=f"{icon} **{el.name}** elaborato.").send()
|
|
|
|
# 2. RAG & GENERATION
|
|
rag_active = settings.get("rag", True) if role == "admin" else True
|
|
context = await search_qdrant(message.content, rag_collection) if rag_active else ""
|
|
|
|
prompt = "Sei un esperto di automazione industriale."
|
|
if context: prompt += f"\n\nUSA QUESTO CONTESTO (Manuali/Excel):\n{context}"
|
|
|
|
msg = cl.Message(content="")
|
|
await msg.send()
|
|
|
|
try:
|
|
client = ollama.AsyncClient(host=OLLAMA_URL)
|
|
stream = await client.chat(
|
|
model=settings.get("model", "glm-4.6:cloud"),
|
|
messages=[{"role": "system", "content": prompt}, {"role": "user", "content": message.content}],
|
|
options={"temperature": settings.get("temp", 0.5)},
|
|
stream=True
|
|
)
|
|
async for chunk in stream:
|
|
await msg.stream_token(chunk['message']['content'])
|
|
except Exception as e:
|
|
await msg.stream_token(f"Errore connessione AI: {e}")
|
|
|
|
await msg.update() |