352 lines
14 KiB
Python
352 lines
14 KiB
Python
|
|
"""
|
|||
|
|
Hook 系统 — 事件钩子注册/触发框架
|
|||
|
|
|
|||
|
|
参考 Claude Code src/utils/hooks.ts 设计:
|
|||
|
|
- 6 种事件: UserPromptSubmit / PreToolUse / PostToolUse / Stop / SessionStart / Notification
|
|||
|
|
- 3 种 Hook 类型: shell / python / http
|
|||
|
|
- 通配符匹配: tool_name 支持 * 前缀匹配
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import fnmatch
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from enum import Enum
|
|||
|
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ──────────────────────────── 事件类型 ────────────────────────────
|
|||
|
|
|
|||
|
|
class HookEvent(str, Enum):
|
|||
|
|
"""Hook 事件类型 — 参考 Claude Code Hooks 接口"""
|
|||
|
|
USER_PROMPT_SUBMIT = "UserPromptSubmit" # 用户提交输入前
|
|||
|
|
PRE_TOOL_USE = "PreToolUse" # 工具执行前
|
|||
|
|
POST_TOOL_USE = "PostToolUse" # 工具执行后
|
|||
|
|
STOP = "Stop" # 对话完成
|
|||
|
|
SESSION_START = "SessionStart" # 会话启动
|
|||
|
|
NOTIFICATION = "Notification" # 事件通知
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ──────────────────────────── 数据结构 ────────────────────────────
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class HookConfig:
|
|||
|
|
"""单个 Hook 的配置"""
|
|||
|
|
event: HookEvent
|
|||
|
|
matcher: str = "*" # 工具名/事件名匹配,支持 * 通配符
|
|||
|
|
description: str = ""
|
|||
|
|
|
|||
|
|
# Hook 处理器(三选一)
|
|||
|
|
shell_command: Optional[str] = None # Shell 命令
|
|||
|
|
python_handler: Optional[Callable[..., Any]] = None # Python 异步函数
|
|||
|
|
http_url: Optional[str] = None # HTTP 端点
|
|||
|
|
|
|||
|
|
timeout_ms: int = 60000
|
|||
|
|
enabled: bool = True
|
|||
|
|
|
|||
|
|
def matches(self, tool_name: str) -> bool:
|
|||
|
|
"""检查工具名是否匹配此 Hook 的 matcher 模式。"""
|
|||
|
|
if not self.enabled:
|
|||
|
|
return False
|
|||
|
|
if self.matcher == "*":
|
|||
|
|
return True
|
|||
|
|
return fnmatch.fnmatch(tool_name, self.matcher)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class HookContext:
|
|||
|
|
"""传递给 Hook 的上下文数据"""
|
|||
|
|
event: HookEvent
|
|||
|
|
tool_name: Optional[str] = None
|
|||
|
|
tool_input: Optional[Dict[str, Any]] = None
|
|||
|
|
tool_output: Optional[str] = None
|
|||
|
|
session_id: Optional[str] = None
|
|||
|
|
agent_name: Optional[str] = None
|
|||
|
|
user_id: Optional[str] = None
|
|||
|
|
messages: Optional[List[Dict[str, Any]]] = None
|
|||
|
|
extra: Dict[str, Any] = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
def to_dict(self) -> Dict[str, Any]:
|
|||
|
|
"""序列化为 JSON-serializable 字典(用于 shell/http hook)。"""
|
|||
|
|
return {
|
|||
|
|
"event": self.event.value,
|
|||
|
|
"tool_name": self.tool_name,
|
|||
|
|
"tool_input": self.tool_input,
|
|||
|
|
"tool_output": (self.tool_output[:2000] if self.tool_output else None),
|
|||
|
|
"session_id": self.session_id,
|
|||
|
|
"agent_name": self.agent_name,
|
|||
|
|
"user_id": self.user_id,
|
|||
|
|
"extra": self.extra,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class HookResult:
|
|||
|
|
"""Hook 执行结果"""
|
|||
|
|
allowed: bool = True # False = 拒绝操作
|
|||
|
|
reason: str = "" # 拒绝原因
|
|||
|
|
modified_input: Optional[Dict[str, Any]] = None # PreToolUse 可修改工具参数
|
|||
|
|
modified_messages: Optional[List[Dict[str, Any]]] = None # UserPromptSubmit 可修改消息
|
|||
|
|
data: Dict[str, Any] = field(default_factory=dict) # 额外数据
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ──────────────────────────── Hook 管理器 ────────────────────────────
|
|||
|
|
|
|||
|
|
class HookManager:
|
|||
|
|
"""
|
|||
|
|
Hook 事件管理与触发。
|
|||
|
|
|
|||
|
|
用法:
|
|||
|
|
manager = HookManager()
|
|||
|
|
manager.register(HookConfig(
|
|||
|
|
event=HookEvent.PRE_TOOL_USE,
|
|||
|
|
matcher="Bash*",
|
|||
|
|
shell_command="echo 'Bash tool called' >&2",
|
|||
|
|
))
|
|||
|
|
result = await manager.trigger(HookEvent.PRE_TOOL_USE, HookContext(...))
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, hooks: Optional[List[HookConfig]] = None):
|
|||
|
|
self._hooks: Dict[HookEvent, List[HookConfig]] = {e: [] for e in HookEvent}
|
|||
|
|
for h in (hooks or []):
|
|||
|
|
self.register(h)
|
|||
|
|
|
|||
|
|
def register(self, config: HookConfig) -> None:
|
|||
|
|
"""注册一个 Hook。"""
|
|||
|
|
self._hooks[config.event].append(config)
|
|||
|
|
logger.info("Hook 注册: event=%s matcher=%s", config.event.value, config.matcher)
|
|||
|
|
|
|||
|
|
def unregister(self, event: HookEvent, matcher: str) -> int:
|
|||
|
|
"""移除匹配的 Hook,返回移除数量。"""
|
|||
|
|
before = len(self._hooks[event])
|
|||
|
|
self._hooks[event] = [h for h in self._hooks[event] if h.matcher != matcher]
|
|||
|
|
removed = before - len(self._hooks[event])
|
|||
|
|
logger.info("Hook 移除: event=%s matcher=%s removed=%d", event.value, matcher, removed)
|
|||
|
|
return removed
|
|||
|
|
|
|||
|
|
def get_hooks(self, event: HookEvent) -> List[HookConfig]:
|
|||
|
|
"""获取指定事件的所有 Hook。"""
|
|||
|
|
return list(self._hooks.get(event, []))
|
|||
|
|
|
|||
|
|
async def trigger(
|
|||
|
|
self,
|
|||
|
|
event: HookEvent,
|
|||
|
|
context: HookContext,
|
|||
|
|
) -> HookResult:
|
|||
|
|
"""
|
|||
|
|
触发指定事件的匹配 Hook。
|
|||
|
|
|
|||
|
|
执行顺序: 按注册顺序依次执行所有匹配的 Hook。
|
|||
|
|
如果任一 Hook 返回 allowed=False,立即返回拒绝结果。
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
聚合的 HookResult;如果多个 Hook 都修改了输入,最后一次修改生效。
|
|||
|
|
"""
|
|||
|
|
final_result = HookResult(allowed=True)
|
|||
|
|
matching = [h for h in self._hooks.get(event, [])
|
|||
|
|
if h.matches(context.tool_name or "*")]
|
|||
|
|
|
|||
|
|
if not matching:
|
|||
|
|
return final_result
|
|||
|
|
|
|||
|
|
logger.debug("触发 Hook event=%s tool=%s hooks=%d",
|
|||
|
|
event.value, context.tool_name, len(matching))
|
|||
|
|
|
|||
|
|
for hook in matching:
|
|||
|
|
try:
|
|||
|
|
result = await asyncio.wait_for(
|
|||
|
|
self._execute_hook(hook, context),
|
|||
|
|
timeout=hook.timeout_ms / 1000,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not result.allowed:
|
|||
|
|
logger.warning(
|
|||
|
|
"Hook 拒绝操作: event=%s tool=%s reason=%s",
|
|||
|
|
event.value, context.tool_name, result.reason,
|
|||
|
|
)
|
|||
|
|
# 被拒绝时接管后续流程不被执行(但继续执行剩余 hooks 以便通知/审计)
|
|||
|
|
final_result.allowed = False
|
|||
|
|
final_result.reason = final_result.reason or result.reason
|
|||
|
|
|
|||
|
|
if result.modified_input is not None:
|
|||
|
|
final_result.modified_input = result.modified_input
|
|||
|
|
|
|||
|
|
if result.modified_messages is not None:
|
|||
|
|
final_result.modified_messages = result.modified_messages
|
|||
|
|
|
|||
|
|
if result.data:
|
|||
|
|
final_result.data.update(result.data)
|
|||
|
|
|
|||
|
|
except asyncio.TimeoutError:
|
|||
|
|
logger.error("Hook 超时 (%.1fs): event=%s matcher=%s",
|
|||
|
|
hook.timeout_ms / 1000, event.value, hook.matcher)
|
|||
|
|
except Exception:
|
|||
|
|
logger.exception("Hook 执行异常: event=%s matcher=%s",
|
|||
|
|
event.value, hook.matcher)
|
|||
|
|
|
|||
|
|
return final_result
|
|||
|
|
|
|||
|
|
async def _execute_hook(self, hook: HookConfig, context: HookContext) -> HookResult:
|
|||
|
|
"""执行单个 Hook(shell / python / http)。"""
|
|||
|
|
if hook.shell_command:
|
|||
|
|
return await self._execute_shell_hook(hook, context)
|
|||
|
|
if hook.python_handler:
|
|||
|
|
return await self._execute_python_hook(hook, context)
|
|||
|
|
if hook.http_url:
|
|||
|
|
return await self._execute_http_hook(hook, context)
|
|||
|
|
return HookResult(allowed=True)
|
|||
|
|
|
|||
|
|
# ── Shell Hook ──
|
|||
|
|
|
|||
|
|
async def _execute_shell_hook(self, hook: HookConfig, context: HookContext) -> HookResult:
|
|||
|
|
"""执行 Shell Hook: stdin 传入 JSON context,stdout 读取结果。"""
|
|||
|
|
import shlex
|
|||
|
|
|
|||
|
|
ctx_json = json.dumps(context.to_dict(), ensure_ascii=False)
|
|||
|
|
|
|||
|
|
proc = await asyncio.create_subprocess_exec(
|
|||
|
|
*shlex.split(hook.shell_command or "true"),
|
|||
|
|
stdin=asyncio.subprocess.PIPE,
|
|||
|
|
stdout=asyncio.subprocess.PIPE,
|
|||
|
|
stderr=asyncio.subprocess.PIPE,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
stdout, stderr = await asyncio.wait_for(
|
|||
|
|
proc.communicate(ctx_json.encode("utf-8")),
|
|||
|
|
timeout=hook.timeout_ms / 1000,
|
|||
|
|
)
|
|||
|
|
except asyncio.TimeoutError:
|
|||
|
|
proc.kill()
|
|||
|
|
await proc.wait()
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
if proc.returncode != 0:
|
|||
|
|
logger.warning("Shell Hook 退出非零: rc=%d stderr=%s",
|
|||
|
|
proc.returncode, stderr.decode()[:200])
|
|||
|
|
return HookResult(
|
|||
|
|
allowed=False,
|
|||
|
|
reason=f"Hook 返回非零退出码: {proc.returncode}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 解析 stdout 为 HookResult
|
|||
|
|
stdout_text = stdout.decode("utf-8").strip()
|
|||
|
|
if not stdout_text:
|
|||
|
|
return HookResult(allowed=True)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
data = json.loads(stdout_text)
|
|||
|
|
return HookResult(
|
|||
|
|
allowed=data.get("allowed", True),
|
|||
|
|
reason=data.get("reason", ""),
|
|||
|
|
modified_input=data.get("modified_input"),
|
|||
|
|
modified_messages=data.get("modified_messages"),
|
|||
|
|
data=data.get("data", {}),
|
|||
|
|
)
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
# stdout 不是 JSON → 视为 stdout 内容,不影响执行
|
|||
|
|
logger.debug("Shell Hook stdout (非JSON): %.200s", stdout_text)
|
|||
|
|
return HookResult(allowed=True)
|
|||
|
|
|
|||
|
|
# ── Python Hook ──
|
|||
|
|
|
|||
|
|
async def _execute_python_hook(self, hook: HookConfig, context: HookContext) -> HookResult:
|
|||
|
|
"""执行 Python Hook: 直接调用 async 函数。"""
|
|||
|
|
if not hook.python_handler:
|
|||
|
|
return HookResult(allowed=True)
|
|||
|
|
|
|||
|
|
result = hook.python_handler(context)
|
|||
|
|
if asyncio.iscoroutine(result):
|
|||
|
|
result = await result
|
|||
|
|
|
|||
|
|
if result is None:
|
|||
|
|
return HookResult(allowed=True)
|
|||
|
|
if isinstance(result, HookResult):
|
|||
|
|
return result
|
|||
|
|
if isinstance(result, dict):
|
|||
|
|
return HookResult(
|
|||
|
|
allowed=result.get("allowed", True),
|
|||
|
|
reason=result.get("reason", ""),
|
|||
|
|
modified_input=result.get("modified_input"),
|
|||
|
|
modified_messages=result.get("modified_messages"),
|
|||
|
|
data=result,
|
|||
|
|
)
|
|||
|
|
if isinstance(result, bool):
|
|||
|
|
return HookResult(allowed=result)
|
|||
|
|
return HookResult(allowed=True)
|
|||
|
|
|
|||
|
|
# ── HTTP Hook ──
|
|||
|
|
|
|||
|
|
async def _execute_http_hook(self, hook: HookConfig, context: HookContext) -> HookResult:
|
|||
|
|
"""执行 HTTP Hook: POST JSON context 到外部服务。"""
|
|||
|
|
try:
|
|||
|
|
import httpx
|
|||
|
|
except ImportError:
|
|||
|
|
logger.error("HTTP Hook 需要 httpx 库")
|
|||
|
|
return HookResult(allowed=True)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
async with httpx.AsyncClient(timeout=hook.timeout_ms / 1000) as client:
|
|||
|
|
resp = await client.post(
|
|||
|
|
hook.http_url or "",
|
|||
|
|
json=context.to_dict(),
|
|||
|
|
headers={"Content-Type": "application/json"},
|
|||
|
|
)
|
|||
|
|
if resp.status_code >= 400:
|
|||
|
|
logger.warning("HTTP Hook 返回 %d: %s", resp.status_code, resp.text[:200])
|
|||
|
|
return HookResult(
|
|||
|
|
allowed=False,
|
|||
|
|
reason=f"HTTP Hook 返回 {resp.status_code}",
|
|||
|
|
)
|
|||
|
|
data = resp.json() if resp.text else {}
|
|||
|
|
return HookResult(
|
|||
|
|
allowed=data.get("allowed", True),
|
|||
|
|
reason=data.get("reason", ""),
|
|||
|
|
modified_input=data.get("modified_input"),
|
|||
|
|
modified_messages=data.get("modified_messages"),
|
|||
|
|
data=data,
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("HTTP Hook 调用失败: %s", e)
|
|||
|
|
return HookResult(allowed=True) # HTTP hook 失败不阻断执行
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ──────────────────────────── 内置 Hook 示例 ────────────────────────────
|
|||
|
|
|
|||
|
|
def create_audit_log_hook():
|
|||
|
|
"""创建审计日志 Hook — 记录所有工具调用到日志。"""
|
|||
|
|
async def audit_handler(ctx: HookContext) -> None:
|
|||
|
|
logger.info(
|
|||
|
|
"[AUDIT] event=%s tool=%s agent=%s session=%s",
|
|||
|
|
ctx.event.value, ctx.tool_name, ctx.agent_name, ctx.session_id,
|
|||
|
|
)
|
|||
|
|
return HookConfig(
|
|||
|
|
event=HookEvent.PRE_TOOL_USE,
|
|||
|
|
matcher="*",
|
|||
|
|
description="审计日志:记录所有工具调用",
|
|||
|
|
python_handler=audit_handler,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_security_hook(forbidden_commands: Optional[List[str]] = None):
|
|||
|
|
"""创建安全 Hook — 拦截危险命令。"""
|
|||
|
|
dangerous = forbidden_commands or ["rm -rf", "sudo", "chmod 777", "DROP TABLE"]
|
|||
|
|
|
|||
|
|
async def security_handler(ctx: HookContext) -> dict:
|
|||
|
|
args_str = json.dumps(ctx.tool_input or {}, ensure_ascii=False).lower()
|
|||
|
|
for cmd in dangerous:
|
|||
|
|
if cmd.lower() in args_str:
|
|||
|
|
return {"allowed": False, "reason": f"检测到危险命令模式: {cmd}"}
|
|||
|
|
return {"allowed": True}
|
|||
|
|
|
|||
|
|
return HookConfig(
|
|||
|
|
event=HookEvent.PRE_TOOL_USE,
|
|||
|
|
matcher="command_exec",
|
|||
|
|
description="安全拦截:检测危险命令",
|
|||
|
|
python_handler=security_handler,
|
|||
|
|
)
|