package middleware import ( "fmt" "strings" "sync" "time" "github.com/gofiber/fiber/v2" "github.com/sentinela-go/internal/db" ) type PlanConfig struct { Name string RateLimit int Endpoints []string Features []string } var Plans = map[string]PlanConfig{ "free": {Name: "Free", RateLimit: 30, Endpoints: []string{"/health", "/api/v1/market", "/api/v1/plans", "/docs", "/pricing"}, Features: []string{"market_data"}}, "bronze": {Name: "Bronze", RateLimit: 100, Endpoints: []string{"/health", "/api/v1/market", "/api/v1/companies", "/api/v1/search", "/api/v1/plans", "/docs", "/pricing"}, Features: []string{"market_data", "companies", "search"}}, "gold": {Name: "Gold", RateLimit: 500, Endpoints: []string{"/api/v1", "/health", "/docs", "/pricing"}, Features: []string{"market_data", "companies", "search", "filings", "historical", "bulk"}}, "platinum": {Name: "Platinum", RateLimit: 2000, Endpoints: []string{"/api/v1", "/health", "/docs", "/pricing"}, Features: []string{"market_data", "companies", "search", "filings", "historical", "bulk", "webhooks", "csv_export", "priority"}}, } type tokenBucket struct { tokens float64 lastCheck time.Time } type PlanMiddleware struct { database *db.DB mu sync.Mutex buckets map[string]*tokenBucket } func NewPlanMiddleware(database *db.DB) fiber.Handler { pm := &PlanMiddleware{ database: database, buckets: make(map[string]*tokenBucket), } // Cleanup old buckets every 5 minutes go func() { for { time.Sleep(5 * time.Minute) pm.mu.Lock() now := time.Now() for k, b := range pm.buckets { if now.Sub(b.lastCheck) > 10*time.Minute { delete(pm.buckets, k) } } pm.mu.Unlock() } }() return pm.handle } func (pm *PlanMiddleware) handle(c *fiber.Ctx) error { path := c.Path() // Always allow these without any checks if path == "/health" || path == "/docs" || path == "/docs/openapi.yaml" || path == "/pricing" { c.Set("X-Plan", "free") return c.Next() } // Plans endpoints are always accessible if strings.HasPrefix(path, "/api/v1/plans") { c.Set("X-Plan", "free") return c.Next() } apiKey := c.Get("X-API-Key") if apiKey == "" { auth := c.Get("Authorization") if strings.HasPrefix(auth, "Bearer ") { apiKey = strings.TrimPrefix(auth, "Bearer ") } } plan := "free" rateLimit := 30 var keyRecord *db.APIKey bucketKey := c.IP() // default: rate limit by IP if apiKey != "" { ak, err := pm.database.GetAPIKey(apiKey) if err != nil { return c.Status(401).JSON(fiber.Map{ "error": "invalid API key", "message": "The provided API key is not valid or has been deactivated.", }) } keyRecord = ak plan = ak.Plan rateLimit = ak.RateLimit bucketKey = fmt.Sprintf("key:%s", ak.Key) } // Check endpoint access planCfg, ok := Plans[plan] if !ok { planCfg = Plans["free"] } if !isEndpointAllowed(path, planCfg.Endpoints) { return c.Status(403).JSON(fiber.Map{ "error": "plan_restricted", "message": fmt.Sprintf("Your %s plan does not include access to this endpoint.", planCfg.Name), "current_plan": plan, "upgrade_url": "/pricing", }) } // Rate limiting pm.mu.Lock() b, ok := pm.buckets[bucketKey] if !ok { b = &tokenBucket{tokens: float64(rateLimit), lastCheck: time.Now()} pm.buckets[bucketKey] = b } now := time.Now() elapsed := now.Sub(b.lastCheck).Seconds() rate := float64(rateLimit) / 60.0 b.tokens += elapsed * rate if b.tokens > float64(rateLimit) { b.tokens = float64(rateLimit) } b.lastCheck = now remaining := int(b.tokens) if b.tokens < 1 { pm.mu.Unlock() resetTime := time.Now().Add(time.Duration(float64(time.Second) / rate)) c.Set("X-RateLimit-Limit", fmt.Sprintf("%d", rateLimit)) c.Set("X-RateLimit-Remaining", "0") c.Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix())) c.Set("X-Plan", plan) return c.Status(429).JSON(fiber.Map{ "error": "rate_limit_exceeded", "message": fmt.Sprintf("Rate limit of %d requests/minute exceeded for %s plan.", rateLimit, planCfg.Name), "plan": plan, "limit": rateLimit, "retry_after": fmt.Sprintf("%.1fs", 1.0/rate), "upgrade_url": "/pricing", }) } b.tokens-- remaining = int(b.tokens) pm.mu.Unlock() // Set rate limit headers resetTime := time.Now().Add(time.Minute) c.Set("X-RateLimit-Limit", fmt.Sprintf("%d", rateLimit)) c.Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) c.Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix())) c.Set("X-Plan", plan) startTime := time.Now() err := c.Next() // Async usage tracking if keyRecord != nil { go func(ak *db.APIKey, endpoint string, status int, latency int, ip string) { _ = pm.database.IncrementUsage(ak.ID, endpoint, status, latency, ip) }(keyRecord, path, c.Response().StatusCode(), int(time.Since(startTime).Milliseconds()), c.IP()) } return err } func isEndpointAllowed(path string, allowed []string) bool { for _, prefix := range allowed { if strings.HasPrefix(path, prefix) { return true } } return false }