""" Hook 系统 — 事件钩子注册/触发框架 参考 Claude Code src/utils/hooks.ts 设计: - 6 种事件: UserPromptSubmit / PreToolUse / PostToolUse / Stop / SessionStart / Notification - 3 种 Hook 类型: shell / python / http - 通配符匹配: tool_name 支持 * 前缀匹配 """ from __future__ import annotations import asyncio import fnmatch import json import logging from dataclasses import dataclass, field from enum import Enum from typing import Any, Awaitable, Callable, Dict, List, Optional, Union logger = logging.getLogger(__name__) # ──────────────────────────── 事件类型 ──────────────────────────── class HookEvent(str, Enum): """Hook 事件类型 — 参考 Claude Code Hooks 接口""" USER_PROMPT_SUBMIT = "UserPromptSubmit" # 用户提交输入前 PRE_TOOL_USE = "PreToolUse" # 工具执行前 POST_TOOL_USE = "PostToolUse" # 工具执行后 STOP = "Stop" # 对话完成 SESSION_START = "SessionStart" # 会话启动 NOTIFICATION = "Notification" # 事件通知 # ──────────────────────────── 数据结构 ──────────────────────────── @dataclass class HookConfig: """单个 Hook 的配置""" event: HookEvent matcher: str = "*" # 工具名/事件名匹配,支持 * 通配符 description: str = "" # Hook 处理器(三选一) shell_command: Optional[str] = None # Shell 命令 python_handler: Optional[Callable[..., Any]] = None # Python 异步函数 http_url: Optional[str] = None # HTTP 端点 timeout_ms: int = 60000 enabled: bool = True def matches(self, tool_name: str) -> bool: """检查工具名是否匹配此 Hook 的 matcher 模式。""" if not self.enabled: return False if self.matcher == "*": return True return fnmatch.fnmatch(tool_name, self.matcher) @dataclass class HookContext: """传递给 Hook 的上下文数据""" event: HookEvent tool_name: Optional[str] = None tool_input: Optional[Dict[str, Any]] = None tool_output: Optional[str] = None session_id: Optional[str] = None agent_name: Optional[str] = None user_id: Optional[str] = None messages: Optional[List[Dict[str, Any]]] = None extra: Dict[str, Any] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: """序列化为 JSON-serializable 字典(用于 shell/http hook)。""" return { "event": self.event.value, "tool_name": self.tool_name, "tool_input": self.tool_input, "tool_output": (self.tool_output[:2000] if self.tool_output else None), "session_id": self.session_id, "agent_name": self.agent_name, "user_id": self.user_id, "extra": self.extra, } @dataclass class HookResult: """Hook 执行结果""" allowed: bool = True # False = 拒绝操作 reason: str = "" # 拒绝原因 modified_input: Optional[Dict[str, Any]] = None # PreToolUse 可修改工具参数 modified_messages: Optional[List[Dict[str, Any]]] = None # UserPromptSubmit 可修改消息 data: Dict[str, Any] = field(default_factory=dict) # 额外数据 # ──────────────────────────── Hook 管理器 ──────────────────────────── class HookManager: """ Hook 事件管理与触发。 用法: manager = HookManager() manager.register(HookConfig( event=HookEvent.PRE_TOOL_USE, matcher="Bash*", shell_command="echo 'Bash tool called' >&2", )) result = await manager.trigger(HookEvent.PRE_TOOL_USE, HookContext(...)) """ def __init__(self, hooks: Optional[List[HookConfig]] = None): self._hooks: Dict[HookEvent, List[HookConfig]] = {e: [] for e in HookEvent} for h in (hooks or []): self.register(h) def register(self, config: HookConfig) -> None: """注册一个 Hook。""" self._hooks[config.event].append(config) logger.info("Hook 注册: event=%s matcher=%s", config.event.value, config.matcher) def unregister(self, event: HookEvent, matcher: str) -> int: """移除匹配的 Hook,返回移除数量。""" before = len(self._hooks[event]) self._hooks[event] = [h for h in self._hooks[event] if h.matcher != matcher] removed = before - len(self._hooks[event]) logger.info("Hook 移除: event=%s matcher=%s removed=%d", event.value, matcher, removed) return removed def get_hooks(self, event: HookEvent) -> List[HookConfig]: """获取指定事件的所有 Hook。""" return list(self._hooks.get(event, [])) async def trigger( self, event: HookEvent, context: HookContext, ) -> HookResult: """ 触发指定事件的匹配 Hook。 执行顺序: 按注册顺序依次执行所有匹配的 Hook。 如果任一 Hook 返回 allowed=False,立即返回拒绝结果。 Returns: 聚合的 HookResult;如果多个 Hook 都修改了输入,最后一次修改生效。 """ final_result = HookResult(allowed=True) matching = [h for h in self._hooks.get(event, []) if h.matches(context.tool_name or "*")] if not matching: return final_result logger.debug("触发 Hook event=%s tool=%s hooks=%d", event.value, context.tool_name, len(matching)) for hook in matching: try: result = await asyncio.wait_for( self._execute_hook(hook, context), timeout=hook.timeout_ms / 1000, ) if not result.allowed: logger.warning( "Hook 拒绝操作: event=%s tool=%s reason=%s", event.value, context.tool_name, result.reason, ) # 被拒绝时接管后续流程不被执行(但继续执行剩余 hooks 以便通知/审计) final_result.allowed = False final_result.reason = final_result.reason or result.reason if result.modified_input is not None: final_result.modified_input = result.modified_input if result.modified_messages is not None: final_result.modified_messages = result.modified_messages if result.data: final_result.data.update(result.data) except asyncio.TimeoutError: logger.error("Hook 超时 (%.1fs): event=%s matcher=%s", hook.timeout_ms / 1000, event.value, hook.matcher) except Exception: logger.exception("Hook 执行异常: event=%s matcher=%s", event.value, hook.matcher) return final_result async def _execute_hook(self, hook: HookConfig, context: HookContext) -> HookResult: """执行单个 Hook(shell / python / http)。""" if hook.shell_command: return await self._execute_shell_hook(hook, context) if hook.python_handler: return await self._execute_python_hook(hook, context) if hook.http_url: return await self._execute_http_hook(hook, context) return HookResult(allowed=True) # ── Shell Hook ── async def _execute_shell_hook(self, hook: HookConfig, context: HookContext) -> HookResult: """执行 Shell Hook: stdin 传入 JSON context,stdout 读取结果。""" import shlex ctx_json = json.dumps(context.to_dict(), ensure_ascii=False) proc = await asyncio.create_subprocess_exec( *shlex.split(hook.shell_command or "true"), stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: stdout, stderr = await asyncio.wait_for( proc.communicate(ctx_json.encode("utf-8")), timeout=hook.timeout_ms / 1000, ) except asyncio.TimeoutError: proc.kill() await proc.wait() raise if proc.returncode != 0: logger.warning("Shell Hook 退出非零: rc=%d stderr=%s", proc.returncode, stderr.decode()[:200]) return HookResult( allowed=False, reason=f"Hook 返回非零退出码: {proc.returncode}", ) # 解析 stdout 为 HookResult stdout_text = stdout.decode("utf-8").strip() if not stdout_text: return HookResult(allowed=True) try: data = json.loads(stdout_text) return HookResult( allowed=data.get("allowed", True), reason=data.get("reason", ""), modified_input=data.get("modified_input"), modified_messages=data.get("modified_messages"), data=data.get("data", {}), ) except json.JSONDecodeError: # stdout 不是 JSON → 视为 stdout 内容,不影响执行 logger.debug("Shell Hook stdout (非JSON): %.200s", stdout_text) return HookResult(allowed=True) # ── Python Hook ── async def _execute_python_hook(self, hook: HookConfig, context: HookContext) -> HookResult: """执行 Python Hook: 直接调用 async 函数。""" if not hook.python_handler: return HookResult(allowed=True) result = hook.python_handler(context) if asyncio.iscoroutine(result): result = await result if result is None: return HookResult(allowed=True) if isinstance(result, HookResult): return result if isinstance(result, dict): return HookResult( allowed=result.get("allowed", True), reason=result.get("reason", ""), modified_input=result.get("modified_input"), modified_messages=result.get("modified_messages"), data=result, ) if isinstance(result, bool): return HookResult(allowed=result) return HookResult(allowed=True) # ── HTTP Hook ── async def _execute_http_hook(self, hook: HookConfig, context: HookContext) -> HookResult: """执行 HTTP Hook: POST JSON context 到外部服务。""" try: import httpx except ImportError: logger.error("HTTP Hook 需要 httpx 库") return HookResult(allowed=True) try: async with httpx.AsyncClient(timeout=hook.timeout_ms / 1000) as client: resp = await client.post( hook.http_url or "", json=context.to_dict(), headers={"Content-Type": "application/json"}, ) if resp.status_code >= 400: logger.warning("HTTP Hook 返回 %d: %s", resp.status_code, resp.text[:200]) return HookResult( allowed=False, reason=f"HTTP Hook 返回 {resp.status_code}", ) data = resp.json() if resp.text else {} return HookResult( allowed=data.get("allowed", True), reason=data.get("reason", ""), modified_input=data.get("modified_input"), modified_messages=data.get("modified_messages"), data=data, ) except Exception as e: logger.error("HTTP Hook 调用失败: %s", e) return HookResult(allowed=True) # HTTP hook 失败不阻断执行 # ──────────────────────────── 内置 Hook 示例 ──────────────────────────── def create_audit_log_hook(): """创建审计日志 Hook — 记录所有工具调用到日志。""" async def audit_handler(ctx: HookContext) -> None: logger.info( "[AUDIT] event=%s tool=%s agent=%s session=%s", ctx.event.value, ctx.tool_name, ctx.agent_name, ctx.session_id, ) return HookConfig( event=HookEvent.PRE_TOOL_USE, matcher="*", description="审计日志:记录所有工具调用", python_handler=audit_handler, ) def create_security_hook(forbidden_commands: Optional[List[str]] = None): """创建安全 Hook — 拦截危险命令。""" dangerous = forbidden_commands or ["rm -rf", "sudo", "chmod 777", "DROP TABLE"] async def security_handler(ctx: HookContext) -> dict: args_str = json.dumps(ctx.tool_input or {}, ensure_ascii=False).lower() for cmd in dangerous: if cmd.lower() in args_str: return {"allowed": False, "reason": f"检测到危险命令模式: {cmd}"} return {"allowed": True} return HookConfig( event=HookEvent.PRE_TOOL_USE, matcher="command_exec", description="安全拦截:检测危险命令", python_handler=security_handler, )