ai-station/app.py

207 lines
7.4 KiB
Python

import os
import chainlit as cl
import re
from datetime import datetime
import shutil
import uuid
import ollama # Import spostato all'inizio per efficienza
from qdrant_client import QdrantClient
from qdrant_client.http.models import PointStruct
# Define user roles mapping
USER_ROLES = {
'moglie@esempio.com': 'business',
'ingegnere@esempio.com': 'engineering',
'architetto@esempio.com': 'architecture',
'admin@esempio.com': 'admin'
}
# Define the path for workspaces
WORKSPACES_DIR = "./workspaces"
def create_workspace(user_role):
workspace_path = os.path.join(WORKSPACES_DIR, user_role)
if not os.path.exists(workspace_path):
os.makedirs(workspace_path)
def save_code_to_file(code, user_role):
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"code_{timestamp}.py"
file_path = os.path.join(WORKSPACES_DIR, user_role, file_name)
with open(file_path, "w") as file:
file.write(code)
return file_path
def limit_history(history):
if len(history) > 20:
history = history[-20:]
return history
async def connect_to_qdrant():
client = QdrantClient("http://qdrant:6333")
collection_name = "documents"
try:
client.get_collection(collection_name)
except Exception as e:
client.create_collection(
collection_name=collection_name,
vectors_config={"size": 768, "distance": "Cosine"}
)
return client
async def get_embeddings(text):
# --- CORREZIONE CRITICA ---
# Inizializza il client usando l'URL completo (come in message), NON 'host=...'
# Questo evita l'errore "127.0.0.1:porta casuale"
ollama_api_base = os.getenv('OLLAMA_API_BASE', 'http://192.168.1.243:11434')
client = ollama.Client(ollama_api_base)
# Controllo lunghezza testo
if len(text) > 12000:
text = text[:12000]
response = client.embed(model='nomic-embed-text', input=text)
# Gestione compatibilità risposta (embedding vs embeddings)
if 'embeddings' in response:
return response['embeddings'][0]
return response['embedding']
# Nuova funzione per CERCARE nei documenti (RAG)
async def search_qdrant(query_text, user_role):
"""Cerca documenti pertinenti su Qdrant"""
try:
qdrant_client = await connect_to_qdrant()
query_embedding = await get_embeddings(query_text)
# Cerca i 3 documenti più simili alla domanda
search_result = qdrant_client.search(
collection_name="documents",
query_vector=query_embedding,
limit=3
)
contexts = []
for hit in search_result:
if 'payload' in hit and 'file_name' in hit['payload']:
contexts.append(f"Documento: {hit['payload']['file_name']}")
return "\n".join(contexts)
except Exception as e:
print(f"Errore ricerca: {e}")
return ""
@cl.on_chat_start
async def chat_start():
# Set the user's email to a hardcoded value for testing purposes
user_email = "admin@esempio.com"
# Determine the user's role based on the email
user_role = USER_ROLES.get(user_email, 'guest')
# Create workspace directory if it doesn't exist
create_workspace(user_role)
# Initialize history in the session
cl.user_session.set("history", [])
# Set the user's role in the session
cl.user_session.set("role", user_role)
# Send a welcome message based on the user's role
if user_role == 'admin':
await cl.Message(content="Welcome, Admin!").send()
elif user_role == 'engineering':
await cl.Message(content="Welcome, Engineer!").send()
elif user_role == 'business':
await cl.Message(content="Welcome, Business User!").send()
elif user_role == 'architecture':
await cl.Message(content="Welcome, Architect!").send()
else:
await cl.Message(content="Welcome, Guest!").send()
@cl.on_message
async def message(message):
# Retrieve the user's role from the session
user_role = cl.user_session.get("role", 'guest')
if not user_role:
await cl.Message(content="User role not found").send()
return
# Initialize the Ollama client
ollama_api_base = os.getenv('OLLAMA_API_BASE', 'http://192.168.1.243:11434')
try:
client = ollama.Client(ollama_api_base)
# Retrieve the history from the session and limit it
history = cl.user_session.get("history", [])
history = limit_history(history)
# --- RAG STEP: Cerca nei documenti prima di chattare ---
context_text = await search_qdrant(message.content, user_role)
# Se trova documenti, inietta il contesto come "System Message"
if context_text:
system_prompt = f"Contexto dai documenti:\n{context_text}\n\nRispondi usando questo contesto."
history.insert(0, {"role": "system", "content": system_prompt})
# Append the new user message to the history
history.append({"role": "user", "content": message.content})
# Check for Uploads
if message.elements:
uploaded_files = []
for element in message.elements:
try:
# Save file to disk
dest_path = os.path.join(WORKSPACES_DIR, user_role, element.name)
with open(element.path, 'rb') as src, open(dest_path, 'wb') as dst:
shutil.copyfileobj(src, dst)
# Indexing on Qdrant if .txt
if element.name.endswith('.txt'):
with open(dest_path, 'r') as f:
content = f.read()
embeddings = await get_embeddings(content)
qdrant_client = await connect_to_qdrant()
point_id = uuid.uuid4()
point = PointStruct(id=point_id, vector=embeddings, payload={"file_name": element.name})
qdrant_client.upsert(collection_name="documents", points=[point])
await cl.Message(content=f"Documento '{element.name}' indicizzato.").send()
uploaded_files.append(element.name)
except Exception as e:
await cl.Message(content=f"Error saving {element.name}: {e}").send()
if uploaded_files:
await cl.Message(content=f"Files saved: {', '.join(uploaded_files)}").send()
# Call the model
response = client.chat(model='qwen2.5-coder:7b', messages=history)
# Extract code blocks
code_blocks = re.findall(r"```python(.*?)```", response['message']['content'], re.DOTALL)
elements = []
if code_blocks:
for code in code_blocks:
file_path = save_code_to_file(code, user_role)
elements.append(cl.File(name=os.path.basename(file_path), path=file_path))
# Append AI response to history
history.append({"role": "assistant", "content": response['message']['content']})
cl.user_session.set("history", history)
# Send final message
await cl.Message(content=response['message']['content'], elements=elements).send()
except Exception as e:
await cl.Message(content=f"Error: {e}").send()