- Free: $0, 30 req/min, market data only - Bronze: $29/mo, 100 req/min, + companies & search - Gold: $99/mo, 500 req/min, + filings & historical - Platinum: $299/mo, 2000 req/min, all features + priority - Plan-aware rate limiting per API key (or per IP for free) - Usage tracking with daily aggregation - GET /api/v1/plans — plan listing - POST /api/v1/plans/register — instant free API key - GET /api/v1/plans/usage — usage stats - /pricing — dark-themed HTML pricing page - X-RateLimit-* and X-Plan headers on every response - Restricted endpoints return upgrade prompt - Updated OpenAPI spec with security scheme - 53 handlers, compiles clean
185 lines
5.0 KiB
Go
185 lines
5.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/sentinela-go/internal/db"
|
|
)
|
|
|
|
type PlanConfig struct {
|
|
Name string
|
|
RateLimit int
|
|
Endpoints []string
|
|
Features []string
|
|
}
|
|
|
|
var Plans = map[string]PlanConfig{
|
|
"free": {Name: "Free", RateLimit: 30, Endpoints: []string{"/health", "/api/v1/market", "/api/v1/plans", "/docs", "/pricing"}, Features: []string{"market_data"}},
|
|
"bronze": {Name: "Bronze", RateLimit: 100, Endpoints: []string{"/health", "/api/v1/market", "/api/v1/companies", "/api/v1/search", "/api/v1/plans", "/docs", "/pricing"}, Features: []string{"market_data", "companies", "search"}},
|
|
"gold": {Name: "Gold", RateLimit: 500, Endpoints: []string{"/api/v1", "/health", "/docs", "/pricing"}, Features: []string{"market_data", "companies", "search", "filings", "historical", "bulk"}},
|
|
"platinum": {Name: "Platinum", RateLimit: 2000, Endpoints: []string{"/api/v1", "/health", "/docs", "/pricing"}, Features: []string{"market_data", "companies", "search", "filings", "historical", "bulk", "webhooks", "csv_export", "priority"}},
|
|
}
|
|
|
|
type tokenBucket struct {
|
|
tokens float64
|
|
lastCheck time.Time
|
|
}
|
|
|
|
type PlanMiddleware struct {
|
|
database *db.DB
|
|
mu sync.Mutex
|
|
buckets map[string]*tokenBucket
|
|
}
|
|
|
|
func NewPlanMiddleware(database *db.DB) fiber.Handler {
|
|
pm := &PlanMiddleware{
|
|
database: database,
|
|
buckets: make(map[string]*tokenBucket),
|
|
}
|
|
|
|
// Cleanup old buckets every 5 minutes
|
|
go func() {
|
|
for {
|
|
time.Sleep(5 * time.Minute)
|
|
pm.mu.Lock()
|
|
now := time.Now()
|
|
for k, b := range pm.buckets {
|
|
if now.Sub(b.lastCheck) > 10*time.Minute {
|
|
delete(pm.buckets, k)
|
|
}
|
|
}
|
|
pm.mu.Unlock()
|
|
}
|
|
}()
|
|
|
|
return pm.handle
|
|
}
|
|
|
|
func (pm *PlanMiddleware) handle(c *fiber.Ctx) error {
|
|
path := c.Path()
|
|
|
|
// Always allow these without any checks
|
|
if path == "/health" || path == "/docs" || path == "/docs/openapi.yaml" || path == "/pricing" {
|
|
c.Set("X-Plan", "free")
|
|
return c.Next()
|
|
}
|
|
|
|
// Plans endpoints are always accessible
|
|
if strings.HasPrefix(path, "/api/v1/plans") {
|
|
c.Set("X-Plan", "free")
|
|
return c.Next()
|
|
}
|
|
|
|
apiKey := c.Get("X-API-Key")
|
|
if apiKey == "" {
|
|
auth := c.Get("Authorization")
|
|
if strings.HasPrefix(auth, "Bearer ") {
|
|
apiKey = strings.TrimPrefix(auth, "Bearer ")
|
|
}
|
|
}
|
|
|
|
plan := "free"
|
|
rateLimit := 30
|
|
var keyRecord *db.APIKey
|
|
bucketKey := c.IP() // default: rate limit by IP
|
|
|
|
if apiKey != "" {
|
|
ak, err := pm.database.GetAPIKey(apiKey)
|
|
if err != nil {
|
|
return c.Status(401).JSON(fiber.Map{
|
|
"error": "invalid API key",
|
|
"message": "The provided API key is not valid or has been deactivated.",
|
|
})
|
|
}
|
|
keyRecord = ak
|
|
plan = ak.Plan
|
|
rateLimit = ak.RateLimit
|
|
bucketKey = fmt.Sprintf("key:%s", ak.Key)
|
|
}
|
|
|
|
// Check endpoint access
|
|
planCfg, ok := Plans[plan]
|
|
if !ok {
|
|
planCfg = Plans["free"]
|
|
}
|
|
|
|
if !isEndpointAllowed(path, planCfg.Endpoints) {
|
|
return c.Status(403).JSON(fiber.Map{
|
|
"error": "plan_restricted",
|
|
"message": fmt.Sprintf("Your %s plan does not include access to this endpoint.", planCfg.Name),
|
|
"current_plan": plan,
|
|
"upgrade_url": "/pricing",
|
|
})
|
|
}
|
|
|
|
// Rate limiting
|
|
pm.mu.Lock()
|
|
b, ok := pm.buckets[bucketKey]
|
|
if !ok {
|
|
b = &tokenBucket{tokens: float64(rateLimit), lastCheck: time.Now()}
|
|
pm.buckets[bucketKey] = b
|
|
}
|
|
|
|
now := time.Now()
|
|
elapsed := now.Sub(b.lastCheck).Seconds()
|
|
rate := float64(rateLimit) / 60.0
|
|
b.tokens += elapsed * rate
|
|
if b.tokens > float64(rateLimit) {
|
|
b.tokens = float64(rateLimit)
|
|
}
|
|
b.lastCheck = now
|
|
|
|
remaining := int(b.tokens)
|
|
if b.tokens < 1 {
|
|
pm.mu.Unlock()
|
|
resetTime := time.Now().Add(time.Duration(float64(time.Second) / rate))
|
|
c.Set("X-RateLimit-Limit", fmt.Sprintf("%d", rateLimit))
|
|
c.Set("X-RateLimit-Remaining", "0")
|
|
c.Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
|
|
c.Set("X-Plan", plan)
|
|
return c.Status(429).JSON(fiber.Map{
|
|
"error": "rate_limit_exceeded",
|
|
"message": fmt.Sprintf("Rate limit of %d requests/minute exceeded for %s plan.", rateLimit, planCfg.Name),
|
|
"plan": plan,
|
|
"limit": rateLimit,
|
|
"retry_after": fmt.Sprintf("%.1fs", 1.0/rate),
|
|
"upgrade_url": "/pricing",
|
|
})
|
|
}
|
|
b.tokens--
|
|
remaining = int(b.tokens)
|
|
pm.mu.Unlock()
|
|
|
|
// Set rate limit headers
|
|
resetTime := time.Now().Add(time.Minute)
|
|
c.Set("X-RateLimit-Limit", fmt.Sprintf("%d", rateLimit))
|
|
c.Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
|
c.Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
|
|
c.Set("X-Plan", plan)
|
|
|
|
startTime := time.Now()
|
|
err := c.Next()
|
|
|
|
// Async usage tracking
|
|
if keyRecord != nil {
|
|
go func(ak *db.APIKey, endpoint string, status int, latency int, ip string) {
|
|
_ = pm.database.IncrementUsage(ak.ID, endpoint, status, latency, ip)
|
|
}(keyRecord, path, c.Response().StatusCode(), int(time.Since(startTime).Milliseconds()), c.IP())
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func isEndpointAllowed(path string, allowed []string) bool {
|
|
for _, prefix := range allowed {
|
|
if strings.HasPrefix(path, prefix) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|