package aiapm import ( "database/sql" "encoding/json" "fmt" "strconv" "strings" "time" "github.com/google/uuid" ) // CreateTable creates the ai_calls table and indexes func CreateTable(db *sql.DB) error { schema := ` CREATE TABLE IF NOT EXISTS ai_calls ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(), service_name VARCHAR(255) NOT NULL, project_id VARCHAR(255) NOT NULL DEFAULT '', vendor VARCHAR(100) NOT NULL, model VARCHAR(255) NOT NULL, tokens_in INT NOT NULL DEFAULT 0, tokens_out INT NOT NULL DEFAULT 0, tokens_cache_read INT NOT NULL DEFAULT 0, tokens_cache_write INT NOT NULL DEFAULT 0, estimated_cost DOUBLE PRECISION NOT NULL DEFAULT 0, latency_ms INT NOT NULL DEFAULT 0, ttfb_ms INT NOT NULL DEFAULT 0, status VARCHAR(20) NOT NULL DEFAULT 'success', error_message TEXT, stream BOOLEAN NOT NULL DEFAULT FALSE, cached BOOLEAN NOT NULL DEFAULT FALSE, tags JSONB ); CREATE INDEX IF NOT EXISTS idx_ai_calls_timestamp ON ai_calls(timestamp DESC); CREATE INDEX IF NOT EXISTS idx_ai_calls_service ON ai_calls(service_name); CREATE INDEX IF NOT EXISTS idx_ai_calls_vendor ON ai_calls(vendor); CREATE INDEX IF NOT EXISTS idx_ai_calls_model ON ai_calls(model); CREATE INDEX IF NOT EXISTS idx_ai_calls_project ON ai_calls(project_id); CREATE INDEX IF NOT EXISTS idx_ai_calls_status ON ai_calls(status); CREATE INDEX IF NOT EXISTS idx_ai_calls_vendor_model ON ai_calls(vendor, model); ` _, err := db.Exec(schema) return err } // InsertCall inserts a single AI call record func InsertCall(db *sql.DB, r AICallRecord) error { if r.ID == "" { r.ID = uuid.New().String() } if r.Timestamp.IsZero() { r.Timestamp = time.Now() } tags, _ := json.Marshal(r.Tags) _, err := db.Exec(` INSERT INTO ai_calls (id, timestamp, service_name, project_id, vendor, model, tokens_in, tokens_out, tokens_cache_read, tokens_cache_write, estimated_cost, latency_ms, ttfb_ms, status, error_message, stream, cached, tags) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18)`, r.ID, r.Timestamp, r.ServiceName, r.ProjectID, r.Vendor, r.Model, r.TokensIn, r.TokensOut, r.TokensCacheRead, r.TokensCacheWrite, r.EstimatedCost, r.LatencyMs, r.TTFBMs, r.Status, r.ErrorMessage, r.Stream, r.Cached, tags) return err } // InsertCallBatch inserts multiple AI call records in a single transaction func InsertCallBatch(db *sql.DB, records []AICallRecord) error { if len(records) == 0 { return nil } tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() stmt, err := tx.Prepare(` INSERT INTO ai_calls (id, timestamp, service_name, project_id, vendor, model, tokens_in, tokens_out, tokens_cache_read, tokens_cache_write, estimated_cost, latency_ms, ttfb_ms, status, error_message, stream, cached, tags) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18)`) if err != nil { return err } defer stmt.Close() for _, r := range records { if r.ID == "" { r.ID = uuid.New().String() } if r.Timestamp.IsZero() { r.Timestamp = time.Now() } tags, _ := json.Marshal(r.Tags) _, err := stmt.Exec( r.ID, r.Timestamp, r.ServiceName, r.ProjectID, r.Vendor, r.Model, r.TokensIn, r.TokensOut, r.TokensCacheRead, r.TokensCacheWrite, r.EstimatedCost, r.LatencyMs, r.TTFBMs, r.Status, r.ErrorMessage, r.Stream, r.Cached, tags) if err != nil { return err } } return tx.Commit() } // buildWhereClause constructs WHERE clause from filter func buildWhereClause(f AICallFilter, startArg int) (string, []any) { var conditions []string var args []any n := startArg if !f.From.IsZero() { conditions = append(conditions, "timestamp >= $"+strconv.Itoa(n)) args = append(args, f.From) n++ } if !f.To.IsZero() { conditions = append(conditions, "timestamp <= $"+strconv.Itoa(n)) args = append(args, f.To) n++ } if f.ServiceName != "" { conditions = append(conditions, "service_name = $"+strconv.Itoa(n)) args = append(args, f.ServiceName) n++ } if f.ProjectID != "" { conditions = append(conditions, "project_id = $"+strconv.Itoa(n)) args = append(args, f.ProjectID) n++ } if f.Vendor != "" { conditions = append(conditions, "vendor = $"+strconv.Itoa(n)) args = append(args, f.Vendor) n++ } if f.Model != "" { conditions = append(conditions, "model = $"+strconv.Itoa(n)) args = append(args, f.Model) n++ } if f.Status != "" { conditions = append(conditions, "status = $"+strconv.Itoa(n)) args = append(args, f.Status) n++ } if len(conditions) == 0 { return "", args } return " WHERE " + strings.Join(conditions, " AND "), args } // QueryCalls queries AI call records with filters func QueryCalls(db *sql.DB, filter AICallFilter) ([]AICallRecord, error) { where, args := buildWhereClause(filter, 1) limit := filter.Limit if limit <= 0 { limit = 100 } offset := filter.Offset q := `SELECT id, timestamp, service_name, project_id, vendor, model, tokens_in, tokens_out, tokens_cache_read, tokens_cache_write, estimated_cost, latency_ms, ttfb_ms, status, COALESCE(error_message,''), stream, cached, COALESCE(tags::text,'{}') FROM ai_calls` + where + ` ORDER BY timestamp DESC LIMIT ` + strconv.Itoa(limit) + ` OFFSET ` + strconv.Itoa(offset) rows, err := db.Query(q, args...) if err != nil { return nil, err } defer rows.Close() var records []AICallRecord for rows.Next() { var r AICallRecord var tagsJSON string if err := rows.Scan(&r.ID, &r.Timestamp, &r.ServiceName, &r.ProjectID, &r.Vendor, &r.Model, &r.TokensIn, &r.TokensOut, &r.TokensCacheRead, &r.TokensCacheWrite, &r.EstimatedCost, &r.LatencyMs, &r.TTFBMs, &r.Status, &r.ErrorMessage, &r.Stream, &r.Cached, &tagsJSON); err != nil { continue } _ = json.Unmarshal([]byte(tagsJSON), &r.Tags) records = append(records, r) } return records, rows.Err() } // GetUsageSummary returns aggregated usage statistics func GetUsageSummary(db *sql.DB, filter AICallFilter) (*AIUsageSummary, error) { where, args := buildWhereClause(filter, 1) q := `SELECT COUNT(*), COALESCE(SUM(tokens_in),0), COALESCE(SUM(tokens_out),0), COALESCE(SUM(tokens_cache_read),0), COALESCE(SUM(tokens_cache_write),0), COALESCE(SUM(estimated_cost),0), COALESCE(AVG(latency_ms),0), COALESCE(AVG(ttfb_ms),0), COUNT(*) FILTER (WHERE status = 'error'), COUNT(DISTINCT model), COUNT(DISTINCT vendor), COUNT(DISTINCT service_name) FROM ai_calls` + where s := &AIUsageSummary{} err := db.QueryRow(q, args...).Scan( &s.TotalCalls, &s.TotalTokensIn, &s.TotalTokensOut, &s.TotalCacheRead, &s.TotalCacheWrite, &s.TotalCost, &s.AvgLatencyMs, &s.AvgTTFBMs, &s.ErrorCount, &s.UniqueModels, &s.UniqueVendors, &s.UniqueServices) if err != nil { return nil, err } if s.TotalCalls > 0 { s.ErrorRate = float64(s.ErrorCount) / float64(s.TotalCalls) } // Cache hit rate var cachedCount int cq := `SELECT COUNT(*) FILTER (WHERE cached = true) FROM ai_calls` + where if err := db.QueryRow(cq, args...).Scan(&cachedCount); err == nil && s.TotalCalls > 0 { s.CacheHitRate = float64(cachedCount) / float64(s.TotalCalls) } return s, nil } // GetModelStats returns per-model statistics func GetModelStats(db *sql.DB, filter AICallFilter) ([]AIModelStats, error) { where, args := buildWhereClause(filter, 1) q := `SELECT vendor, model, COUNT(*), COALESCE(SUM(tokens_in + tokens_out),0), COALESCE(SUM(estimated_cost),0), COALESCE(AVG(latency_ms),0), COUNT(*) FILTER (WHERE status = 'error') FROM ai_calls` + where + ` GROUP BY vendor, model ORDER BY SUM(estimated_cost) DESC` rows, err := db.Query(q, args...) if err != nil { return nil, err } defer rows.Close() var stats []AIModelStats for rows.Next() { var s AIModelStats if err := rows.Scan(&s.Vendor, &s.Model, &s.TotalCalls, &s.TotalTokens, &s.TotalCost, &s.AvgLatencyMs, &s.ErrorCount); err != nil { continue } if s.TotalCalls > 0 { s.ErrorRate = float64(s.ErrorCount) / float64(s.TotalCalls) } stats = append(stats, s) } return stats, rows.Err() } // GetVendorStats returns per-vendor statistics func GetVendorStats(db *sql.DB, filter AICallFilter) ([]AIVendorStats, error) { where, args := buildWhereClause(filter, 1) q := `SELECT vendor, COUNT(*), COALESCE(SUM(tokens_in + tokens_out),0), COALESCE(SUM(estimated_cost),0), COALESCE(AVG(latency_ms),0), COUNT(DISTINCT model), COUNT(*) FILTER (WHERE status = 'error') FROM ai_calls` + where + ` GROUP BY vendor ORDER BY SUM(estimated_cost) DESC` rows, err := db.Query(q, args...) if err != nil { return nil, err } defer rows.Close() var stats []AIVendorStats for rows.Next() { var s AIVendorStats if err := rows.Scan(&s.Vendor, &s.TotalCalls, &s.TotalTokens, &s.TotalCost, &s.AvgLatencyMs, &s.ModelCount, &s.ErrorCount); err != nil { continue } if s.TotalCalls > 0 { s.ErrorRate = float64(s.ErrorCount) / float64(s.TotalCalls) } stats = append(stats, s) } return stats, rows.Err() } // GetCostTimeseries returns cost aggregated over time intervals func GetCostTimeseries(db *sql.DB, filter AICallFilter, interval string) ([]TimeseriesPoint, error) { // Validate interval validIntervals := map[string]bool{"1h": true, "6h": true, "1d": true, "7d": true, "1m": true} if !validIntervals[interval] { interval = "1d" } // Map to PostgreSQL interval pgInterval := map[string]string{ "1h": "1 hour", "6h": "6 hours", "1d": "1 day", "7d": "7 days", "1m": "1 month", }[interval] where, args := buildWhereClause(filter, 1) q := fmt.Sprintf(`SELECT date_trunc('hour', timestamp) - (EXTRACT(EPOCH FROM date_trunc('hour', timestamp))::int %%%% EXTRACT(EPOCH FROM interval '%s')::int) * interval '1 second' AS bucket, COALESCE(SUM(estimated_cost),0), COUNT(*) FROM ai_calls%s GROUP BY bucket ORDER BY bucket ASC`, pgInterval, where) rows, err := db.Query(q, args...) if err != nil { return nil, err } defer rows.Close() var points []TimeseriesPoint for rows.Next() { var p TimeseriesPoint if err := rows.Scan(&p.Timestamp, &p.Value, &p.Count); err != nil { continue } points = append(points, p) } return points, rows.Err() }