218 lines
8.6 KiB
Python
218 lines
8.6 KiB
Python
|
|
# Copyright 2025 CloudZero
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
#
|
||
|
|
# CHANGELOG: 2025-07-23 - Added support for using LiteLLM_SpendLogs table for CBF mapping (ishaan-jaff)
|
||
|
|
# CHANGELOG: 2025-01-19 - Refactored to use daily spend tables for proper CBF mapping (erik.peterson)
|
||
|
|
# CHANGELOG: 2025-01-19 - Migrated from pandas to polars for database operations (erik.peterson)
|
||
|
|
# CHANGELOG: 2025-01-19 - Initial database module for LiteLLM data extraction (erik.peterson)
|
||
|
|
|
||
|
|
"""Database connection and data extraction for LiteLLM."""
|
||
|
|
|
||
|
|
from datetime import datetime, timedelta
|
||
|
|
from typing import Any, Dict, Optional
|
||
|
|
|
||
|
|
import polars as pl
|
||
|
|
|
||
|
|
|
||
|
|
class LiteLLMDatabase:
|
||
|
|
"""Handle LiteLLM PostgreSQL database connections and queries."""
|
||
|
|
def _ensure_prisma_client(self):
|
||
|
|
from litellm.proxy.proxy_server import prisma_client
|
||
|
|
|
||
|
|
"""Ensure prisma client is available."""
|
||
|
|
if prisma_client is None:
|
||
|
|
raise Exception(
|
||
|
|
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||
|
|
)
|
||
|
|
return prisma_client
|
||
|
|
|
||
|
|
async def get_usage_data_for_hour(self, target_hour: datetime, limit: Optional[int] = 1000) -> pl.DataFrame:
|
||
|
|
"""Retrieve spend logs for a specific hour from LiteLLM_SpendLogs table with batching."""
|
||
|
|
client = self._ensure_prisma_client()
|
||
|
|
|
||
|
|
# Calculate hour range
|
||
|
|
hour_start = target_hour.replace(minute=0, second=0, microsecond=0)
|
||
|
|
hour_end = hour_start + timedelta(hours=1)
|
||
|
|
|
||
|
|
# Convert datetime objects to ISO format strings for PostgreSQL compatibility
|
||
|
|
hour_start_str = hour_start.isoformat()
|
||
|
|
hour_end_str = hour_end.isoformat()
|
||
|
|
|
||
|
|
# Query to get spend logs for the specific hour
|
||
|
|
query = """
|
||
|
|
SELECT *
|
||
|
|
FROM "LiteLLM_SpendLogs"
|
||
|
|
WHERE "startTime" >= $1::timestamp
|
||
|
|
AND "startTime" < $2::timestamp
|
||
|
|
ORDER BY "startTime" ASC
|
||
|
|
"""
|
||
|
|
|
||
|
|
if limit:
|
||
|
|
query += f" LIMIT {limit}"
|
||
|
|
|
||
|
|
try:
|
||
|
|
db_response = await client.db.query_raw(query, hour_start_str, hour_end_str)
|
||
|
|
# Convert the response to polars DataFrame
|
||
|
|
return pl.DataFrame(db_response) if db_response else pl.DataFrame()
|
||
|
|
except Exception as e:
|
||
|
|
raise Exception(f"Error retrieving spend logs for hour {target_hour}: {str(e)}")
|
||
|
|
|
||
|
|
|
||
|
|
async def get_table_info(self) -> Dict[str, Any]:
|
||
|
|
"""Get information about the LiteLLM_SpendLogs table."""
|
||
|
|
client = self._ensure_prisma_client()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Get row count from SpendLogs table
|
||
|
|
spend_logs_count = await self._get_table_row_count('LiteLLM_SpendLogs')
|
||
|
|
|
||
|
|
# Get column structure from spend logs table
|
||
|
|
query = """
|
||
|
|
SELECT column_name, data_type, is_nullable
|
||
|
|
FROM information_schema.columns
|
||
|
|
WHERE table_name = 'LiteLLM_SpendLogs'
|
||
|
|
ORDER BY ordinal_position;
|
||
|
|
"""
|
||
|
|
columns_response = await client.db.query_raw(query)
|
||
|
|
|
||
|
|
return {
|
||
|
|
'columns': columns_response,
|
||
|
|
'row_count': spend_logs_count,
|
||
|
|
'table_breakdown': {
|
||
|
|
'spend_logs': spend_logs_count
|
||
|
|
}
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
raise Exception(f"Error getting table info: {str(e)}")
|
||
|
|
|
||
|
|
async def _get_table_row_count(self, table_name: str) -> int:
|
||
|
|
"""Get row count from specified table."""
|
||
|
|
client = self._ensure_prisma_client()
|
||
|
|
|
||
|
|
try:
|
||
|
|
query = f'SELECT COUNT(*) as count FROM "{table_name}"'
|
||
|
|
response = await client.db.query_raw(query)
|
||
|
|
|
||
|
|
if response and len(response) > 0:
|
||
|
|
return response[0].get('count', 0)
|
||
|
|
return 0
|
||
|
|
except Exception:
|
||
|
|
return 0
|
||
|
|
|
||
|
|
async def discover_all_tables(self) -> Dict[str, Any]:
|
||
|
|
"""Discover all tables in the LiteLLM database and their schemas."""
|
||
|
|
client = self._ensure_prisma_client()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Get all LiteLLM tables
|
||
|
|
litellm_tables_query = """
|
||
|
|
SELECT table_name
|
||
|
|
FROM information_schema.tables
|
||
|
|
WHERE table_schema = 'public'
|
||
|
|
AND table_name LIKE 'LiteLLM_%'
|
||
|
|
ORDER BY table_name;
|
||
|
|
"""
|
||
|
|
tables_response = await client.db.query_raw(litellm_tables_query)
|
||
|
|
table_names = [row['table_name'] for row in tables_response]
|
||
|
|
|
||
|
|
# Get detailed schema for each table
|
||
|
|
tables_info = {}
|
||
|
|
for table_name in table_names:
|
||
|
|
# Get column information
|
||
|
|
columns_query = """
|
||
|
|
SELECT
|
||
|
|
column_name,
|
||
|
|
data_type,
|
||
|
|
is_nullable,
|
||
|
|
column_default,
|
||
|
|
character_maximum_length,
|
||
|
|
numeric_precision,
|
||
|
|
numeric_scale,
|
||
|
|
ordinal_position
|
||
|
|
FROM information_schema.columns
|
||
|
|
WHERE table_name = $1
|
||
|
|
AND table_schema = 'public'
|
||
|
|
ORDER BY ordinal_position;
|
||
|
|
"""
|
||
|
|
columns_response = await client.db.query_raw(columns_query, table_name)
|
||
|
|
|
||
|
|
# Get primary key information
|
||
|
|
pk_query = """
|
||
|
|
SELECT a.attname
|
||
|
|
FROM pg_index i
|
||
|
|
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
|
||
|
|
WHERE i.indrelid = $1::regclass AND i.indisprimary;
|
||
|
|
"""
|
||
|
|
pk_response = await client.db.query_raw(pk_query, f'"{table_name}"')
|
||
|
|
primary_keys = [row['attname'] for row in pk_response] if pk_response else []
|
||
|
|
|
||
|
|
# Get foreign key information
|
||
|
|
fk_query = """
|
||
|
|
SELECT
|
||
|
|
tc.constraint_name,
|
||
|
|
kcu.column_name,
|
||
|
|
ccu.table_name AS foreign_table_name,
|
||
|
|
ccu.column_name AS foreign_column_name
|
||
|
|
FROM information_schema.table_constraints AS tc
|
||
|
|
JOIN information_schema.key_column_usage AS kcu
|
||
|
|
ON tc.constraint_name = kcu.constraint_name
|
||
|
|
JOIN information_schema.constraint_column_usage AS ccu
|
||
|
|
ON ccu.constraint_name = tc.constraint_name
|
||
|
|
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||
|
|
AND tc.table_name = $1;
|
||
|
|
"""
|
||
|
|
fk_response = await client.db.query_raw(fk_query, table_name)
|
||
|
|
foreign_keys = fk_response if fk_response else []
|
||
|
|
|
||
|
|
# Get indexes
|
||
|
|
indexes_query = """
|
||
|
|
SELECT
|
||
|
|
i.relname AS index_name,
|
||
|
|
array_agg(a.attname ORDER BY a.attnum) AS column_names,
|
||
|
|
ix.indisunique AS is_unique
|
||
|
|
FROM pg_class t
|
||
|
|
JOIN pg_index ix ON t.oid = ix.indrelid
|
||
|
|
JOIN pg_class i ON i.oid = ix.indexrelid
|
||
|
|
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey)
|
||
|
|
WHERE t.relname = $1
|
||
|
|
AND t.relkind = 'r'
|
||
|
|
GROUP BY i.relname, ix.indisunique
|
||
|
|
ORDER BY i.relname;
|
||
|
|
"""
|
||
|
|
indexes_response = await client.db.query_raw(indexes_query, table_name)
|
||
|
|
indexes = indexes_response if indexes_response else []
|
||
|
|
|
||
|
|
# Get row count
|
||
|
|
try:
|
||
|
|
row_count = await self._get_table_row_count(table_name)
|
||
|
|
except Exception:
|
||
|
|
row_count = 0
|
||
|
|
|
||
|
|
tables_info[table_name] = {
|
||
|
|
'columns': columns_response,
|
||
|
|
'primary_keys': primary_keys,
|
||
|
|
'foreign_keys': foreign_keys,
|
||
|
|
'indexes': indexes,
|
||
|
|
'row_count': row_count
|
||
|
|
}
|
||
|
|
|
||
|
|
return {
|
||
|
|
'tables': tables_info,
|
||
|
|
'table_count': len(table_names),
|
||
|
|
'table_names': table_names
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
raise Exception(f"Error discovering tables: {str(e)}")
|
||
|
|
|