ai-station/app.py

344 lines
12 KiB
Python

import os
import re
import uuid
import shutil
from datetime import datetime
from typing import Optional, Dict, List
import chainlit as cl
import ollama
import fitz # PyMuPDF
from qdrant_client import AsyncQdrantClient
from qdrant_client.models import PointStruct, Distance, VectorParams
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
# === FIX IMPORT ROBUSTO ===
# Gestisce le differenze tra le versioni di Chainlit 2.x
try:
from chainlit.data.storage_clients import BaseStorageClient
except ImportError:
try:
from chainlit.data.base import BaseStorageClient
except ImportError:
from chainlit.data.storage_clients.base import BaseStorageClient
# === CONFIGURAZIONE ===
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://ai_user:secure_password_here@postgres:5432/ai_station")
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://192.168.1.243:11434")
QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333")
WORKSPACES_DIR = "./workspaces"
STORAGE_DIR = "./.files"
os.makedirs(STORAGE_DIR, exist_ok=True)
os.makedirs(WORKSPACES_DIR, exist_ok=True)
# === MAPPING UTENTI E RUOLI ===
USER_PROFILES = {
"giuseppe@defranceschi.pro": {
"role": "admin",
"name": "Giuseppe",
"workspace": "admin_workspace",
"rag_collection": "admin_docs",
"capabilities": ["debug", "system_prompts", "user_management", "all_models"],
"show_code": True
},
"federica.tecchio@gmail.com": {
"role": "business",
"name": "Federica",
"workspace": "business_workspace",
"rag_collection": "contabilita",
"capabilities": ["pdf_upload", "basic_chat"],
"show_code": False
},
"giuseppe.defranceschi@gmail.com": {
"role": "admin",
"name": "Giuseppe",
"workspace": "admin_workspace",
"rag_collection": "admin_docs",
"capabilities": ["debug", "system_prompts", "user_management", "all_models"],
"show_code": True
},
"riccardob545@gmail.com": {
"role": "engineering",
"name": "Riccardo",
"workspace": "engineering_workspace",
"rag_collection": "engineering_docs",
"capabilities": ["code_execution", "data_viz", "advanced_chat"],
"show_code": True
},
"giuliadefranceschi05@gmail.com": {
"role": "architecture",
"name": "Giulia",
"workspace": "architecture_workspace",
"rag_collection": "architecture_manuals",
"capabilities": ["visual_chat", "pdf_upload", "image_gen"],
"show_code": False
}
}
# === CUSTOM LOCAL STORAGE CLIENT (FIXED) ===# Questa classe ora implementa tutti i metodi astratti richiesti da Chainlit 2.8.3
class LocalStorageClient(BaseStorageClient):
"""Storage locale su filesystem per file/elementi"""
def __init__(self, storage_path: str):
self.storage_path = storage_path
os.makedirs(storage_path, exist_ok=True)
async def upload_file(
self,
object_key: str,
data: bytes,
mime: str = "application/octet-stream",
overwrite: bool = True,
) -> Dict[str, str]:
file_path = os.path.join(self.storage_path, object_key)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
f.write(data)
return {"object_key": object_key, "url": f"/files/{object_key}"}
# Implementazione metodi obbligatori mancanti nella versione precedente
async def get_read_url(self, object_key: str) -> str:
return f"/files/{object_key}"
async def delete_file(self, object_key: str) -> bool:
file_path = os.path.join(self.storage_path, object_key)
if os.path.exists(file_path):
os.remove(file_path)
return True
return False
async def close(self):
pass
# === DATA LAYER ===
@cl.data_layer
def get_data_layer():
return SQLAlchemyDataLayer(
conninfo=DATABASE_URL,
user_thread_limit=1000,
storage_provider=LocalStorageClient(storage_path=STORAGE_DIR)
)
# === OAUTH CALLBACK ===
@cl.oauth_callback
def oauth_callback(
provider_id: str,
token: str,
raw_user_data: Dict[str, str],
default_user: cl.User,
) -> Optional[cl.User]:
if provider_id == "google":
email = raw_user_data.get("email", "").lower()
# Verifica se utente è autorizzato (opzionale: blocca se non in lista)
# if email not in USER_PROFILES:
# return None
# Recupera profilo o usa default Guest
profile = USER_PROFILES.get(email, get_user_profile("guest"))
default_user.metadata.update({
"picture": raw_user_data.get("picture", ""),
"role": profile["role"],
"workspace": profile["workspace"],
"rag_collection": profile["rag_collection"],
"capabilities": profile["capabilities"],
"show_code": profile["show_code"],
"display_name": profile["name"]
})
return default_user
return default_user
# === UTILITY FUNCTIONS ===
def get_user_profile(user_email: str) -> Dict:
return USER_PROFILES.get(user_email.lower(), {
"role": "guest",
"name": "Ospite",
"workspace": "guest_workspace",
"rag_collection": "documents",
"capabilities": [],
"show_code": False
})
def create_workspace(workspace_name: str) -> str:
path = os.path.join(WORKSPACES_DIR, workspace_name)
os.makedirs(path, exist_ok=True)
return path
def save_code_to_file(code: str, workspace: str) -> str:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"code_{timestamp}.py"
file_path = os.path.join(WORKSPACES_DIR, workspace, file_name)
with open(file_path, "w", encoding="utf-8") as f:
f.write(code)
return file_path
def extract_text_from_pdf(pdf_path: str) -> str:
try:
doc = fitz.open(pdf_path)
text = "\n".join([page.get_text() for page in doc])
doc.close()
return text
except Exception:
return ""
# === QDRANT FUNCTIONS ===
async def get_qdrant_client() -> AsyncQdrantClient:
return AsyncQdrantClient(url=QDRANT_URL)
async def ensure_collection(collection_name: str):
client = await get_qdrant_client()
if not await client.collection_exists(collection_name):
await client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=768, distance=Distance.COSINE)
)
async def get_embeddings(text: str) -> list:
client = ollama.Client(host=OLLAMA_URL)
try:
response = client.embed(model='nomic-embed-text', input=text[:2000])
if 'embeddings' in response: return response['embeddings'][0]
return response.get('embedding', [])
except: return []
async def index_document(file_name: str, content: str, collection_name: str) -> bool:
try:
await ensure_collection(collection_name)
embedding = await get_embeddings(content)
if not embedding: return False
qdrant = await get_qdrant_client()
await qdrant.upsert(
collection_name=collection_name,
points=[PointStruct(
id=str(uuid.uuid4()),
vector=embedding,
payload={"file_name": file_name, "content": content[:3000], "indexed_at": datetime.now().isoformat()}
)]
)
return True
except: return False
async def search_qdrant(query: str, collection: str) -> str:
try:
client = await get_qdrant_client()
if not await client.collection_exists(collection): return ""
emb = await get_embeddings(query)
if not emb: return ""
res = await client.query_points(collection_name=collection, query=emb, limit=3)
return "\n\n".join([hit.payload['content'] for hit in res.points if hit.payload])
except: return ""
# === CHAINLIT HANDLERS ===
@cl.on_chat_start
async def on_chat_start():
user = cl.user_session.get("user")
if not user:
# Fallback locale se non c'è auth
user_email = "guest@local"
profile = get_user_profile(user_email)
else:
user_email = user.identifier
# I metadati sono già popolati dalla callback oauth
profile = USER_PROFILES.get(user_email, get_user_profile("guest"))
# Salva in sessione
cl.user_session.set("email", user_email)
cl.user_session.set("role", profile["role"])
cl.user_session.set("workspace", profile["workspace"])
cl.user_session.set("rag_collection", profile["rag_collection"])
cl.user_session.set("show_code", profile["show_code"])
create_workspace(profile["workspace"])
# === SETTINGS WIDGETS ===
settings_widgets = [
cl.input_widget.Select(
id="model",
label="Modello AI",
values=["glm-4.6:cloud", "llama3.2", "mistral", "qwen2.5-coder:32b"],
initial_value="glm-4.6:cloud",
),
cl.input_widget.Slider(
id="temperature",
label="Temperatura",
initial=0.7, min=0, max=2, step=0.1,
),
]
if profile["role"] == "admin":
settings_widgets.append(cl.input_widget.Switch(id="rag_enabled", label="Abilita RAG", initial=True))
await cl.ChatSettings(settings_widgets).send()
await cl.Message(
content=f"👋 Ciao **{profile['name']}**!\n"
f"Ruolo: `{profile['role']}` | Workspace: `{profile['workspace']}`\n"
).send()
@cl.on_settings_update
async def on_settings_update(settings):
cl.user_session.set("settings", settings)
await cl.Message(content="✅ Impostazioni aggiornate").send()
@cl.on_message
async def on_message(message: cl.Message):
workspace = cl.user_session.get("workspace")
rag_collection = cl.user_session.get("rag_collection")
user_role = cl.user_session.get("role")
show_code = cl.user_session.get("show_code")
settings = cl.user_session.get("settings", {})
model = settings.get("model", "glm-4.6:cloud")
temperature = settings.get("temperature", 0.7)
rag_enabled = settings.get("rag_enabled", True) if user_role == "admin" else True
# 1. GESTIONE FILE
if message.elements:
for element in message.elements:
dest = os.path.join(WORKSPACES_DIR, workspace, element.name)
shutil.copy(element.path, dest)
if element.name.endswith(".pdf"):
text = extract_text_from_pdf(dest)
if text:
await index_document(element.name, text, rag_collection)
await cl.Message(content=f"✅ **{element.name}** indicizzato.").send()
# 2. RAG
context = ""
if rag_enabled:
context = await search_qdrant(message.content, rag_collection)
system_prompt = "Sei un assistente esperto."
if context: system_prompt += f"\n\nCONTESTO:\n{context}"
# 3. GENERAZIONE
client = ollama.AsyncClient(host=OLLAMA_URL)
msg = cl.Message(content="")
await msg.send()
stream = await client.chat(
model=model,
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": message.content}],
options={"temperature": temperature},
stream=True
)
full_resp = ""
async for chunk in stream:
token = chunk['message']['content']
full_resp += token
await msg.stream_token(token)
await msg.update()
# 4. SALVATAGGIO CODICE
if show_code:
blocks = re.findall(r"``````", full_resp, re.DOTALL)
elements = []
for code in blocks:
path = save_code_to_file(code.strip(), workspace)
elements.append(cl.File(name=os.path.basename(path), path=path, display="inline"))
if elements:
await cl.Message(content="💾 Codice salvato", elements=elements).send()