"""API 限流中间件 — 基于滑动窗口的简易限流器(Redis 优先,内存 fallback)""" from __future__ import annotations import logging import time from collections import defaultdict from typing import Dict, List, Optional, Tuple from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse logger = logging.getLogger(__name__) # 默认限流配置 DEFAULT_RATE_LIMIT = 120 # 每窗口最大请求数 DEFAULT_WINDOW_SEC = 60 # 窗口时长(秒) # 敏感端点更严格的限制 SENSITIVE_PATH_PREFIXES = [ "/api/v1/auth/login", "/api/v1/agent-chat", ] # 单路径精确限流配置(优先级高于前缀匹配) PATH_SPECIFIC_LIMITS: Dict[str, Tuple[int, int]] = { # (max_requests, window_sec) "/api/v1/auth/login": (5, 60), # 登录: 5次/分钟 "/api/v1/webhooks": (60, 60), # Webhook: 60次/分钟 } # ─── 内存存储(单进程 / 无 Redis 时使用) ─── _memory_store: Dict[str, List[float]] = defaultdict(list) def _get_redis(): """尝试获取 Redis 客户端。""" try: from app.core.redis_client import get_redis_client client = get_redis_client() if client: try: client.ping() return client except Exception: pass except Exception: pass return None def _check_memory( key: str, max_requests: int, window_sec: float ) -> Tuple[bool, int]: """内存滑动窗口检查。返回 (allowed, remaining)。""" now = time.monotonic() window = _memory_store[key] # 清理过期记录 cutoff = now - window_sec while window and window[0] < cutoff: window.pop(0) if len(window) < max_requests: window.append(now) return True, max_requests - len(window) return False, 0 def _check_redis( client, key: str, max_requests: int, window_sec: int ) -> Tuple[bool, int]: """Redis 滑动窗口检查。""" now_ms = int(time.time() * 1000) window_ms = window_sec * 1000 pipe = client.pipeline() member = f"{now_ms}:{now_ms}" pipe.zadd(key, {member: now_ms}) pipe.zremrangebyscore(key, 0, now_ms - window_ms) pipe.zcard(key) pipe.expire(key, window_sec * 2) _, _, count, _ = pipe.execute() remaining = max(0, max_requests - count) if count <= max_requests: return True, remaining return False, remaining class RateLimiterMiddleware(BaseHTTPMiddleware): """API 限流中间件。 规则: - 默认: 120 req / 60s per IP - 敏感端点 (login, agent-chat): 30 req / 60s per IP - 限流时返回 429 + Retry-After """ async def dispatch(self, request: Request, call_next) -> Response: path = request.url.path # 跳过非 API 路径 if not path.startswith("/api/"): return await call_next(request) # 确定限流配置:优先精确路径匹配,其次前缀匹配 max_requests = DEFAULT_RATE_LIMIT window_sec = DEFAULT_WINDOW_SEC is_sensitive = False for pfx, (limit, win) in PATH_SPECIFIC_LIMITS.items(): if path.startswith(pfx): max_requests = limit window_sec = win is_sensitive = True break else: is_sensitive = any(path.startswith(p) for p in SENSITIVE_PATH_PREFIXES) if is_sensitive: max_requests = 30 # 构建 key: ip + path 前缀 client_ip = request.client.host if request.client else "unknown" rate_key = f"rl:{client_ip}:{'sensitive' if is_sensitive else 'normal'}" # 检查限流 redis_client = _get_redis() if redis_client: allowed, remaining = _check_redis( redis_client, rate_key, max_requests, window_sec ) else: allowed, remaining = _check_memory( rate_key, max_requests, window_sec ) if not allowed: retry_after = window_sec logger.warning( "API 限流触发: ip=%s path=%s max=%d/%ds", client_ip, path, max_requests, window_sec, ) return JSONResponse( status_code=429, content={ "detail": f"请求过于频繁,请 {retry_after}s 后重试", "retry_after": retry_after, }, headers={"Retry-After": str(retry_after)}, ) response = await call_next(request) # 注入限流头 response.headers["X-RateLimit-Limit"] = str(max_requests) response.headers["X-RateLimit-Remaining"] = str(remaining) return response