implementazione BGE-M3 Dense
This commit is contained in:
parent
4c4e7b92a7
commit
9cef64f9ea
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
357
app.py
357
app.py
|
|
@ -2,17 +2,19 @@ import os
|
|||
import re
|
||||
import uuid
|
||||
import shutil
|
||||
import pandas as pd # NUOVO: Gestione Excel
|
||||
import httpx # NUOVO: Chiamate API Remote
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, List
|
||||
import chainlit as cl
|
||||
import ollama
|
||||
import fitz # PyMuPDF
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.models import PointStruct, Distance, VectorParams
|
||||
from qdrant_client import models
|
||||
from qdrant_client.models import PointStruct, Distance, VectorParams, SparseVectorParams, SparseIndexParams
|
||||
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
|
||||
|
||||
# === FIX IMPORT ROBUSTO ===
|
||||
# Gestisce le differenze tra le versioni di Chainlit 2.x
|
||||
# === FIX IMPORT ===
|
||||
try:
|
||||
from chainlit.data.storage_clients import BaseStorageClient
|
||||
except ImportError:
|
||||
|
|
@ -23,8 +25,11 @@ except ImportError:
|
|||
|
||||
# === CONFIGURAZIONE ===
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://ai_user:secure_password_here@postgres:5432/ai_station")
|
||||
# PUNTANO AL SERVER .243 (Il "Cervello")
|
||||
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://192.168.1.243:11434")
|
||||
BGE_API_URL = os.getenv("BGE_API_URL", "http://192.168.1.243:8001")
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333")
|
||||
|
||||
WORKSPACES_DIR = "./workspaces"
|
||||
STORAGE_DIR = "./.files"
|
||||
|
||||
|
|
@ -75,178 +80,172 @@ USER_PROFILES = {
|
|||
}
|
||||
}
|
||||
|
||||
# === CUSTOM LOCAL STORAGE CLIENT (FIXED) ===# Questa classe ora implementa tutti i metodi astratti richiesti da Chainlit 2.8.3
|
||||
# === STORAGE CLIENT ===
|
||||
class LocalStorageClient(BaseStorageClient):
|
||||
"""Storage locale su filesystem per file/elementi"""
|
||||
|
||||
def __init__(self, storage_path: str):
|
||||
self.storage_path = storage_path
|
||||
os.makedirs(storage_path, exist_ok=True)
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
object_key: str,
|
||||
data: bytes,
|
||||
mime: str = "application/octet-stream",
|
||||
overwrite: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
async def upload_file(self, object_key: str, data: bytes, mime: str = "application/octet-stream", overwrite: bool = True) -> Dict[str, str]:
|
||||
file_path = os.path.join(self.storage_path, object_key)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(data)
|
||||
with open(file_path, "wb") as f: f.write(data)
|
||||
return {"object_key": object_key, "url": f"/files/{object_key}"}
|
||||
|
||||
# Implementazione metodi obbligatori mancanti nella versione precedente
|
||||
async def get_read_url(self, object_key: str) -> str:
|
||||
return f"/files/{object_key}"
|
||||
|
||||
async def get_read_url(self, object_key: str) -> str: return f"/files/{object_key}"
|
||||
async def delete_file(self, object_key: str) -> bool:
|
||||
file_path = os.path.join(self.storage_path, object_key)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
return True
|
||||
path = os.path.join(self.storage_path, object_key)
|
||||
if os.path.exists(path): os.remove(path); return True
|
||||
return False
|
||||
async def close(self): pass
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
# === DATA LAYER ===
|
||||
@cl.data_layer
|
||||
def get_data_layer():
|
||||
return SQLAlchemyDataLayer(
|
||||
conninfo=DATABASE_URL,
|
||||
user_thread_limit=1000,
|
||||
storage_provider=LocalStorageClient(storage_path=STORAGE_DIR)
|
||||
)
|
||||
return SQLAlchemyDataLayer(conninfo=DATABASE_URL, user_thread_limit=1000, storage_provider=LocalStorageClient(STORAGE_DIR))
|
||||
|
||||
# === OAUTH CALLBACK ===
|
||||
# === OAUTH ===
|
||||
@cl.oauth_callback
|
||||
def oauth_callback(
|
||||
provider_id: str,
|
||||
token: str,
|
||||
raw_user_data: Dict[str, str],
|
||||
default_user: cl.User,
|
||||
) -> Optional[cl.User]:
|
||||
def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, str], default_user: cl.User) -> Optional[cl.User]:
|
||||
if provider_id == "google":
|
||||
email = raw_user_data.get("email", "").lower()
|
||||
|
||||
# Verifica se utente è autorizzato (opzionale: blocca se non in lista)
|
||||
# if email not in USER_PROFILES:
|
||||
# return None
|
||||
|
||||
# Recupera profilo o usa default Guest
|
||||
profile = USER_PROFILES.get(email, get_user_profile("guest"))
|
||||
|
||||
default_user.metadata.update({
|
||||
"picture": raw_user_data.get("picture", ""),
|
||||
"role": profile["role"],
|
||||
"workspace": profile["workspace"],
|
||||
"rag_collection": profile["rag_collection"],
|
||||
"capabilities": profile["capabilities"],
|
||||
"show_code": profile["show_code"],
|
||||
"display_name": profile["name"]
|
||||
"role": profile["role"], "workspace": profile["workspace"],
|
||||
"rag_collection": profile["rag_collection"], "capabilities": profile["capabilities"],
|
||||
"show_code": profile["show_code"], "display_name": profile["name"]
|
||||
})
|
||||
return default_user
|
||||
return default_user
|
||||
|
||||
# === UTILITY FUNCTIONS ===
|
||||
def get_user_profile(user_email: str) -> Dict:
|
||||
return USER_PROFILES.get(user_email.lower(), {
|
||||
"role": "guest",
|
||||
"name": "Ospite",
|
||||
"workspace": "guest_workspace",
|
||||
"rag_collection": "documents",
|
||||
"capabilities": [],
|
||||
"show_code": False
|
||||
})
|
||||
def get_user_profile(email: str) -> Dict:
|
||||
return USER_PROFILES.get(email.lower(), {"role": "guest", "name": "Ospite", "workspace": "guest_workspace", "rag_collection": "documents", "show_code": False})
|
||||
|
||||
def create_workspace(workspace_name: str) -> str:
|
||||
path = os.path.join(WORKSPACES_DIR, workspace_name)
|
||||
def create_workspace(name: str) -> str:
|
||||
path = os.path.join(WORKSPACES_DIR, name)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
def save_code_to_file(code: str, workspace: str) -> str:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
file_name = f"code_{timestamp}.py"
|
||||
file_path = os.path.join(WORKSPACES_DIR, workspace, file_name)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
return file_path
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
path = os.path.join(WORKSPACES_DIR, workspace, f"code_{ts}.py")
|
||||
with open(path, "w", encoding="utf-8") as f: f.write(code)
|
||||
return path
|
||||
|
||||
def extract_text_from_pdf(pdf_path: str) -> str:
|
||||
# === PARSING DOCUMENTI ===
|
||||
def extract_text_from_pdf(path: str) -> str:
|
||||
try:
|
||||
doc = fitz.open(pdf_path)
|
||||
text = "\n".join([page.get_text() for page in doc])
|
||||
doc.close()
|
||||
return text
|
||||
except Exception:
|
||||
doc = fitz.open(path)
|
||||
return "\n".join([page.get_text() for page in doc])
|
||||
except: return ""
|
||||
|
||||
def extract_text_from_excel(path: str) -> str:
|
||||
"""Estrae testo da Excel convertendo i fogli in Markdown"""
|
||||
try:
|
||||
xl = pd.read_excel(path, sheet_name=None)
|
||||
text_content = []
|
||||
for sheet, df in xl.items():
|
||||
text_content.append(f"\n--- Foglio Excel: {sheet} ---\n")
|
||||
# Pulisce NaN e converte in stringa
|
||||
clean_df = df.fillna("").astype(str)
|
||||
text_content.append(clean_df.to_markdown(index=False))
|
||||
return "\n".join(text_content)
|
||||
except Exception as e:
|
||||
print(f"❌ Errore Excel: {e}")
|
||||
return ""
|
||||
|
||||
# === QDRANT FUNCTIONS ===
|
||||
async def get_qdrant_client() -> AsyncQdrantClient:
|
||||
return AsyncQdrantClient(url=QDRANT_URL)
|
||||
|
||||
async def ensure_collection(collection_name: str):
|
||||
client = await get_qdrant_client()
|
||||
if not await client.collection_exists(collection_name):
|
||||
await client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=768, distance=Distance.COSINE)
|
||||
)
|
||||
|
||||
# === AI & EMBEDDINGS (Remoto) ===
|
||||
async def get_embeddings(text: str) -> list:
|
||||
client = ollama.Client(host=OLLAMA_URL)
|
||||
try:
|
||||
response = client.embed(model='nomic-embed-text', input=text[:2000])
|
||||
if 'embeddings' in response: return response['embeddings'][0]
|
||||
return response.get('embedding', [])
|
||||
except: return []
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{BGE_API_URL}/embed",
|
||||
json={"texts": [text], "normalize": True}
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
# Gestisce sia il vecchio formato (lista diretta) che il nuovo (dict)
|
||||
if isinstance(data, list): return data[0] # Vecchia API
|
||||
if "dense" in data: return data["dense"][0] # Nuova API Hybrid
|
||||
if "embeddings" in data: return data["embeddings"][0] # API precedente
|
||||
except Exception as e:
|
||||
print(f"⚠️ Errore Embedding: {e}")
|
||||
return []
|
||||
|
||||
async def index_document(file_name: str, content: str, collection_name: str) -> bool:
|
||||
try:
|
||||
await ensure_collection(collection_name)
|
||||
embedding = await get_embeddings(content)
|
||||
if not embedding: return False
|
||||
|
||||
qdrant = await get_qdrant_client()
|
||||
await qdrant.upsert(
|
||||
collection_name=collection_name,
|
||||
points=[PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=embedding,
|
||||
payload={"file_name": file_name, "content": content[:3000], "indexed_at": datetime.now().isoformat()}
|
||||
)]
|
||||
async def ensure_collection(name: str):
|
||||
client = AsyncQdrantClient(url=QDRANT_URL)
|
||||
if not await client.collection_exists(name):
|
||||
# Creiamo una collezione ottimizzata
|
||||
await client.create_collection(
|
||||
collection_name=name,
|
||||
vectors_config={
|
||||
"bge_dense": VectorParams(size=1024, distance=Distance.COSINE)
|
||||
}
|
||||
# Se in futuro abilitiamo lo sparse output dal .243, aggiungeremo:
|
||||
# sparse_vectors_config={"bge_sparse": SparseVectorParams(index=SparseIndexParams(on_disk=False))}
|
||||
)
|
||||
return True
|
||||
except: return False
|
||||
|
||||
async def index_document(filename: str, content: str, collection: str) -> bool:
|
||||
try:
|
||||
await ensure_collection(collection)
|
||||
chunks = [content[i:i+3000] for i in range(0, len(content), 3000)]
|
||||
|
||||
qdrant = AsyncQdrantClient(url=QDRANT_URL)
|
||||
points = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Ottieni embedding (assume che get_embeddings ritorni la lista float)
|
||||
# Nota: Se hai aggiornato l'API .243 per ritornare un dict {"dense": ...},
|
||||
# devi aggiornare get_embeddings per estrarre ["dense"]!
|
||||
|
||||
# Vedere funzione get_embeddings aggiornata sotto
|
||||
emb = await get_embeddings(chunk)
|
||||
|
||||
if emb:
|
||||
points.append(PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
# Vettori nominati
|
||||
vector={"bge_dense": emb},
|
||||
payload={"file_name": filename, "content": chunk, "chunk_id": i}
|
||||
))
|
||||
|
||||
if points:
|
||||
await qdrant.upsert(collection_name=collection, points=points)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Index Error: {e}")
|
||||
return False
|
||||
|
||||
async def search_qdrant(query: str, collection: str) -> str:
|
||||
try:
|
||||
client = await get_qdrant_client()
|
||||
client = AsyncQdrantClient(url=QDRANT_URL)
|
||||
if not await client.collection_exists(collection): return ""
|
||||
|
||||
emb = await get_embeddings(query)
|
||||
if not emb: return ""
|
||||
res = await client.query_points(collection_name=collection, query=emb, limit=3)
|
||||
return "\n\n".join([hit.payload['content'] for hit in res.points if hit.payload])
|
||||
|
||||
# Ricerca mirata sul vettore BGE
|
||||
res = await client.query_points(
|
||||
collection_name=collection,
|
||||
query=emb,
|
||||
using="bge_dense", # Specifica quale indice usare
|
||||
limit=5
|
||||
)
|
||||
return "\n\n".join([f"📄 {hit.payload['file_name']}:\n{hit.payload['content']}" for hit in res.points if hit.payload])
|
||||
except: return ""
|
||||
|
||||
# === CHAINLIT HANDLERS ===
|
||||
|
||||
# === CHAT LOGIC ===
|
||||
@cl.on_chat_start
|
||||
async def on_chat_start():
|
||||
user = cl.user_session.get("user")
|
||||
|
||||
if not user:
|
||||
# Fallback locale se non c'è auth
|
||||
user_email = "guest@local"
|
||||
profile = get_user_profile(user_email)
|
||||
email = "guest@local"
|
||||
profile = get_user_profile(email)
|
||||
else:
|
||||
user_email = user.identifier
|
||||
# I metadati sono già popolati dalla callback oauth
|
||||
profile = USER_PROFILES.get(user_email, get_user_profile("guest"))
|
||||
email = user.identifier
|
||||
profile = USER_PROFILES.get(email, get_user_profile("guest"))
|
||||
|
||||
# Salva in sessione
|
||||
cl.user_session.set("email", user_email)
|
||||
cl.user_session.set("email", email)
|
||||
cl.user_session.set("role", profile["role"])
|
||||
cl.user_session.set("workspace", profile["workspace"])
|
||||
cl.user_session.set("rag_collection", profile["rag_collection"])
|
||||
|
|
@ -254,91 +253,65 @@ async def on_chat_start():
|
|||
|
||||
create_workspace(profile["workspace"])
|
||||
|
||||
# === SETTINGS WIDGETS ===
|
||||
settings_widgets = [
|
||||
cl.input_widget.Select(
|
||||
id="model",
|
||||
label="Modello AI",
|
||||
values=["glm-4.6:cloud", "llama3.2", "mistral", "qwen2.5-coder:32b"],
|
||||
initial_value="glm-4.6:cloud",
|
||||
),
|
||||
cl.input_widget.Slider(
|
||||
id="temperature",
|
||||
label="Temperatura",
|
||||
initial=0.7, min=0, max=2, step=0.1,
|
||||
),
|
||||
settings = [
|
||||
cl.input_widget.Select(id="model", label="Modello", values=["glm-4.6:cloud", "llama3"], initial_value="glm-4.6:cloud"),
|
||||
cl.input_widget.Slider(id="temp", label="Temperatura", initial=0.5, min=0, max=1, step=0.1)
|
||||
]
|
||||
if profile["role"] == "admin":
|
||||
settings_widgets.append(cl.input_widget.Switch(id="rag_enabled", label="Abilita RAG", initial=True))
|
||||
settings.append(cl.input_widget.Switch(id="rag", label="RAG Attivo", initial=True))
|
||||
|
||||
await cl.ChatSettings(settings_widgets).send()
|
||||
|
||||
await cl.Message(
|
||||
content=f"👋 Ciao **{profile['name']}**!\n"
|
||||
f"Ruolo: `{profile['role']}` | Workspace: `{profile['workspace']}`\n"
|
||||
).send()
|
||||
await cl.ChatSettings(settings).send()
|
||||
await cl.Message(content=f"👋 Ciao **{profile['name']}**! Pronto per l'automazione.").send()
|
||||
|
||||
@cl.on_settings_update
|
||||
async def on_settings_update(settings):
|
||||
cl.user_session.set("settings", settings)
|
||||
await cl.Message(content="✅ Impostazioni aggiornate").send()
|
||||
async def on_settings_update(s): cl.user_session.set("settings", s)
|
||||
|
||||
@cl.on_message
|
||||
async def on_message(message: cl.Message):
|
||||
workspace = cl.user_session.get("workspace")
|
||||
rag_collection = cl.user_session.get("rag_collection")
|
||||
user_role = cl.user_session.get("role")
|
||||
show_code = cl.user_session.get("show_code")
|
||||
|
||||
role = cl.user_session.get("role")
|
||||
settings = cl.user_session.get("settings", {})
|
||||
model = settings.get("model", "glm-4.6:cloud")
|
||||
temperature = settings.get("temperature", 0.7)
|
||||
rag_enabled = settings.get("rag_enabled", True) if user_role == "admin" else True
|
||||
|
||||
# 1. GESTIONE FILE
|
||||
|
||||
# 1. FILE UPLOAD (PDF & EXCEL)
|
||||
if message.elements:
|
||||
for element in message.elements:
|
||||
dest = os.path.join(WORKSPACES_DIR, workspace, element.name)
|
||||
shutil.copy(element.path, dest)
|
||||
if element.name.endswith(".pdf"):
|
||||
text = extract_text_from_pdf(dest)
|
||||
if text:
|
||||
await index_document(element.name, text, rag_collection)
|
||||
await cl.Message(content=f"✅ **{element.name}** indicizzato.").send()
|
||||
for el in message.elements:
|
||||
dest = os.path.join(WORKSPACES_DIR, workspace, el.name)
|
||||
shutil.copy(el.path, dest)
|
||||
|
||||
content = ""
|
||||
if el.name.endswith(".pdf"):
|
||||
content = extract_text_from_pdf(dest)
|
||||
elif el.name.endswith((".xlsx", ".xls")):
|
||||
await cl.Message(content=f"📊 Analisi Excel **{el.name}**...").send()
|
||||
content = extract_text_from_excel(dest)
|
||||
|
||||
if content:
|
||||
ok = await index_document(el.name, content, rag_collection)
|
||||
icon = "✅" if ok else "❌"
|
||||
await cl.Message(content=f"{icon} **{el.name}** elaborato.").send()
|
||||
|
||||
# 2. RAG
|
||||
context = ""
|
||||
if rag_enabled:
|
||||
context = await search_qdrant(message.content, rag_collection)
|
||||
# 2. RAG & GENERATION
|
||||
rag_active = settings.get("rag", True) if role == "admin" else True
|
||||
context = await search_qdrant(message.content, rag_collection) if rag_active else ""
|
||||
|
||||
system_prompt = "Sei un assistente esperto."
|
||||
if context: system_prompt += f"\n\nCONTESTO:\n{context}"
|
||||
prompt = "Sei un esperto di automazione industriale."
|
||||
if context: prompt += f"\n\nUSA QUESTO CONTESTO (Manuali/Excel):\n{context}"
|
||||
|
||||
# 3. GENERAZIONE
|
||||
client = ollama.AsyncClient(host=OLLAMA_URL)
|
||||
msg = cl.Message(content="")
|
||||
await msg.send()
|
||||
|
||||
stream = await client.chat(
|
||||
model=model,
|
||||
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": message.content}],
|
||||
options={"temperature": temperature},
|
||||
stream=True
|
||||
)
|
||||
try:
|
||||
client = ollama.AsyncClient(host=OLLAMA_URL)
|
||||
stream = await client.chat(
|
||||
model=settings.get("model", "glm-4.6:cloud"),
|
||||
messages=[{"role": "system", "content": prompt}, {"role": "user", "content": message.content}],
|
||||
options={"temperature": settings.get("temp", 0.5)},
|
||||
stream=True
|
||||
)
|
||||
async for chunk in stream:
|
||||
await msg.stream_token(chunk['message']['content'])
|
||||
except Exception as e:
|
||||
await msg.stream_token(f"Errore connessione AI: {e}")
|
||||
|
||||
full_resp = ""
|
||||
async for chunk in stream:
|
||||
token = chunk['message']['content']
|
||||
full_resp += token
|
||||
await msg.stream_token(token)
|
||||
await msg.update()
|
||||
|
||||
# 4. SALVATAGGIO CODICE
|
||||
if show_code:
|
||||
blocks = re.findall(r"``````", full_resp, re.DOTALL)
|
||||
elements = []
|
||||
for code in blocks:
|
||||
path = save_code_to_file(code.strip(), workspace)
|
||||
elements.append(cl.File(name=os.path.basename(path), path=path, display="inline"))
|
||||
if elements:
|
||||
await cl.Message(content="💾 Codice salvato", elements=elements).send()
|
||||
await msg.update()
|
||||
|
|
@ -26,4 +26,8 @@ aiofiles>=23.0.0
|
|||
sniffio
|
||||
aiohttp
|
||||
boto3>=1.28.0
|
||||
azure-storage-file-datalake>=12.14.0
|
||||
azure-storage-file-datalake>=12.14.0
|
||||
# NUOVI PER EXCEL
|
||||
pandas
|
||||
openpyxl
|
||||
tabulate
|
||||
|
|
|
|||
Loading…
Reference in New Issue