""" 错误恢复增强 — 错误分类 + 退避重试 + 会话快照 参考 Claude Code conversationRecovery.ts 设计: - 错误分类: 可重试 vs 不可重试 - 退避策略: 指数退避 + 抖动 - 会话快照: 崩溃时保存状态,启动时恢复 """ from __future__ import annotations import json import logging import os import random import time from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) # ──────────────────────────── 错误类型 ──────────────────────────── class ErrorType(str, Enum): """错误分类""" RETRYABLE = "retryable" # 可重试(网络/限流/服务端) NON_RETRYABLE = "non_retryable" # 不可重试(认证/校验) DEGRADED = "degraded" # 降级运行(部分功能不可用) FATAL = "fatal" # 致命错误(需人工介入) # ──────────────────────────── 退避配置 ──────────────────────────── @dataclass class BackoffConfig: """退避策略配置""" base_delay_ms: float = 1000 # 基础延迟 max_delay_ms: float = 60000 # 最大延迟 multiplier: float = 2.0 # 退避乘数 jitter: float = 0.1 # 抖动比例 (0-1) max_retries: int = 3 # 最大重试次数 # ──────────────────────────── 错误分类器 ──────────────────────────── class ErrorClassifier: """ 错误分类器 — 判断错误是否可重试及对应的退避策略。 参考 Claude Code API 错误处理: - 429 Rate Limit → 指数退避 - 5xx Server Error → 线性退避 - 网络超时 → 立即重试(最多3次) - 401/403 → 不可重试 """ # 可重试错误模式(按优先级排序) RETRYABLE_PATTERNS: List[Tuple[str, Optional[BackoffConfig]]] = [ # (匹配模式, 自定义退避配置 | None=使用默认) ("rate limit", BackoffConfig(base_delay_ms=5000, multiplier=2.0, max_delay_ms=120000, max_retries=5)), ("too many requests", BackoffConfig(base_delay_ms=5000, multiplier=2.0, max_delay_ms=120000, max_retries=5)), ("429", BackoffConfig(base_delay_ms=5000, multiplier=2.0, max_delay_ms=120000, max_retries=5)), ("timed out", BackoffConfig(base_delay_ms=500, multiplier=1.5, max_delay_ms=10000, max_retries=3)), ("timeout", BackoffConfig(base_delay_ms=500, multiplier=1.5, max_delay_ms=10000, max_retries=3)), ("connection error", BackoffConfig(base_delay_ms=500, multiplier=1.5, max_delay_ms=10000, max_retries=3)), ("connection reset", BackoffConfig(base_delay_ms=500, multiplier=1.5, max_delay_ms=10000, max_retries=3)), ("server disconnected", BackoffConfig(base_delay_ms=1000, multiplier=2.0, max_delay_ms=30000, max_retries=3)), ("internal server error", BackoffConfig(base_delay_ms=2000, multiplier=2.0, max_delay_ms=30000, max_retries=3)), ("service unavailable", BackoffConfig(base_delay_ms=2000, multiplier=2.0, max_delay_ms=60000, max_retries=3)), ("temporarily unavailable", BackoffConfig(base_delay_ms=1000, multiplier=2.0, max_delay_ms=30000, max_retries=3)), ("bad gateway", BackoffConfig(base_delay_ms=1000, multiplier=2.0, max_delay_ms=30000, max_retries=3)), ("gateway timeout", BackoffConfig(base_delay_ms=1000, multiplier=2.0, max_delay_ms=30000, max_retries=3)), ] # 不可重试错误模式 NON_RETRYABLE_PATTERNS = [ "unauthorized", "authentication", "invalid api key", "forbidden", "not found", "validation error", "bad request", "402", # Payment Required ] def __init__(self, default_backoff: Optional[BackoffConfig] = None): self.default_backoff = default_backoff or BackoffConfig() def classify(self, error: Exception) -> Tuple[ErrorType, BackoffConfig]: """ 分类错误并返回退避策略。 Returns: (ErrorType, BackoffConfig) """ err_str = str(error).lower() err_type = type(error).__name__.lower() # 检查可重试 for pattern, backoff in self.RETRYABLE_PATTERNS: if pattern in err_str or pattern in err_type: return ErrorType.RETRYABLE, backoff or self.default_backoff # 检查不可重试 for pattern in self.NON_RETRYABLE_PATTERNS: if pattern in err_str or pattern in err_type: return ErrorType.NON_RETRYABLE, self.default_backoff # 默认: 可重试(保守策略:未知错误也重试一次) return ErrorType.RETRYABLE, self.default_backoff def compute_delay(self, attempt: int, backoff: BackoffConfig) -> float: """ 计算第 N 次重试的延迟(指数退避 + 抖动)。 Args: attempt: 第几次重试(0-based) backoff: 退避配置 Returns: 延迟秒数 """ delay = backoff.base_delay_ms * (backoff.multiplier ** attempt) delay = min(delay, backoff.max_delay_ms) # 添加抖动 jitter_range = delay * backoff.jitter delay = delay + random.uniform(-jitter_range, jitter_range) delay = max(0, delay) return delay / 1000 # 转为秒 # ──────────────────────────── 重试执行器 ──────────────────────────── class RetryExecutor: """带退避策略的异步重试执行器""" def __init__(self, classifier: Optional[ErrorClassifier] = None): self.classifier = classifier or ErrorClassifier() async def execute_with_retry( self, fn, *args, max_retries: Optional[int] = None, on_retry: Optional[callable] = None, **kwargs, ) -> Any: """ 使用退避策略执行异步函数。 Args: fn: 异步可调用对象 max_retries: 覆盖默认最大重试次数 on_retry: 重试回调 (attempt, error, delay) -> None Returns: fn 的返回值 Raises: 最后一次失败时的异常(如果所有重试都失败) """ last_error = None for attempt in range(3): # 初始 attempt 用于分类 try: return await fn(*args, **kwargs) except Exception as e: last_error = e error_type, backoff = self.classifier.classify(e) if error_type == ErrorType.NON_RETRYABLE: logger.warning("不可重试错误,直接抛出: %s", e) raise effective_max = max_retries if max_retries is not None else backoff.max_retries if attempt >= effective_max: logger.error("已达最大重试次数 (%d): %s", effective_max, e) raise delay = self.classifier.compute_delay(attempt, backoff) logger.warning( "重试 %d/%d,等待 %.1fs: %s", attempt + 1, effective_max, delay, str(e)[:200], ) if on_retry: try: on_retry(attempt, e, delay) except Exception: pass time.sleep(delay) # 同步等待 raise last_error # type: ignore # ──────────────────────────── 会话快照与恢复 ──────────────────────────── class ConversationRecovery: """ 会话崩溃恢复 — 参考 Claude Code conversationRecovery.ts。 在关键节点自动保存快照,崩溃后可恢复最近状态。 """ def __init__(self, snapshot_dir: Optional[str] = None): self.snapshot_dir = snapshot_dir or os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "snapshots", ) def _snapshot_path(self, session_id: str) -> str: os.makedirs(self.snapshot_dir, exist_ok=True) safe_id = session_id.replace("/", "_").replace("\\", "_") return os.path.join(self.snapshot_dir, f"{safe_id}.json") async def save_snapshot( self, session_id: str, messages: List[Dict[str, Any]], extra: Optional[Dict[str, Any]] = None, ) -> bool: """ 保存会话快照。 Args: session_id: 会话标识 messages: 消息列表 extra: 额外状态数据(模型、配置等) Returns: 是否保存成功 """ try: snapshot = { "session_id": session_id, "saved_at": time.time(), "message_count": len(messages), "messages": messages[-100:], # 最多保存最近 100 条 "extra": extra or {}, } path = self._snapshot_path(session_id) with open(path, "w", encoding="utf-8") as f: json.dump(snapshot, f, ensure_ascii=False, default=str) logger.info("会话快照已保存: %s (%d 条消息)", session_id, len(messages)) return True except Exception as e: logger.error("会话快照保存失败: %s", e) return False async def restore_snapshot( self, session_id: str, ) -> Optional[Dict[str, Any]]: """ 恢复会话快照。 Returns: 快照数据字典,若不存在则返回 None """ try: path = self._snapshot_path(session_id) if not os.path.exists(path): return None with open(path, "r", encoding="utf-8") as f: snapshot = json.load(f) age = time.time() - snapshot.get("saved_at", 0) logger.info( "会话快照已恢复: %s (%d 条消息, %.0f 秒前)", session_id, snapshot.get("message_count", 0), age, ) return snapshot except Exception as e: logger.error("会话快照恢复失败: %s", e) return None async def delete_snapshot(self, session_id: str) -> bool: """删除会话快照(正常退出时调用)。""" try: path = self._snapshot_path(session_id) if os.path.exists(path): os.remove(path) logger.info("会话快照已删除: %s", session_id) return True except Exception as e: logger.error("会话快照删除失败: %s", e) return False async def mark_interrupted(self, session_id: str) -> bool: """ 标记会话为异常中断(崩溃时调用)。 下次启动时前端可检测此标记并提示恢复。 """ try: path = self._snapshot_path(session_id) # 读取现有快照 snapshot = {} if os.path.exists(path): with open(path, "r", encoding="utf-8") as f: snapshot = json.load(f) snapshot["interrupted"] = True snapshot["interrupted_at"] = time.time() with open(path, "w", encoding="utf-8") as f: json.dump(snapshot, f, ensure_ascii=False, default=str) logger.info("会话已标记为中断: %s", session_id) return True except Exception as e: logger.error("标记会话中断失败: %s", e) return False def list_interrupted_sessions(self) -> List[Dict[str, Any]]: """列出所有中断的会话快照。""" interrupted = [] try: os.makedirs(self.snapshot_dir, exist_ok=True) for filename in os.listdir(self.snapshot_dir): if not filename.endswith(".json"): continue path = os.path.join(self.snapshot_dir, filename) try: with open(path, "r", encoding="utf-8") as f: snapshot = json.load(f) if snapshot.get("interrupted"): age = time.time() - snapshot.get("interrupted_at", 0) interrupted.append({ "session_id": snapshot.get("session_id"), "message_count": snapshot.get("message_count", 0), "interrupted_at": snapshot.get("interrupted_at"), "age_seconds": age, "path": path, }) except Exception: continue interrupted.sort(key=lambda s: s.get("interrupted_at", 0), reverse=True) except Exception as e: logger.error("列出中断会话失败: %s", e) return interrupted