Files
sentinela-go/internal/api/middleware/plans.go
Rainbow a2b0db8f3f feat: tiered API plans (Free/Bronze/Gold/Platinum)
- Free: $0, 30 req/min, market data only
- Bronze: $29/mo, 100 req/min, + companies & search
- Gold: $99/mo, 500 req/min, + filings & historical
- Platinum: $299/mo, 2000 req/min, all features + priority
- Plan-aware rate limiting per API key (or per IP for free)
- Usage tracking with daily aggregation
- GET /api/v1/plans — plan listing
- POST /api/v1/plans/register — instant free API key
- GET /api/v1/plans/usage — usage stats
- /pricing — dark-themed HTML pricing page
- X-RateLimit-* and X-Plan headers on every response
- Restricted endpoints return upgrade prompt
- Updated OpenAPI spec with security scheme
- 53 handlers, compiles clean
2026-02-10 12:55:45 -03:00

185 lines
5.0 KiB
Go

package middleware
import (
"fmt"
"strings"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/sentinela-go/internal/db"
)
type PlanConfig struct {
Name string
RateLimit int
Endpoints []string
Features []string
}
var Plans = map[string]PlanConfig{
"free": {Name: "Free", RateLimit: 30, Endpoints: []string{"/health", "/api/v1/market", "/api/v1/plans", "/docs", "/pricing"}, Features: []string{"market_data"}},
"bronze": {Name: "Bronze", RateLimit: 100, Endpoints: []string{"/health", "/api/v1/market", "/api/v1/companies", "/api/v1/search", "/api/v1/plans", "/docs", "/pricing"}, Features: []string{"market_data", "companies", "search"}},
"gold": {Name: "Gold", RateLimit: 500, Endpoints: []string{"/api/v1", "/health", "/docs", "/pricing"}, Features: []string{"market_data", "companies", "search", "filings", "historical", "bulk"}},
"platinum": {Name: "Platinum", RateLimit: 2000, Endpoints: []string{"/api/v1", "/health", "/docs", "/pricing"}, Features: []string{"market_data", "companies", "search", "filings", "historical", "bulk", "webhooks", "csv_export", "priority"}},
}
type tokenBucket struct {
tokens float64
lastCheck time.Time
}
type PlanMiddleware struct {
database *db.DB
mu sync.Mutex
buckets map[string]*tokenBucket
}
func NewPlanMiddleware(database *db.DB) fiber.Handler {
pm := &PlanMiddleware{
database: database,
buckets: make(map[string]*tokenBucket),
}
// Cleanup old buckets every 5 minutes
go func() {
for {
time.Sleep(5 * time.Minute)
pm.mu.Lock()
now := time.Now()
for k, b := range pm.buckets {
if now.Sub(b.lastCheck) > 10*time.Minute {
delete(pm.buckets, k)
}
}
pm.mu.Unlock()
}
}()
return pm.handle
}
func (pm *PlanMiddleware) handle(c *fiber.Ctx) error {
path := c.Path()
// Always allow these without any checks
if path == "/health" || path == "/docs" || path == "/docs/openapi.yaml" || path == "/pricing" {
c.Set("X-Plan", "free")
return c.Next()
}
// Plans endpoints are always accessible
if strings.HasPrefix(path, "/api/v1/plans") {
c.Set("X-Plan", "free")
return c.Next()
}
apiKey := c.Get("X-API-Key")
if apiKey == "" {
auth := c.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
apiKey = strings.TrimPrefix(auth, "Bearer ")
}
}
plan := "free"
rateLimit := 30
var keyRecord *db.APIKey
bucketKey := c.IP() // default: rate limit by IP
if apiKey != "" {
ak, err := pm.database.GetAPIKey(apiKey)
if err != nil {
return c.Status(401).JSON(fiber.Map{
"error": "invalid API key",
"message": "The provided API key is not valid or has been deactivated.",
})
}
keyRecord = ak
plan = ak.Plan
rateLimit = ak.RateLimit
bucketKey = fmt.Sprintf("key:%s", ak.Key)
}
// Check endpoint access
planCfg, ok := Plans[plan]
if !ok {
planCfg = Plans["free"]
}
if !isEndpointAllowed(path, planCfg.Endpoints) {
return c.Status(403).JSON(fiber.Map{
"error": "plan_restricted",
"message": fmt.Sprintf("Your %s plan does not include access to this endpoint.", planCfg.Name),
"current_plan": plan,
"upgrade_url": "/pricing",
})
}
// Rate limiting
pm.mu.Lock()
b, ok := pm.buckets[bucketKey]
if !ok {
b = &tokenBucket{tokens: float64(rateLimit), lastCheck: time.Now()}
pm.buckets[bucketKey] = b
}
now := time.Now()
elapsed := now.Sub(b.lastCheck).Seconds()
rate := float64(rateLimit) / 60.0
b.tokens += elapsed * rate
if b.tokens > float64(rateLimit) {
b.tokens = float64(rateLimit)
}
b.lastCheck = now
remaining := int(b.tokens)
if b.tokens < 1 {
pm.mu.Unlock()
resetTime := time.Now().Add(time.Duration(float64(time.Second) / rate))
c.Set("X-RateLimit-Limit", fmt.Sprintf("%d", rateLimit))
c.Set("X-RateLimit-Remaining", "0")
c.Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
c.Set("X-Plan", plan)
return c.Status(429).JSON(fiber.Map{
"error": "rate_limit_exceeded",
"message": fmt.Sprintf("Rate limit of %d requests/minute exceeded for %s plan.", rateLimit, planCfg.Name),
"plan": plan,
"limit": rateLimit,
"retry_after": fmt.Sprintf("%.1fs", 1.0/rate),
"upgrade_url": "/pricing",
})
}
b.tokens--
remaining = int(b.tokens)
pm.mu.Unlock()
// Set rate limit headers
resetTime := time.Now().Add(time.Minute)
c.Set("X-RateLimit-Limit", fmt.Sprintf("%d", rateLimit))
c.Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
c.Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
c.Set("X-Plan", plan)
startTime := time.Now()
err := c.Next()
// Async usage tracking
if keyRecord != nil {
go func(ak *db.APIKey, endpoint string, status int, latency int, ip string) {
_ = pm.database.IncrementUsage(ak.ID, endpoint, status, latency, ip)
}(keyRecord, path, c.Response().StatusCode(), int(time.Since(startTime).Milliseconds()), c.IP())
}
return err
}
func isEndpointAllowed(path string, allowed []string) bool {
for _, prefix := range allowed {
if strings.HasPrefix(path, prefix) {
return true
}
}
return false
}