diff --git a/(红头)项目核心文档汇总.md b/(红头)项目核心文档汇总.md
index f8dfa9a..a9123bd 100644
--- a/(红头)项目核心文档汇总.md
+++ b/(红头)项目核心文档汇总.md
@@ -87,6 +87,7 @@
- **AI框架**: LangChain
- **Agent Runtime**: 自研 ReAct 循环(零重构,寄生式复用现有服务)
- **Agent Orchestrator**: 多 Agent 编排引擎(路由/顺序/辩论三种模式)
+- **Agent 监控**: LLM 调用埋点 + 专属 Dashboard(Token/工具/Agent 用量统计)
- **数据库ORM**: SQLAlchemy
- **迁移工具**: Alembic
- **认证**: JWT
@@ -212,11 +213,14 @@ aiagent/
│ │ │ ├── orchestrator.py# 多 Agent 编排引擎(route/sequential/debate)
│ │ │ └── workflow_integration.py # 工作流桥接
│ │ ├── api/ # API路由
-│ │ │ └── agent_chat.py # Agent 独立聊天 API(新增)
+│ │ │ ├── agent_chat.py # Agent 聊天 + 多 Agent 编排 + LLM 调用日志(新增)
+│ │ │ └── agent_monitoring.py # Agent 监控 API 5 个端点(新增)
│ │ ├── core/ # 核心模块
│ │ ├── models/ # 数据库模型
+│ │ │ └── agent_llm_log.py # Agent LLM 调用日志模型(新增)
│ │ ├── schemas/ # Pydantic模式
│ │ ├── services/ # 业务逻辑
+│ │ │ └── agent_monitoring_service.py # Agent 监控服务(新增)
│ │ └── main.py # 应用入口
│ ├── alembic/ # 数据库迁移
│ ├── tests/ # 测试
@@ -328,17 +332,19 @@ pnpm dev
- `POST /api/v1/data-sources/{id}/test` - 测试数据源连接
- `POST /api/v1/data-sources/{id}/query` - 执行数据查询
-### Agent 对话 API(新增)
+### Agent 与编排 API(新增)
- `POST /api/v1/agent-chat/bare` - 默认 Agent 直接对话(无需预配置)
- `POST /api/v1/agent-chat/{agent_id}` - 与指定 Agent 对话(复用工作流配置)
- `POST /api/v1/agent-chat/orchestrate` - 多 Agent 编排(route/sequential/debate 三种模式)
-- `GET /api/v1/agent-monitoring/overview` - Agent 概览统计
-- `GET /api/v1/agent-monitoring/llm-calls` - LLM 调用记录列表(支持 days/limit 参数)
-- `GET /api/v1/agent-monitoring/agents-stats` - 各 Agent 用量排行
-- `GET /api/v1/agent-monitoring/tool-usage` - 工具调用频次统计
-- `GET /api/v1/agent-monitoring/daily-trend` - 每日 LLM 调用趋势
- - 请求体包含 message、mode、agents 列表(每个 Agent 可独立配置 system_prompt/model/temperature/tools)
- - 返回 final_answer、steps 追踪、agent_results
+ - 请求体: `{ message, mode, agents[] }` 每个 Agent 可独立配置 system_prompt/model/temperature/tools
+ - 返回: `{ final_answer, steps[], agent_results[] }`
+
+### Agent 监控 API(新增)
+- `GET /api/v1/agent-monitoring/overview` - Agent 概览统计(Agent 数、对话/LLM/工具调用次数、Token 用量)
+- `GET /api/v1/agent-monitoring/llm-calls?days=7&limit=50` - LLM 调用记录列表(模型、tokens、耗时、状态)
+- `GET /api/v1/agent-monitoring/agents-stats?days=7` - 各 Agent 用量排行(调用次数、Token、延迟)
+- `GET /api/v1/agent-monitoring/tool-usage?days=7` - 工具调用频次统计
+- `GET /api/v1/agent-monitoring/daily-trend?days=7` - 每日 LLM 调用趋势
### WebSocket API
- `ws://localhost:8037/ws/execution/{execution_id}` - 执行状态实时推送
@@ -457,7 +463,8 @@ alembic downgrade -1
- **第四-七阶段功能**: 100% ✅
- **自主 Agent Runtime**: 100% ✅(2026-04 新增)
- **多 Agent 编排**: 100% ✅(2026-05 新增)
-- **整体项目**: 约 95-97%
+- **Agent 监控 Dashboard**: 100% ✅(2026-05 新增)
+- **整体项目**: 约 97-98%
### 已完成核心功能
1. **完整的用户认证系统** - 注册、登录、JWT认证
@@ -481,8 +488,8 @@ alembic downgrade -1
### 近期开发重点(高优先级)
1. **预算接入** - Agent 内部 LLM 调用计入工作流执行预算
-2. **Agent Dashboard** - LLM 调用链路追踪、Token 消耗统计、执行历史
-3. **用户体验优化** - 工作流编辑器优化、Agent使用体验优化
+2. **用户体验优化** - 工作流编辑器优化、Agent使用体验优化
+3. **流式输出** - Agent 思考过程实时推送到前端
### 中期规划
1. **向量记忆** - 集成 Embedding API + 向量检索(语义记忆)
@@ -550,6 +557,6 @@ alembic downgrade -1
---
**最后更新**: 2026-05-01
-**文档版本**: 1.5
+**文档版本**: 1.6
*本文档基于项目现有文档整理生成,涵盖项目核心信息。详细技术方案请参考[方案-优化版.md](./方案-优化版.md)。DeepSeek 模型名与 Base URL 以官方文档为准,变更时请同步修订本节。*
\ No newline at end of file
diff --git a/backend/app/agent_runtime/__init__.py b/backend/app/agent_runtime/__init__.py
index df7e4f7..54af7e2 100644
--- a/backend/app/agent_runtime/__init__.py
+++ b/backend/app/agent_runtime/__init__.py
@@ -15,6 +15,7 @@ from app.agent_runtime.schemas import (
AgentLLMConfig,
AgentToolConfig,
AgentMemoryConfig,
+ AgentBudgetConfig,
AgentStep,
)
from app.agent_runtime.context import AgentContext
@@ -34,6 +35,7 @@ __all__ = [
"AgentLLMConfig",
"AgentToolConfig",
"AgentMemoryConfig",
+ "AgentBudgetConfig",
"AgentContext",
"AgentMemory",
"AgentToolManager",
diff --git a/backend/app/agent_runtime/core.py b/backend/app/agent_runtime/core.py
index 845b214..17503f5 100644
--- a/backend/app/agent_runtime/core.py
+++ b/backend/app/agent_runtime/core.py
@@ -13,7 +13,7 @@ from __future__ import annotations
import json
import logging
import time
-from typing import Any, Callable, Dict, List, Optional, Protocol, TypedDict
+from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Protocol, TypedDict
from app.agent_runtime.schemas import (
AgentConfig,
@@ -23,6 +23,7 @@ from app.agent_runtime.schemas import (
from app.agent_runtime.context import AgentContext
from app.agent_runtime.memory import AgentMemory
from app.agent_runtime.tool_manager import AgentToolManager
+from app.core.exceptions import WorkflowExecutionError
logger = logging.getLogger(__name__)
@@ -95,6 +96,11 @@ class AgentRuntime:
self.on_tool_executed = on_tool_executed
self.on_llm_call = on_llm_call
self._memory_context_loaded = False
+ self._llm_invocations = 0
+
+ # 预算回调:供 WorkflowEngine 注入,使 Agent 内部计数计入工作流预算
+ # 返回 True 表示预算充足;返回 False 或抛出异常表示超限
+ self.on_llm_invocation: Optional[Callable[[], Any]] = None
async def run(self, user_input: str) -> AgentResult:
"""
@@ -108,7 +114,7 @@ class AgentRuntime:
# 1. 首次运行时加载长期记忆到 system prompt
if not self._memory_context_loaded:
- await self._inject_memory_context()
+ await self._inject_memory_context(user_input)
self._memory_context_loaded = True
# 2. 追加用户消息
@@ -139,6 +145,32 @@ class AgentRuntime:
# 裁剪过长历史
messages = self.memory.trim_messages(self.context.messages)
+ # 预算检查:LLM 调用次数(在调用 LLM 之前检查,避免浪费额度)
+ budget = self.config.budget
+ if self._llm_invocations >= budget.max_llm_invocations:
+ err = f"已超过 LLM 调用预算({budget.max_llm_invocations} 次)"
+ logger.warning(err)
+ steps.append(AgentStep(iteration=self.context.iteration, type="final", content=err))
+ await self.memory.save_context(user_input, err, self.context.messages)
+ return AgentResult(success=False, content=err, truncated=True,
+ iterations_used=self.context.iteration,
+ tool_calls_made=self.context.tool_calls_made,
+ steps=steps, error=err)
+
+ # 调用外部 LLM 预算回调(WorkflowEngine 注入,将 Agent 的 LLM 计入工作流预算)
+ if self.on_llm_invocation:
+ try:
+ self.on_llm_invocation()
+ except Exception as e:
+ err = f"LLM 调用超出工作流预算: {e}"
+ logger.warning(err)
+ steps.append(AgentStep(iteration=self.context.iteration, type="final", content=err))
+ await self.memory.save_context(user_input, err, self.context.messages)
+ return AgentResult(success=False, content=err, truncated=True,
+ iterations_used=self.context.iteration,
+ tool_calls_made=self.context.tool_calls_made,
+ steps=steps, error=str(e))
+
# 调用 LLM
try:
response = await llm.chat(
@@ -166,6 +198,9 @@ class AgentRuntime:
error=err_str,
)
+ # 记录 LLM 调用次数(内部计数)
+ self._llm_invocations += 1
+
# 解析工具调用
tool_calls = self._extract_tool_calls(response)
content = self._extract_content(response)
@@ -247,9 +282,22 @@ class AgentRuntime:
self.context.add_tool_result(tcid, tname, result)
self.context.tool_calls_made += 1
+ # 预算检查:工具调用次数
+ if self.context.tool_calls_made > budget.max_tool_calls:
+ err = f"已超过工具调用预算({budget.max_tool_calls} 次)"
+ logger.warning(err)
+ steps.append(AgentStep(iteration=self.context.iteration, type="tool_result",
+ content=err, tool_name=tname))
+ return AgentResult(success=False, content=err, truncated=True,
+ iterations_used=self.context.iteration,
+ tool_calls_made=self.context.tool_calls_made,
+ steps=steps, error=err)
+
if self.on_tool_executed:
try:
await self.on_tool_executed(tname)
+ except WorkflowExecutionError:
+ raise
except Exception:
pass
@@ -284,9 +332,240 @@ class AgentRuntime:
steps=steps,
)
- async def _inject_memory_context(self) -> None:
+ async def run_stream(self, user_input: str) -> AsyncGenerator[dict, None]:
+ """
+ 流式执行 Agent 单轮对话。
+
+ 与 run() 逻辑相同,但在每个关键步骤 yield SSE 事件:
+ - think: LLM 思考中,准备调用工具
+ - tool_call: 即将执行工具
+ - tool_result: 工具执行完毕
+ - final: 最终回答
+ - error: 出错/预算超限
+ """
+ max_iter = max(1, self.config.llm.max_iterations)
+ self.context.iteration = 0
+ self.context.tool_calls_made = 0
+
+ # 1. 首次运行时加载长期记忆到 system prompt
+ if not self._memory_context_loaded:
+ await self._inject_memory_context(user_input)
+ self._memory_context_loaded = True
+
+ # 2. 追加用户消息
+ self.context.add_user_message(user_input)
+
+ # 3. ReAct 循环
+ llm = _LLMClient(self.config.llm)
+ tool_schemas = self.tool_manager.get_tool_schemas()
+ has_tools = self.tool_manager.has_tools()
+ steps: List[AgentStep] = []
+
+ llm_callback_ctx = {"step_type": "think", "tool_name": None}
+
+ def _llm_callback(metrics: Dict[str, Any]):
+ if self.on_llm_call:
+ metrics.update({
+ "session_id": self.context.session_id,
+ "user_id": self.config.user_id,
+ "step_type": llm_callback_ctx["step_type"],
+ "tool_name": llm_callback_ctx["tool_name"],
+ })
+ self.on_llm_call(metrics)
+
+ while self.context.iteration < max_iter:
+ self.context.iteration += 1
+ messages = self.memory.trim_messages(self.context.messages)
+
+ # 预算检查:LLM 调用次数(在调用 LLM 之前检查,避免浪费额度)
+ budget = self.config.budget
+ if self._llm_invocations >= budget.max_llm_invocations:
+ err = f"已超过 LLM 调用预算({budget.max_llm_invocations} 次)"
+ logger.warning(err)
+ yield {"type": "error", "content": err, "iteration": self.context.iteration,
+ "truncated": True}
+ await self.memory.save_context(user_input, err, self.context.messages)
+ return
+
+ # 调用外部 LLM 预算回调(WorkflowEngine 注入)
+ if self.on_llm_invocation:
+ try:
+ self.on_llm_invocation()
+ except Exception as e:
+ err = f"LLM 调用超出工作流预算: {e}"
+ logger.warning(err)
+ yield {"type": "error", "content": err, "iteration": self.context.iteration,
+ "truncated": True}
+ return
+
+ # 调用 LLM
+ try:
+ response = await llm.chat(
+ messages=messages,
+ tools=tool_schemas if has_tools and self.context.iteration == 1 else
+ (tool_schemas if has_tools else None),
+ iteration=self.context.iteration,
+ on_completion=_llm_callback,
+ )
+ except Exception as e:
+ err_str = str(e)
+ logger.error("LLM 调用失败 (iteration=%s): %s", self.context.iteration, err_str)
+ if self.context.iteration < max_iter and self._is_retryable(err_str):
+ yield {"type": "error", "content": f"LLM 调用失败(可重试): {err_str}",
+ "iteration": self.context.iteration}
+ continue
+ yield {"type": "error", "content": f"LLM 调用失败: {err_str}",
+ "iteration": self.context.iteration}
+ return
+
+ # 记录 LLM 调用次数(内部计数)
+ self._llm_invocations += 1
+
+ # 解析工具调用
+ tool_calls = self._extract_tool_calls(response)
+ content = self._extract_content(response)
+ reasoning = getattr(response, "reasoning_content", None) or (
+ response.get("reasoning_content") if isinstance(response, dict) else None
+ )
+
+ if not tool_calls:
+ # LLM 直接返回文本 → 结束
+ self.context.add_assistant_message(content)
+ final_text = content or "(模型未返回有效内容)"
+ yield {
+ "type": "final",
+ "content": final_text,
+ "reasoning": reasoning,
+ "iteration": self.context.iteration,
+ "iterations_used": self.context.iteration,
+ "tool_calls_made": self.context.tool_calls_made,
+ "session_id": self.context.session_id,
+ }
+ await self.memory.save_context(user_input, final_text, self.context.messages)
+ return
+
+ # 有工具调用 → 先记录 assistant 消息
+ self.context.add_assistant_message(content or "", tool_calls, reasoning)
+
+ # yield think 事件
+ tc_names = [tc["function"]["name"] for tc in tool_calls]
+ tc_args_list = []
+ for tc in tool_calls:
+ try:
+ tc_args_list.append(json.loads(tc["function"].get("arguments", "{}")))
+ except (json.JSONDecodeError, TypeError):
+ tc_args_list.append({})
+
+ yield {
+ "type": "think",
+ "content": content or f"调用工具: {', '.join(tc_names)}",
+ "reasoning": reasoning,
+ "tool_names": tc_names,
+ "iteration": self.context.iteration,
+ }
+
+ steps.append(AgentStep(
+ iteration=self.context.iteration,
+ type="think",
+ content=content or f"调用工具: {', '.join(tc_names)}",
+ reasoning=reasoning,
+ tool_name=tc_names[0] if len(tc_names) == 1 else None,
+ tool_input=tc_args_list[0] if len(tc_args_list) == 1 else None,
+ ))
+
+ if self.execution_logger:
+ self.execution_logger.info(
+ f"Agent 调用 {len(tool_calls)} 个工具",
+ data={"tool_calls": tc_names, "iteration": self.context.iteration},
+ )
+
+ # 逐一执行工具
+ for tc in tool_calls:
+ tfn = tc.get("function", {})
+ tname = tfn.get("name", "unknown")
+ tcid = tc.get("id", f"call_{self.context.iteration}_{self.context.tool_calls_made}")
+
+ try:
+ targs = json.loads(tfn.get("arguments", "{}"))
+ except (json.JSONDecodeError, TypeError):
+ targs = {}
+
+ # yield tool_call 事件
+ yield {
+ "type": "tool_call",
+ "name": tname,
+ "input": targs,
+ "iteration": self.context.iteration,
+ }
+
+ logger.info("Agent 执行工具 [%s]: %s", tname, targs)
+ result = await self.tool_manager.execute(tname, targs)
+
+ # yield tool_result 事件
+ yield {
+ "type": "tool_result",
+ "name": tname,
+ "result": result[:500] + "..." if len(result) > 500 else result,
+ "iteration": self.context.iteration,
+ }
+
+ steps.append(AgentStep(
+ iteration=self.context.iteration,
+ type="tool_result",
+ content=f"工具 {tname} 返回结果",
+ tool_name=tname,
+ tool_input=targs,
+ tool_result=result[:500] + "..." if len(result) > 500 else result,
+ ))
+
+ self.context.add_tool_result(tcid, tname, result)
+ self.context.tool_calls_made += 1
+
+ # 预算检查:工具调用次数
+ if self.context.tool_calls_made > budget.max_tool_calls:
+ err = f"已超过工具调用预算({budget.max_tool_calls} 次)"
+ logger.warning(err)
+ yield {"type": "error", "content": err, "iteration": self.context.iteration,
+ "truncated": True}
+ return
+
+ if self.on_tool_executed:
+ try:
+ await self.on_tool_executed(tname)
+ except WorkflowExecutionError:
+ raise
+ except Exception:
+ pass
+
+ if self.execution_logger:
+ preview = result[:300] + "..." if len(result) > 300 else result
+ self.execution_logger.info(
+ f"工具 {tname} 执行完成",
+ data={"tool_name": tname, "result_preview": preview},
+ )
+
+ # 达到最大迭代次数
+ last_content = ""
+ for m in reversed(self.context.messages):
+ if m.get("role") == "assistant" and m.get("content"):
+ last_content = m["content"]
+ break
+
+ logger.warning("Agent 达到最大迭代次数 (%s)", max_iter)
+ await self.memory.save_context(user_input, last_content or "(已达最大迭代次数)", self.context.messages)
+ yield {
+ "type": "final",
+ "content": last_content or "已达最大迭代次数,但模型未返回最终回答。",
+ "iteration": self.context.iteration,
+ "iterations_used": self.context.iteration,
+ "tool_calls_made": self.context.tool_calls_made,
+ "truncated": True,
+ "session_id": self.context.session_id,
+ }
+
+ async def _inject_memory_context(self, query: str = "") -> None:
"""加载长期记忆并注入 system prompt。"""
- mem_text = await self.memory.initialize()
+ mem_text = await self.memory.initialize(query=query)
if mem_text:
enriched = (
self.config.system_prompt.rstrip("\n")
diff --git a/backend/app/agent_runtime/memory.py b/backend/app/agent_runtime/memory.py
index facb27f..7824cce 100644
--- a/backend/app/agent_runtime/memory.py
+++ b/backend/app/agent_runtime/memory.py
@@ -15,6 +15,7 @@ from app.services.persistent_memory_service import (
save_persistent_memory,
persist_enabled,
)
+from app.services.embedding_service import embedding_service, VectorEntry
logger = logging.getLogger(__name__)
@@ -35,25 +36,31 @@ class AgentMemory:
session_key: Optional[str] = None,
persist: bool = True,
max_history: int = 20,
+ vector_memory_enabled: bool = True,
+ vector_memory_top_k: int = 5,
):
self.scope_kind = scope_kind
self.scope_id = scope_id or "default"
self.session_key = session_key or "default_session"
self.persist = persist and persist_enabled()
self.max_history = max_history
+ self.vector_memory_enabled = vector_memory_enabled
+ self.vector_memory_top_k = vector_memory_top_k
# 从长期记忆加载的上下文(启动时加载)
self._long_term_context: Dict[str, Any] = {}
# 记录已压缩的消息数,避免重复压缩
self._last_compressed_msg_count = 0
- async def initialize(self) -> str:
+ async def initialize(self, query: str = "") -> str:
"""
- 初始化记忆:从 DB/Redis 加载长期记忆,构造初始上下文文本。
+ 初始化记忆:从 DB/Redis 加载长期记忆 + 向量检索相关历史。
返回注入 system prompt 的记忆文本块。
"""
if not self.persist or not self.scope_id:
return ""
+ parts: List[str] = []
+
db: Optional[Session] = None
try:
db = SessionLocal()
@@ -62,8 +69,6 @@ class AgentMemory:
)
if payload and isinstance(payload, dict):
self._long_term_context = payload
- # 构建注入 system prompt 的记忆文本
- parts = []
profile = payload.get("user_profile")
if profile and isinstance(profile, dict):
profile_text = json.dumps(profile, ensure_ascii=False)
@@ -78,15 +83,93 @@ class AgentMemory:
if history and isinstance(history, list) and len(history) > 0:
summary = self._summarize_history(history)
parts.append(f"## 历史对话摘要\n{summary}")
-
- if parts:
- return "\n\n".join(parts)
except Exception as e:
logger.warning("加载长期记忆失败: %s", e)
finally:
if db:
db.close()
- return ""
+
+ # 2. 向量检索:查找语义相关的历史对话
+ if self.vector_memory_enabled and self.scope_kind and self.scope_id:
+ vector_text = await self._vector_search(query)
+ if vector_text:
+ parts.append(vector_text)
+
+ return "\n\n".join(parts) if parts else ""
+
+ async def _vector_search(self, query: str = "") -> str:
+ """
+ 向量检索语义相关的历史记忆,返回格式化的文本块。
+ 若无 query 则返回最近 Top-5 条记忆。
+ """
+ from app.models.agent_vector_memory import AgentVectorMemory
+
+ db: Optional[Session] = None
+ try:
+ db = SessionLocal()
+ # 查询当前 scope 的所有向量记忆(按时间倒序)
+ rows = (
+ db.query(AgentVectorMemory)
+ .filter(
+ AgentVectorMemory.scope_kind == self.scope_kind,
+ AgentVectorMemory.scope_id == self.scope_id,
+ )
+ .order_by(AgentVectorMemory.created_at.desc())
+ .limit(50) # 最多取最近 50 条做相似度计算
+ .all()
+ )
+
+ if not rows:
+ return ""
+
+ entries: List[VectorEntry] = []
+ for row in rows:
+ emb = embedding_service.deserialize_embedding(row.embedding) if row.embedding else []
+ entries.append({
+ "id": row.id,
+ "scope_kind": row.scope_kind,
+ "scope_id": row.scope_id,
+ "content_text": row.content_text,
+ "embedding": emb,
+ "metadata": row.metadata_ or {},
+ })
+
+ matched: List[VectorEntry] = []
+
+ if query and query.strip():
+ # 有 query:生成 embedding 做语义搜索
+ query_emb = await embedding_service.generate_embedding(query)
+ if query_emb:
+ matched = await embedding_service.similarity_search(
+ query_emb, entries, top_k=self.vector_memory_top_k
+ )
+ else:
+ # 无 query:返回最近几条
+ matched = entries[: self.vector_memory_top_k]
+ for m in matched:
+ m["score"] = 1.0
+
+ if not matched:
+ return ""
+
+ # 格式化为文本块
+ lines = ["## 相关历史记忆"]
+ for i, m in enumerate(matched, 1):
+ text = m.get("content_text", "")[:500]
+ meta = m.get("metadata", {})
+ entry_type = meta.get("type", "对话")
+ lines.append(f"{i}. [{entry_type}] {text}")
+ if m.get("score", 1.0) < 1.0:
+ lines[-1] += f" (匹配度: {m['score']:.2f})"
+
+ return "\n".join(lines)
+
+ except Exception as e:
+ logger.warning("向量检索失败: %s", e)
+ return ""
+ finally:
+ if db:
+ db.close()
async def save_context(
self, user_message: str, assistant_reply: str,
@@ -114,12 +197,50 @@ class AgentMemory:
db, self.scope_kind, self.scope_id,
self.session_key, self._long_term_context,
)
+
+ # 保存向量记忆(异步生成 embedding 并存储)
+ if self.vector_memory_enabled:
+ await self._save_vector_memory(
+ db, user_message, assistant_reply
+ )
except Exception as e:
logger.warning("保存长期记忆失败: %s", e)
finally:
if db:
db.close()
+ async def _save_vector_memory(
+ self, db: Session, user_message: str, assistant_reply: str,
+ ) -> None:
+ """生成 embedding 并保存到向量记忆表。"""
+ from app.models.agent_vector_memory import AgentVectorMemory
+
+ content_text = f"用户: {user_message}\n助手: {assistant_reply}"
+ if len(content_text) > 8000:
+ content_text = content_text[:8000]
+
+ try:
+ # 生成 embedding
+ embedding = await embedding_service.generate_embedding(content_text)
+ embedding_json = embedding_service.serialize_embedding(embedding) if embedding else ""
+
+ record = AgentVectorMemory(
+ scope_kind=self.scope_kind,
+ scope_id=self.scope_id,
+ session_key=self.session_key,
+ content_text=content_text[:2000],
+ embedding=embedding_json or None,
+ metadata_={
+ "type": "conversation_turn",
+ },
+ )
+ db.add(record)
+ db.commit()
+ logger.debug("已保存向量记忆 (scope=%s/%s)", self.scope_kind, self.scope_id)
+ except Exception as e:
+ logger.warning("保存向量记忆失败: %s", e)
+ db.rollback()
+
async def _compress_and_summarize(
self, messages: List[Dict[str, Any]]
) -> None:
diff --git a/backend/app/agent_runtime/schemas.py b/backend/app/agent_runtime/schemas.py
index 367c0d7..de07ee5 100644
--- a/backend/app/agent_runtime/schemas.py
+++ b/backend/app/agent_runtime/schemas.py
@@ -20,6 +20,8 @@ class AgentMemoryConfig(BaseModel):
max_history_messages: int = 20 # 注入 LLM 的上文最大消息数
session_key: Optional[str] = None # 会话标识,默认自动生成
persist_to_db: bool = True # 是否写入 MySQL 长期记忆
+ vector_memory_enabled: bool = True # 是否启用向量记忆(语义检索)
+ vector_memory_top_k: int = 5 # 向量检索 Top-K
class AgentLLMConfig(BaseModel):
@@ -35,6 +37,12 @@ class AgentLLMConfig(BaseModel):
extra_body: Optional[Dict[str, Any]] = None
+class AgentBudgetConfig(BaseModel):
+ """Agent 执行预算配置"""
+ max_llm_invocations: int = 200 # LLM 调用次数上限
+ max_tool_calls: int = 500 # 工具调用次数上限
+
+
class AgentConfig(BaseModel):
"""Agent 完整配置"""
name: str = "default_agent"
@@ -42,6 +50,7 @@ class AgentConfig(BaseModel):
llm: AgentLLMConfig = Field(default_factory=AgentLLMConfig)
tools: AgentToolConfig = Field(default_factory=AgentToolConfig)
memory: AgentMemoryConfig = Field(default_factory=AgentMemoryConfig)
+ budget: AgentBudgetConfig = Field(default_factory=AgentBudgetConfig)
user_id: Optional[str] = None
diff --git a/backend/app/agent_runtime/tool_manager.py b/backend/app/agent_runtime/tool_manager.py
index 38c37ff..80dd682 100644
--- a/backend/app/agent_runtime/tool_manager.py
+++ b/backend/app/agent_runtime/tool_manager.py
@@ -3,9 +3,8 @@ Agent 工具管理器:包装已有 ToolRegistry,提供 Agent 需要的工具
"""
from __future__ import annotations
-import json
import logging
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Dict, List, Optional
from app.services.tool_registry import tool_registry
@@ -58,6 +57,8 @@ class AgentToolManager:
"""
执行工具调用。
+ 优先查找内置工具,其次查找数据库自定义工具(HTTP / Code)。
+
Args:
name: 工具名称
args: 工具参数字典
@@ -65,27 +66,8 @@ class AgentToolManager:
Returns:
工具执行结果的字符串表示
"""
- func: Optional[Callable] = tool_registry.get_tool_function(name)
- if not func:
- err = f"工具 '{name}' 不存在"
- logger.error(err)
- return json.dumps({"error": err}, ensure_ascii=False)
-
- logger.info("Agent 执行工具: %s, 参数: %s", name, args)
- try:
- import asyncio
- if asyncio.iscoroutinefunction(func):
- result = await func(**args)
- else:
- result = func(**args)
-
- if isinstance(result, (dict, list)):
- return json.dumps(result, ensure_ascii=False)
- return str(result)
- except Exception as e:
- err_msg = f"工具 '{name}' 执行失败: {e}"
- logger.error(err_msg, exc_info=True)
- return json.dumps({"error": err_msg}, ensure_ascii=False)
+ logger.info("Agent 执行工具: %s", name)
+ return await tool_registry.execute_tool(name, args)
@staticmethod
def _extract_tool_name(schema: Dict[str, Any]) -> Optional[str]:
diff --git a/backend/app/agent_runtime/workflow_integration.py b/backend/app/agent_runtime/workflow_integration.py
index fbcb2ac..8566c7e 100644
--- a/backend/app/agent_runtime/workflow_integration.py
+++ b/backend/app/agent_runtime/workflow_integration.py
@@ -13,6 +13,7 @@ from app.agent_runtime.schemas import (
AgentConfig,
AgentLLMConfig,
AgentToolConfig,
+ AgentBudgetConfig,
)
logger = logging.getLogger(__name__)
@@ -24,6 +25,8 @@ async def run_agent_node(
execution_logger: Optional[Any] = None,
user_id: Optional[str] = None,
on_tool_executed: Optional[Any] = None,
+ on_llm_invocation: Optional[Any] = None,
+ budget_limits: Optional[Dict[str, int]] = None,
) -> Dict[str, Any]:
"""
在工作流中执行 Agent 节点。
@@ -72,6 +75,14 @@ async def run_agent_node(
if node_data.get("base_url"):
llm_config.base_url = node_data["base_url"]
+ # 3a. 构建预算配置(接收工作流级预算限制)
+ budget = AgentBudgetConfig()
+ if budget_limits:
+ if "max_llm_invocations" in budget_limits:
+ budget.max_llm_invocations = max(1, int(budget_limits["max_llm_invocations"]))
+ if "max_tool_calls" in budget_limits:
+ budget.max_tool_calls = max(1, int(budget_limits["max_tool_calls"]))
+
agent_config = AgentConfig(
name=node_data.get("label", "agent_node"),
system_prompt=formatted_prompt,
@@ -84,6 +95,7 @@ async def run_agent_node(
"enabled": node_data.get("memory", True),
"persist_to_db": node_data.get("memory", True),
},
+ budget=budget,
user_id=user_id,
)
@@ -93,6 +105,9 @@ async def run_agent_node(
execution_logger=execution_logger,
on_tool_executed=on_tool_executed,
)
+ # 注入 LLM 预算回调(使 Agent 内部 LLM 调用计入工作流预算)
+ if on_llm_invocation:
+ runtime.on_llm_invocation = on_llm_invocation
result = await runtime.run(query)
diff --git a/backend/app/api/agent_chat.py b/backend/app/api/agent_chat.py
index 9339c34..c9d683b 100644
--- a/backend/app/api/agent_chat.py
+++ b/backend/app/api/agent_chat.py
@@ -8,8 +8,10 @@ POST /api/v1/agent-chat/bare
from __future__ import annotations
import logging
-from typing import Any, Dict, List, Optional
-from fastapi import APIRouter, Depends, HTTPException
+import json
+from typing import Any, AsyncGenerator, Dict, List, Optional
+from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from app.core.database import get_db
@@ -23,6 +25,7 @@ from app.agent_runtime import (
AgentConfig,
AgentLLMConfig,
AgentToolConfig,
+ AgentBudgetConfig,
AgentStep,
AgentOrchestrator,
OrchestratorAgentConfig,
@@ -64,6 +67,14 @@ def _make_llm_logger(
return _log
+async def _sse_stream(gen: AsyncGenerator[dict, None]) -> AsyncGenerator[str, None]:
+ """将 run_stream 生成的 dict 事件格式化为 SSE 文本流。"""
+ async for event in gen:
+ event_type = event.get("type", "message")
+ data = {k: v for k, v in event.items() if k != "type"}
+ yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
+
+
class ChatRequest(BaseModel):
message: str
session_id: Optional[str] = None
@@ -205,6 +216,39 @@ async def chat_bare(
)
+@router.post("/bare/stream")
+async def chat_bare_stream(
+ req: ChatRequest,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
+ """无需 Agent 配置,使用默认设置直接对话(流式 SSE)。"""
+ config = AgentConfig(
+ name="bare_agent",
+ system_prompt="你是一个有用的AI助手。请使用可用工具来帮助用户完成任务。",
+ llm=AgentLLMConfig(
+ model=req.model or (
+ "gpt-4o-mini" if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key"
+ else "deepseek-v4-flash"
+ ),
+ temperature=req.temperature or 0.7,
+ max_iterations=req.max_iterations or 10,
+ ),
+ user_id=current_user.id,
+ )
+ on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
+ runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
+ return StreamingResponse(
+ _sse_stream(runtime.run_stream(req.message)),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "X-Accel-Buffering": "no",
+ },
+ )
+
+
@router.post("/{agent_id}", response_model=ChatResponse)
async def chat_with_agent(
agent_id: str,
@@ -225,9 +269,25 @@ async def chat_with_agent(
# 查找 agent 节点的配置(或第一个 llm 节点的配置)
agent_node_cfg = _find_agent_node_config(nodes)
+ # 构建 system prompt,并自动注入智能体名称
+ system_prompt = agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。"
+ if agent.name:
+ name_prefix = f"你的名字是{agent.name}"
+ if name_prefix not in system_prompt:
+ system_prompt = f"{name_prefix}。\n\n{system_prompt}"
+
+ # 合并执行预算:Agent.budget_config 覆盖默认值
+ budget = AgentBudgetConfig()
+ if agent.budget_config and isinstance(agent.budget_config, dict):
+ bc = agent.budget_config
+ if "max_llm_invocations" in bc and bc["max_llm_invocations"] is not None:
+ budget.max_llm_invocations = max(1, int(bc["max_llm_invocations"]))
+ if "max_tool_calls" in bc and bc["max_tool_calls"] is not None:
+ budget.max_tool_calls = max(1, int(bc["max_tool_calls"]))
+
config = AgentConfig(
name=agent.name,
- system_prompt=agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。",
+ system_prompt=system_prompt,
llm=AgentLLMConfig(
provider=agent_node_cfg.get("provider", "openai"),
model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
@@ -238,6 +298,7 @@ async def chat_with_agent(
include_tools=agent_node_cfg.get("tools", []),
exclude_tools=agent_node_cfg.get("exclude_tools", []),
),
+ budget=budget,
user_id=current_user.id,
)
@@ -256,6 +317,68 @@ async def chat_with_agent(
)
+@router.post("/{agent_id}/stream")
+async def chat_with_agent_stream(
+ agent_id: str,
+ req: ChatRequest,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
+ """与指定的 Agent 对话(流式 SSE)。"""
+ agent = db.query(Agent).filter(Agent.id == agent_id).first()
+ if not agent:
+ raise HTTPException(status_code=404, detail="Agent 不存在")
+ if agent.user_id and agent.user_id != current_user.id and current_user.role != "admin":
+ raise HTTPException(status_code=403, detail="无权访问该 Agent")
+
+ wc = agent.workflow_config or {}
+ nodes = wc.get("nodes", [])
+ agent_node_cfg = _find_agent_node_config(nodes)
+
+ system_prompt = agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。"
+ if agent.name:
+ name_prefix = f"你的名字是{agent.name}"
+ if name_prefix not in system_prompt:
+ system_prompt = f"{name_prefix}。\n\n{system_prompt}"
+
+ budget = AgentBudgetConfig()
+ if agent.budget_config and isinstance(agent.budget_config, dict):
+ bc = agent.budget_config
+ if "max_llm_invocations" in bc and bc["max_llm_invocations"] is not None:
+ budget.max_llm_invocations = max(1, int(bc["max_llm_invocations"]))
+ if "max_tool_calls" in bc and bc["max_tool_calls"] is not None:
+ budget.max_tool_calls = max(1, int(bc["max_tool_calls"]))
+
+ config = AgentConfig(
+ name=agent.name,
+ system_prompt=system_prompt,
+ llm=AgentLLMConfig(
+ provider=agent_node_cfg.get("provider", "openai"),
+ model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
+ temperature=req.temperature or float(agent_node_cfg.get("temperature", 0.7)),
+ max_iterations=req.max_iterations or int(agent_node_cfg.get("max_iterations", 10)),
+ ),
+ tools=AgentToolConfig(
+ include_tools=agent_node_cfg.get("tools", []),
+ exclude_tools=agent_node_cfg.get("exclude_tools", []),
+ ),
+ budget=budget,
+ user_id=current_user.id,
+ )
+
+ on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
+ runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
+ return StreamingResponse(
+ _sse_stream(runtime.run_stream(req.message)),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "X-Accel-Buffering": "no",
+ },
+ )
+
+
def _find_agent_node_config(nodes: list) -> Dict[str, Any]:
"""从工作流节点列表中查找第一个 agent 类型或 llm 类型的节点配置。"""
if not nodes:
diff --git a/backend/app/api/knowledge_base.py b/backend/app/api/knowledge_base.py
new file mode 100644
index 0000000..6609b3d
--- /dev/null
+++ b/backend/app/api/knowledge_base.py
@@ -0,0 +1,251 @@
+"""
+知识库 RAG API。
+
+提供知识库管理、文档上传、语义搜索和 RAG 查询接口。
+"""
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, List, Optional
+
+from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
+from pydantic import BaseModel
+from sqlalchemy.orm import Session
+
+from app.core.database import get_db
+from app.api.auth import get_current_user
+from app.models.user import User
+from app.services.knowledge_service import (
+ create_knowledge_base,
+ delete_document,
+ delete_knowledge_base,
+ get_knowledge_base,
+ list_documents,
+ list_knowledge_bases,
+ rag_query,
+ search,
+ upload_document,
+)
+
+logger = logging.getLogger(__name__)
+router = APIRouter(prefix="/api/v1/knowledge-bases", tags=["knowledge-base"])
+
+
+# ─── Schema ──────────────────────────────────────────────────────
+
+
+class KBCreateRequest(BaseModel):
+ name: str
+ description: str = ""
+ chunk_size: int = 500
+ chunk_overlap: int = 50
+
+
+class KBResponse(BaseModel):
+ id: str
+ name: str
+ description: Optional[str] = ""
+ user_id: Optional[str] = ""
+ chunk_size: int = 500
+ chunk_overlap: int = 50
+ doc_count: int = 0
+ created_at: Optional[str] = ""
+ updated_at: Optional[str] = ""
+
+
+class DocumentResponse(BaseModel):
+ id: str
+ kb_id: str
+ filename: str
+ file_type: str
+ file_size: int = 0
+ status: str = "pending"
+ error_message: Optional[str] = None
+ chunk_count: int = 0
+ created_at: Optional[str] = None
+
+
+class SearchRequest(BaseModel):
+ query: str
+ top_k: int = 5
+ min_score: float = 0.3
+
+
+class SearchResult(BaseModel):
+ chunk_id: str
+ content: str
+ score: float
+ metadata: Dict[str, Any] = {}
+
+
+class SearchResponse(BaseModel):
+ results: List[SearchResult] = []
+
+
+class RAGRequest(BaseModel):
+ query: str
+ top_k: int = 5
+ min_score: float = 0.3
+
+
+class RAGResponse(BaseModel):
+ query: str
+ context: str = ""
+ sources: List[Dict[str, Any]] = []
+ found: bool = False
+
+
+# ─── 知识库 CRUD ──────────────────────────────────────────────
+
+
+@router.post("", response_model=KBResponse)
+async def api_create_kb(
+ req: KBCreateRequest,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
+ """创建知识库。"""
+ kb = create_knowledge_base(
+ db=db,
+ name=req.name,
+ user_id=current_user.id,
+ description=req.description,
+ chunk_size=req.chunk_size,
+ chunk_overlap=req.chunk_overlap,
+ )
+ return KBResponse(**kb.to_dict())
+
+
+@router.get("", response_model=List[KBResponse])
+async def api_list_kb(
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
+ """列出知识库。"""
+ kbs = list_knowledge_bases(db, user_id=current_user.id)
+ return [KBResponse(**kb.to_dict()) for kb in kbs]
+
+
+@router.get("/{kb_id}", response_model=KBResponse)
+async def api_get_kb(
+ kb_id: str,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
+ """获取知识库详情。"""
+ kb = get_knowledge_base(db, kb_id)
+ if not kb:
+ raise HTTPException(status_code=404, detail="知识库不存在")
+ return KBResponse(**kb.to_dict())
+
+
+@router.delete("/{kb_id}")
+async def api_delete_kb(
+ kb_id: str,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
+ """删除知识库。"""
+ ok = delete_knowledge_base(db, kb_id)
+ if not ok:
+ raise HTTPException(status_code=404, detail="知识库不存在")
+ return {"message": "知识库已删除"}
+
+
+# ─── 文档管理 ──────────────────────────────────────────────────
+
+
+@router.post("/{kb_id}/documents", response_model=DocumentResponse)
+async def api_upload_document(
+ kb_id: str,
+ file: UploadFile = File(...),
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
+ """上传文档到知识库(自动解析、分块、生成 Embedding)。"""
+ if not file.filename:
+ raise HTTPException(status_code=400, detail="文件名不能为空")
+
+ content = await file.read()
+ if not content:
+ raise HTTPException(status_code=400, detail="文件内容为空")
+
+ try:
+ doc = await upload_document(
+ db=db,
+ kb_id=kb_id,
+ filename=file.filename,
+ file_content=content,
+ )
+ except ValueError as e:
+ raise HTTPException(status_code=404, detail=str(e))
+
+ if doc.status == "failed":
+ # 返回成功但告知处理失败
+ return DocumentResponse(**doc.to_dict())
+
+ return DocumentResponse(**doc.to_dict())
+
+
+@router.get("/{kb_id}/documents", response_model=List[DocumentResponse])
+async def api_list_documents(
+ kb_id: str,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
+ """列出知识库中的文档。"""
+ docs = list_documents(db, kb_id)
+ return [DocumentResponse(**d.to_dict()) for d in docs]
+
+
+@router.delete("/{kb_id}/documents/{doc_id}")
+async def api_delete_document(
+ kb_id: str,
+ doc_id: str,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
+ """删除文档。"""
+ ok = delete_document(db, doc_id)
+ if not ok:
+ raise HTTPException(status_code=404, detail="文档不存在")
+ return {"message": "文档已删除"}
+
+
+# ─── 搜索 & RAG ────────────────────────────────────────────────
+
+
+@router.post("/{kb_id}/search", response_model=SearchResponse)
+async def api_search(
+ kb_id: str,
+ req: SearchRequest,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
+ """语义搜索知识库。"""
+ results = await search(
+ db=db,
+ kb_id=kb_id,
+ query=req.query,
+ top_k=req.top_k,
+ min_score=req.min_score,
+ )
+ return SearchResponse(results=[SearchResult(**r) for r in results])
+
+
+@router.post("/{kb_id}/rag", response_model=RAGResponse)
+async def api_rag(
+ kb_id: str,
+ req: RAGRequest,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
+ """RAG 查询:搜索相关片段并格式化为上下文。"""
+ result = await rag_query(
+ db=db,
+ kb_id=kb_id,
+ query=req.query,
+ top_k=req.top_k,
+ min_score=req.min_score,
+ )
+ return RAGResponse(**result)
diff --git a/backend/app/api/tools.py b/backend/app/api/tools.py
index 3c9e73e..ac1869a 100644
--- a/backend/app/api/tools.py
+++ b/backend/app/api/tools.py
@@ -1,21 +1,42 @@
"""
-工具管理API
+工具市场 API — 管理、测试、发现和安装工具。
+
+提供内置工具和用户自定义工具(HTTP / 代码段)的 CRUD、测试和执行。
"""
+from __future__ import annotations
+
+import logging
+from typing import Any, Dict, List, Optional
+
from fastapi import APIRouter, Depends, HTTPException, Query
+from pydantic import BaseModel
from sqlalchemy.orm import Session
-from typing import List, Optional
+
+from app.api.auth import get_current_user
from app.core.database import get_db
from app.models.tool import Tool
-from app.services.tool_registry import tool_registry
-from app.api.auth import get_current_user
from app.models.user import User
-from pydantic import BaseModel
+from app.services.tool_registry import tool_registry
+logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/tools", tags=["tools"])
+# ─── Schema ──────────────────────────────────────────────────────
+
+
class ToolCreate(BaseModel):
- """创建工具请求"""
+ name: str
+ description: str
+ category: Optional[str] = None
+ function_schema: dict
+ implementation_type: str # builtin / http / code / workflow
+ implementation_config: Optional[dict] = None
+ is_public: bool = False
+
+
+class ToolResponse(BaseModel):
+ id: str
name: str
description: str
category: Optional[str] = None
@@ -23,86 +44,39 @@ class ToolCreate(BaseModel):
implementation_type: str
implementation_config: Optional[dict] = None
is_public: bool = False
+ use_count: int = 0
+ user_id: Optional[str] = None
+ created_at: str = ""
+ updated_at: str = ""
-class ToolResponse(BaseModel):
- """工具响应"""
- id: str
- name: str
- description: str
- category: Optional[str]
- function_schema: dict
- implementation_type: str
- implementation_config: Optional[dict]
- is_public: bool
- use_count: int
- user_id: Optional[str]
- created_at: str
- updated_at: str
-
- class Config:
- from_attributes = True
+class TestHTTPRequest(BaseModel):
+ url: str
+ method: str = "GET"
+ headers: Dict[str, str] = {}
+ body: Optional[Dict[str, Any]] = None
+ args: Dict[str, Any] = {}
+ timeout: int = 30
-@router.get("", response_model=List[ToolResponse])
-async def list_tools(
- category: Optional[str] = Query(None, description="工具分类"),
- search: Optional[str] = Query(None, description="搜索关键词"),
- db: Session = Depends(get_db)
-):
- """获取工具列表"""
- query = db.query(Tool).filter(Tool.is_public == True)
-
- if category:
- query = query.filter(Tool.category == category)
-
- if search:
- query = query.filter(
- Tool.name.contains(search) |
- Tool.description.contains(search)
- )
-
- tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
-
- # 转换为响应格式,确保日期时间字段转换为字符串
- result = []
- for tool in tools:
- result.append({
- "id": tool.id,
- "name": tool.name,
- "description": tool.description,
- "category": tool.category,
- "function_schema": tool.function_schema,
- "implementation_type": tool.implementation_type,
- "implementation_config": tool.implementation_config,
- "is_public": tool.is_public,
- "use_count": tool.use_count,
- "user_id": tool.user_id,
- "created_at": tool.created_at.isoformat() if tool.created_at else "",
- "updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
- })
-
- return result
+class TestCodeRequest(BaseModel):
+ source: str
+ args: Dict[str, Any] = {}
-@router.get("/builtin")
-async def list_builtin_tools():
- """获取内置工具列表"""
- schemas = tool_registry.get_all_tool_schemas()
- return schemas
+class TestResponse(BaseModel):
+ success: bool
+ elapsed_ms: Optional[int] = None
+ result: Optional[Any] = None
+ status_code: Optional[int] = None
+ body: Optional[str] = None
+ error: Optional[str] = None
-@router.get("/{tool_id}", response_model=ToolResponse)
-async def get_tool(
- tool_id: str,
- db: Session = Depends(get_db)
-):
- """获取工具详情"""
- tool = db.query(Tool).filter(Tool.id == tool_id).first()
- if not tool:
- raise HTTPException(status_code=404, detail="工具不存在")
-
- # 转换为响应格式,确保日期时间字段转换为字符串
+# ─── 工具函数 ──────────────────────────────────────────────────
+
+
+def _tool_to_dict(tool: Tool) -> dict:
return {
"id": tool.id,
"name": tool.name,
@@ -115,22 +89,89 @@ async def get_tool(
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
- "updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
+ "updated_at": tool.updated_at.isoformat() if tool.updated_at else "",
}
+# ─── 工具市场浏览 ──────────────────────────────────────────────
+
+
+@router.get("", response_model=List[ToolResponse])
+async def list_tools(
+ category: Optional[str] = Query(None, description="按分类筛选"),
+ search: Optional[str] = Query(None, description="搜索关键词"),
+ scope: Optional[str] = Query("public", description="public / mine / all"),
+ db: Session = Depends(get_db),
+ current_user: Optional[User] = Depends(get_current_user),
+):
+ """浏览工具市场。"""
+ query = db.query(Tool)
+
+ if scope == "public":
+ query = query.filter(Tool.is_public == True)
+ elif scope == "mine":
+ if not current_user:
+ raise HTTPException(status_code=401, detail="需登录")
+ query = query.filter(Tool.user_id == current_user.id)
+
+ if category:
+ query = query.filter(Tool.category == category)
+ if search:
+ query = query.filter(
+ Tool.name.contains(search) | Tool.description.contains(search)
+ )
+
+ tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
+ return [_tool_to_dict(t) for t in tools]
+
+
+@router.get("/categories", response_model=List[str])
+async def list_categories(db: Session = Depends(get_db)):
+ """列出所有工具分类。"""
+ rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all()
+ cats = sorted(set(r[0] for r in rows if r[0]))
+ # 加上常用分类
+ defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义"]
+ for d in defaults:
+ if d not in cats:
+ cats.append(d)
+ return cats
+
+
+@router.get("/builtin")
+async def list_builtin_tools():
+ """列出所有内置工具(OpenAI Function 格式)。"""
+ return tool_registry.get_all_tool_schemas()
+
+
+@router.get("/{tool_id}", response_model=ToolResponse)
+async def get_tool(tool_id: str, db: Session = Depends(get_db)):
+ """获取工具详情。"""
+ tool = db.query(Tool).filter(Tool.id == tool_id).first()
+ if not tool:
+ raise HTTPException(status_code=404, detail="工具不存在")
+ return _tool_to_dict(tool)
+
+
+# ─── 工具创建 / 更新 / 删除 ──────────────────────────────────
+
+
@router.post("", response_model=ToolResponse, status_code=201)
async def create_tool(
tool_data: ToolCreate,
db: Session = Depends(get_db),
- current_user: User = Depends(get_current_user)
+ current_user: User = Depends(get_current_user),
):
- """创建工具"""
- # 检查工具名称是否已存在
+ """创建自定义工具。"""
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
- raise HTTPException(status_code=400, detail=f"工具名称 '{tool_data.name}' 已存在")
-
+ raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
+
+ valid_types = {"builtin", "http", "code", "workflow"}
+ if tool_data.implementation_type not in valid_types:
+ raise HTTPException(status_code=400,
+ detail=f"无效的实现类型: {tool_data.implementation_type}")
+
tool = Tool(
name=tool_data.name,
description=tool_data.description,
@@ -139,28 +180,22 @@ async def create_tool(
implementation_type=tool_data.implementation_type,
implementation_config=tool_data.implementation_config,
is_public=tool_data.is_public,
- user_id=current_user.id
+ user_id=current_user.id,
)
-
db.add(tool)
db.commit()
db.refresh(tool)
-
- # 转换为响应格式,确保日期时间字段转换为字符串
- return {
- "id": tool.id,
- "name": tool.name,
- "description": tool.description,
- "category": tool.category,
- "function_schema": tool.function_schema,
- "implementation_type": tool.implementation_type,
- "implementation_config": tool.implementation_config,
- "is_public": tool.is_public,
- "use_count": tool.use_count,
- "user_id": tool.user_id,
- "created_at": tool.created_at.isoformat() if tool.created_at else "",
- "updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
+ logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type)
+
+ # 刷新注册表
+ tool_registry._custom_tool_configs[tool.name] = {
+ **(tool.implementation_config or {}),
+ "_type": tool.implementation_type,
+ "_db_id": tool.id,
}
+ tool_registry._tool_schemas[tool.name] = tool.function_schema
+
+ return _tool_to_dict(tool)
@router.put("/{tool_id}", response_model=ToolResponse)
@@ -168,23 +203,20 @@ async def update_tool(
tool_id: str,
tool_data: ToolCreate,
db: Session = Depends(get_db),
- current_user: User = Depends(get_current_user)
+ current_user: User = Depends(get_current_user),
):
- """更新工具"""
+ """更新工具。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
-
- # 检查权限(只有创建者可以更新)
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权更新此工具")
-
- # 检查名称冲突
+
if tool_data.name != tool.name:
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
- raise HTTPException(status_code=400, detail=f"工具名称 '{tool_data.name}' 已存在")
-
+ raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
+
tool.name = tool_data.name
tool.description = tool_data.description
tool.category = tool_data.category
@@ -192,47 +224,94 @@ async def update_tool(
tool.implementation_type = tool_data.implementation_type
tool.implementation_config = tool_data.implementation_config
tool.is_public = tool_data.is_public
-
+
db.commit()
db.refresh(tool)
-
- # 转换为响应格式,确保日期时间字段转换为字符串
- return {
- "id": tool.id,
- "name": tool.name,
- "description": tool.description,
- "category": tool.category,
- "function_schema": tool.function_schema,
- "implementation_type": tool.implementation_type,
- "implementation_config": tool.implementation_config,
- "is_public": tool.is_public,
- "use_count": tool.use_count,
- "user_id": tool.user_id,
- "created_at": tool.created_at.isoformat() if tool.created_at else "",
- "updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
- }
+
+ # 刷新注册表
+ if tool.name in tool_registry._custom_tool_configs:
+ tool_registry._custom_tool_configs[tool.name] = {
+ **(tool.implementation_config or {}),
+ "_type": tool.implementation_type,
+ "_db_id": tool.id,
+ }
+ if tool.name in tool_registry._tool_schemas:
+ tool_registry._tool_schemas[tool.name] = tool.function_schema
+
+ return _tool_to_dict(tool)
-@router.delete("/{tool_id}", status_code=200)
+@router.delete("/{tool_id}")
async def delete_tool(
tool_id: str,
db: Session = Depends(get_db),
- current_user: User = Depends(get_current_user)
+ current_user: User = Depends(get_current_user),
):
- """删除工具"""
+ """删除工具。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
-
- # 检查权限(只有创建者可以删除)
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权删除此工具")
-
- # 内置工具不允许删除
if tool.implementation_type == "builtin":
raise HTTPException(status_code=400, detail="内置工具不允许删除")
-
+
db.delete(tool)
db.commit()
-
+
+ # 清理注册表
+ tool_registry._custom_tool_configs.pop(tool.name, None)
+ tool_registry._tool_schemas.pop(tool.name, None)
+
return {"message": "工具已删除"}
+
+
+# ─── 工具测试 ──────────────────────────────────────────────────
+
+
+@router.post("/test/http", response_model=TestResponse)
+async def test_http_tool(
+ req: TestHTTPRequest,
+ current_user: User = Depends(get_current_user),
+):
+ """测试 HTTP 工具(不保存到数据库)。"""
+ result = await tool_registry.test_http_tool(
+ url=req.url,
+ method=req.method,
+ headers=req.headers,
+ body=req.body,
+ args=req.args,
+ timeout=req.timeout,
+ )
+ return TestResponse(**result)
+
+
+@router.post("/test/code", response_model=TestResponse)
+async def test_code_tool(
+ req: TestCodeRequest,
+ current_user: User = Depends(get_current_user),
+):
+ """测试代码工具(不保存到数据库)。"""
+ result = await tool_registry.test_code_tool(
+ source=req.source,
+ args=req.args,
+ )
+ return TestResponse(**result)
+
+
+# ─── 使用计数 ──────────────────────────────────────────────────
+
+
+@router.post("/{tool_id}/use")
+async def record_tool_use(
+ tool_id: str,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+):
+ """记录工具使用次数(Agent 执行时自动调用)。"""
+ tool = db.query(Tool).filter(Tool.id == tool_id).first()
+ if not tool:
+ raise HTTPException(status_code=404, detail="工具不存在")
+ tool.use_count = (tool.use_count or 0) + 1
+ db.commit()
+ return {"use_count": tool.use_count}
diff --git a/backend/app/core/config.py b/backend/app/core/config.py
index 72630ce..cc33173 100644
--- a/backend/app/core/config.py
+++ b/backend/app/core/config.py
@@ -56,6 +56,11 @@ class Settings(BaseSettings):
# DeepSeek配置
DEEPSEEK_API_KEY: str = ""
DEEPSEEK_BASE_URL: str = "https://api.deepseek.com"
+
+ # SiliconFlow配置(Embedding 推荐使用 SiliconFlow)
+ SILICONFLOW_API_KEY: str = ""
+ SILICONFLOW_BASE_URL: str = "https://api.siliconflow.cn/v1"
+ SILICONFLOW_EMBEDDING_MODEL: str = "netease-youdao/bce-embedding-base_v1"
# Anthropic配置
ANTHROPIC_API_KEY: str = ""
diff --git a/backend/app/core/database.py b/backend/app/core/database.py
index e48b074..4cb9ad5 100644
--- a/backend/app/core/database.py
+++ b/backend/app/core/database.py
@@ -47,4 +47,6 @@ def init_db():
import app.models.permission
import app.models.alert_rule
import app.models.agent_llm_log
+ import app.models.agent_vector_memory
+ import app.models.knowledge_base
Base.metadata.create_all(bind=engine)
diff --git a/backend/app/main.py b/backend/app/main.py
index c756e59..370a0cd 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -201,7 +201,7 @@ async def startup_event():
# 不抛出异常,允许应用继续启动
# 注册路由
-from app.api import auth, uploads, workflows, executions, websocket, execution_logs, data_sources, agents, platform_templates, model_configs, webhooks, template_market, batch_operations, collaboration, permissions, monitoring, alert_rules, node_test, node_templates, tools, agent_chat, agent_monitoring
+from app.api import auth, uploads, workflows, executions, websocket, execution_logs, data_sources, agents, platform_templates, model_configs, webhooks, template_market, batch_operations, collaboration, permissions, monitoring, alert_rules, node_test, node_templates, tools, agent_chat, agent_monitoring, knowledge_base
app.include_router(auth.router)
app.include_router(uploads.router)
@@ -225,6 +225,7 @@ app.include_router(node_templates.router)
app.include_router(tools.router)
app.include_router(agent_chat.router)
app.include_router(agent_monitoring.router)
+app.include_router(knowledge_base.router)
if __name__ == "__main__":
import uvicorn
diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py
index 6b50fac..13b3fda 100644
--- a/backend/app/models/__init__.py
+++ b/backend/app/models/__init__.py
@@ -13,5 +13,7 @@ from app.models.permission import Role, Permission, WorkflowPermission, AgentPer
from app.models.alert_rule import AlertRule, AlertLog
from app.models.persistent_user_memory import PersistentUserMemory
from app.models.agent_llm_log import AgentLLMLog
+from app.models.agent_vector_memory import AgentVectorMemory
+from app.models.knowledge_base import KnowledgeBase, Document, DocumentChunk
-__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog"]
\ No newline at end of file
+__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "KnowledgeBase", "Document", "DocumentChunk"]
\ No newline at end of file
diff --git a/backend/app/models/agent_vector_memory.py b/backend/app/models/agent_vector_memory.py
new file mode 100644
index 0000000..35da35a
--- /dev/null
+++ b/backend/app/models/agent_vector_memory.py
@@ -0,0 +1,45 @@
+"""
+Agent 向量记忆模型。
+
+存储对话片段的 embedding 向量,支持语义检索。
+"""
+from __future__ import annotations
+
+import uuid
+from datetime import datetime
+
+from sqlalchemy import Column, String, Text, DateTime, Index
+from sqlalchemy.dialects.mysql import JSON as MySQLJSON
+
+from app.core.database import Base
+
+
+class AgentVectorMemory(Base):
+ """Agent 向量记忆 — 存储对话文本及 embedding"""
+
+ __tablename__ = "agent_vector_memories"
+
+ id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
+ scope_kind = Column(String(16), nullable=False, index=True, comment="作用域类型: agent / bare")
+ scope_id = Column(String(36), nullable=False, index=True, comment="作用域 ID: agent_id / user_id")
+ session_key = Column(String(128), nullable=False, default="", comment="会话标识")
+ content_text = Column(Text, nullable=False, comment="原始对话文本")
+ embedding = Column(Text, nullable=True, comment="JSON 序列化的 embedding 向量")
+ metadata_ = Column("metadata", MySQLJSON, nullable=True, comment="元数据: {type, iteration, ...}")
+ created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
+
+ __table_args__ = (
+ Index("ix_agent_vector_memory_scope", "scope_kind", "scope_id"),
+ )
+
+ def to_dict(self) -> dict:
+ return {
+ "id": self.id,
+ "scope_kind": self.scope_kind,
+ "scope_id": self.scope_id,
+ "session_key": self.session_key,
+ "content_text": self.content_text,
+ "embedding": self.embedding,
+ "metadata": self.metadata_ or {},
+ "created_at": self.created_at.isoformat() if self.created_at else None,
+ }
diff --git a/backend/app/models/knowledge_base.py b/backend/app/models/knowledge_base.py
new file mode 100644
index 0000000..c90fac1
--- /dev/null
+++ b/backend/app/models/knowledge_base.py
@@ -0,0 +1,116 @@
+"""
+知识库 RAG 模型。
+
+KnowledgeBase — 知识库容器
+Document — 上传的源文档
+DocumentChunk — 文档切片(含 embedding)
+"""
+from __future__ import annotations
+
+import uuid
+from datetime import datetime
+
+from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, Index
+from sqlalchemy.dialects.mysql import JSON as MySQLJSON
+from sqlalchemy.orm import relationship
+
+from app.core.database import Base
+
+
+class KnowledgeBase(Base):
+ """知识库"""
+ __tablename__ = "knowledge_bases"
+
+ id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
+ name = Column(String(200), nullable=False, comment="知识库名称")
+ description = Column(Text, nullable=True, comment="描述")
+ user_id = Column(String(36), nullable=True, index=True, comment="创建者 ID")
+ chunk_size = Column(Integer, default=500, comment="分块大小(字符数)")
+ chunk_overlap = Column(Integer, default=50, comment="分块重叠(字符数)")
+ doc_count = Column(Integer, default=0, comment="文档数量")
+ created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
+
+ documents = relationship("Document", back_populates="kb", cascade="all, delete-orphan",
+ passive_deletes=True)
+
+ def to_dict(self) -> dict:
+ return {
+ "id": self.id,
+ "name": self.name,
+ "description": self.description,
+ "user_id": self.user_id,
+ "chunk_size": self.chunk_size,
+ "chunk_overlap": self.chunk_overlap,
+ "doc_count": self.doc_count,
+ "created_at": self.created_at.isoformat() if self.created_at else None,
+ "updated_at": self.updated_at.isoformat() if self.updated_at else None,
+ }
+
+
+class Document(Base):
+ """知识库文档"""
+ __tablename__ = "kb_documents"
+
+ id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
+ kb_id = Column(String(36), ForeignKey("knowledge_bases.id", ondelete="CASCADE"),
+ nullable=False, index=True, comment="所属知识库")
+ filename = Column(String(500), nullable=False, comment="原始文件名")
+ file_type = Column(String(20), nullable=False, comment="文件类型: txt/pdf/docx/md/csv")
+ file_size = Column(Integer, default=0, comment="文件大小(字节)")
+ status = Column(String(20), default="pending",
+ comment="状态: pending/processing/ready/failed")
+ error_message = Column(Text, nullable=True, comment="处理失败信息")
+ chunk_count = Column(Integer, default=0, comment="分块数量")
+ created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
+
+ kb = relationship("KnowledgeBase", back_populates="documents")
+ chunks = relationship("DocumentChunk", back_populates="document",
+ cascade="all, delete-orphan", passive_deletes=True)
+
+ def to_dict(self) -> dict:
+ return {
+ "id": self.id,
+ "kb_id": self.kb_id,
+ "filename": self.filename,
+ "file_type": self.file_type,
+ "file_size": self.file_size,
+ "status": self.status,
+ "error_message": self.error_message,
+ "chunk_count": self.chunk_count,
+ "created_at": self.created_at.isoformat() if self.created_at else None,
+ }
+
+
+class DocumentChunk(Base):
+ """文档切片 — 含 embedding 向量"""
+ __tablename__ = "kb_document_chunks"
+
+ id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
+ document_id = Column(String(36), ForeignKey("kb_documents.id", ondelete="CASCADE"),
+ nullable=False, index=True, comment="所属文档")
+ kb_id = Column(String(36), ForeignKey("knowledge_bases.id", ondelete="CASCADE"),
+ nullable=False, index=True, comment="所属知识库")
+ chunk_index = Column(Integer, default=0, comment="块序号")
+ content = Column(Text, nullable=False, comment="切片文本内容")
+ embedding = Column(Text, nullable=True, comment="JSON 序列化的 embedding 向量")
+ metadata_ = Column("metadata", MySQLJSON, nullable=True, comment="元数据")
+ created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
+
+ document = relationship("Document", back_populates="chunks")
+
+ __table_args__ = (
+ Index("ix_kb_chunks_kb_id", "kb_id"),
+ Index("ix_kb_chunks_doc_id", "document_id"),
+ )
+
+ def to_dict(self) -> dict:
+ return {
+ "id": self.id,
+ "document_id": self.document_id,
+ "kb_id": self.kb_id,
+ "chunk_index": self.chunk_index,
+ "content": self.content[:200] + "..." if len(self.content) > 200 else self.content,
+ "metadata": self.metadata_ or {},
+ "created_at": self.created_at.isoformat() if self.created_at else None,
+ }
diff --git a/backend/app/services/document_parser.py b/backend/app/services/document_parser.py
new file mode 100644
index 0000000..c084399
--- /dev/null
+++ b/backend/app/services/document_parser.py
@@ -0,0 +1,101 @@
+"""
+文档解析器 — 支持 txt / pdf / docx / md / csv。
+
+解析结果统一返回纯文本字符串,由调用方做分块处理。
+"""
+from __future__ import annotations
+
+import csv
+import io
+import logging
+import os
+from typing import Optional
+
+logger = logging.getLogger(__name__)
+
+
+def parse_document(file_path: str, file_type: str) -> Optional[str]:
+ """
+ 解析文档为纯文本。
+
+ Args:
+ file_path: 文件绝对路径
+ file_type: 文件类型(txt / pdf / docx / md / csv)
+
+ Returns:
+ 文本内容,解析失败返回 None
+ """
+ if not os.path.isfile(file_path):
+ logger.warning("文件不存在: %s", file_path)
+ return None
+
+ parsers = {
+ "txt": _parse_text,
+ "md": _parse_text,
+ "pdf": _parse_pdf,
+ "docx": _parse_docx,
+ "csv": _parse_csv,
+ }
+ parser = parsers.get(file_type)
+ if not parser:
+ logger.warning("不支持的文件类型: %s", file_type)
+ return None
+
+ try:
+ text = parser(file_path)
+ if text:
+ text = text.strip()
+ logger.info("文档解析完成: %s (%s, %d 字符)", file_path, file_type, len(text or ""))
+ return text
+ except Exception as e:
+ logger.error("文档解析失败: %s (%s)", file_path, e, exc_info=True)
+ return None
+
+
+def _parse_text(file_path: str) -> str:
+ """解析纯文本 / Markdown 文件。"""
+ with open(file_path, "r", encoding="utf-8", errors="replace") as f:
+ return f.read()
+
+
+def _parse_pdf(file_path: str) -> str:
+ """解析 PDF 文件。"""
+ from pypdf import PdfReader
+
+ reader = PdfReader(file_path)
+ pages = []
+ for page in reader.pages:
+ text = page.extract_text()
+ if text:
+ pages.append(text)
+ return "\n\n".join(pages)
+
+
+def _parse_docx(file_path: str) -> str:
+ """解析 Word 文档。"""
+ from docx import Document as DocxDocument
+
+ doc = DocxDocument(file_path)
+ paragraphs = []
+ for para in doc.paragraphs:
+ if para.text and para.text.strip():
+ paragraphs.append(para.text.strip())
+
+ # 也提取表格内容
+ for table in doc.tables:
+ for row in table.rows:
+ cells = [cell.text.strip() for cell in row.cells if cell.text.strip()]
+ if cells:
+ paragraphs.append(" | ".join(cells))
+
+ return "\n\n".join(paragraphs)
+
+
+def _parse_csv(file_path: str) -> str:
+ """解析 CSV 文件为文本表格。"""
+ rows = []
+ with open(file_path, "r", encoding="utf-8", errors="replace") as f:
+ reader = csv.reader(f)
+ for row in reader:
+ rows.append(" | ".join(cell.strip() for cell in row))
+ return "\n".join(rows)
diff --git a/backend/app/services/embedding_service.py b/backend/app/services/embedding_service.py
new file mode 100644
index 0000000..f301a26
--- /dev/null
+++ b/backend/app/services/embedding_service.py
@@ -0,0 +1,229 @@
+"""
+Embedding 生成与语义检索服务。
+
+使用 OpenAI text-embedding-3-small 生成文本向量,
+在内存中计算余弦相似度实现语义搜索。
+
+如未配置 OpenAI API Key,所有方法静默降级返回空结果。
+"""
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+import time
+from typing import Any, Dict, List, Optional, TypedDict
+
+from app.core.config import settings
+
+logger = logging.getLogger(__name__)
+
+# 默认 embedding 模型
+EMBEDDING_MODEL = "text-embedding-3-small"
+EMBEDDING_DIMENSIONS = 1536
+
+
+class VectorEntry(TypedDict, total=False):
+ """向量条目"""
+ id: str
+ scope_kind: str
+ scope_id: str
+ content_text: str
+ embedding: List[float]
+ metadata: Dict[str, Any]
+ score: float # 余弦相似度,仅检索结果包含
+
+
+class EmbeddingService:
+ """
+ Embedding 服务。
+
+ 用法:
+ svc = EmbeddingService()
+ emb = await svc.generate_embedding("你好")
+ results = await svc.similarity_search(query_emb, entries, top_k=5)
+ """
+
+ def __init__(self):
+ self._client: Optional[Any] = None
+ self._client_lock = asyncio.Lock()
+
+ async def _get_client(self):
+ """延迟初始化 OpenAI 客户端(仅在首次调用时创建)。
+
+ 优先级:SiliconFlow > OpenAI > DeepSeek。
+ SiliconFlow 的 netease-youdao/bce-embedding-base_v1 在国内可直连且免费。
+ """
+ if self._client is not None:
+ return self._client
+
+ async with self._client_lock:
+ if self._client is not None:
+ return self._client
+
+ from openai import AsyncOpenAI
+
+ api_key: Optional[str] = None
+ base_url: Optional[str] = None
+ backend_label = "none"
+
+ # 1) SiliconFlow(国内直连,推荐)
+ if settings.SILICONFLOW_API_KEY:
+ api_key = settings.SILICONFLOW_API_KEY
+ base_url = settings.SILICONFLOW_BASE_URL or "https://api.siliconflow.cn/v1"
+ backend_label = "siliconflow"
+ logger.info("Embedding 后端: SiliconFlow (model=%s)", settings.SILICONFLOW_EMBEDDING_MODEL)
+
+ # 2) OpenAI
+ if not api_key:
+ if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key":
+ api_key = settings.OPENAI_API_KEY
+ base_url = settings.OPENAI_BASE_URL or "https://api.openai.com/v1"
+ backend_label = "openai"
+
+ # 3) DeepSeek(部分代理可能支持 embedding)
+ if not api_key:
+ if settings.DEEPSEEK_API_KEY:
+ api_key = settings.DEEPSEEK_API_KEY
+ base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
+ backend_label = "deepseek"
+
+ if not api_key:
+ logger.info("未配置任何 API Key,向量记忆功能已禁用(请配置 SILICONFLOW_API_KEY 或 OPENAI_API_KEY)")
+ self._client = None
+ return None
+
+ self._client = AsyncOpenAI(
+ api_key=api_key,
+ base_url=base_url,
+ )
+ logger.info("EmbeddingService 已初始化 (backend=%s)", backend_label)
+ return self._client
+
+ def _get_model(self) -> str:
+ """根据后端选择对应的 embedding 模型。"""
+ if settings.SILICONFLOW_API_KEY:
+ return settings.SILICONFLOW_EMBEDDING_MODEL
+ return EMBEDDING_MODEL
+
+ async def generate_embedding(self, text: str) -> Optional[List[float]]:
+ """
+ 为单段文本生成 embedding 向量。
+
+ 返回 float 列表,无 API key 或出错时返回 None。
+ """
+ if not text or not text.strip():
+ return None
+
+ client = await self._get_client()
+ if not client:
+ return None
+
+ model = self._get_model()
+ try:
+ kwargs: Dict[str, Any] = {
+ "model": model,
+ "input": text.strip()[:8000],
+ }
+ # OpenAI text-embedding-3-small 支持指定 dimensions;其它模型可能不支持
+ if model == EMBEDDING_MODEL:
+ kwargs["dimensions"] = EMBEDDING_DIMENSIONS
+ resp = await client.embeddings.create(**kwargs)
+ return resp.data[0].embedding
+ except Exception as e:
+ logger.warning("生成 embedding 失败: %s", e)
+ return None
+
+ async def generate_embeddings(self, texts: List[str]) -> List[Optional[List[float]]]:
+ """
+ 批量生成 embedding 向量。
+ """
+ if not texts:
+ return []
+
+ client = await self._get_client()
+ if not client:
+ return [None] * len(texts)
+
+ # 清理空文本
+ valid_indices = [i for i, t in enumerate(texts) if t and t.strip()]
+ if not valid_indices:
+ return [None] * len(texts)
+
+ model = self._get_model()
+ try:
+ inputs = [texts[i].strip()[:8000] for i in valid_indices]
+ kwargs: Dict[str, Any] = {
+ "model": model,
+ "input": inputs,
+ }
+ if model == EMBEDDING_MODEL:
+ kwargs["dimensions"] = EMBEDDING_DIMENSIONS
+ resp = await client.embeddings.create(**kwargs)
+ embeddings = [r.embedding for r in resp.data]
+ except Exception as e:
+ logger.warning("批量生成 embedding 失败: %s", e)
+ return [None] * len(texts)
+
+ # 按原始顺序排列结果
+ result: List[Optional[List[float]]] = [None] * len(texts)
+ for idx, emb in zip(valid_indices, embeddings):
+ result[idx] = emb
+ return result
+
+ @staticmethod
+ def cosine_similarity(a: List[float], b: List[float]) -> float:
+ """计算两个向量的余弦相似度。"""
+ if not a or not b or len(a) != len(b):
+ return 0.0
+
+ dot = sum(x * y for x, y in zip(a, b))
+ norm_a = sum(x * x for x in a) ** 0.5
+ norm_b = sum(y * y for y in b) ** 0.5
+
+ if norm_a == 0 or norm_b == 0:
+ return 0.0
+ return dot / (norm_a * norm_b)
+
+ async def similarity_search(
+ self,
+ query_embedding: List[float],
+ entries: List[VectorEntry],
+ top_k: int = 5,
+ min_score: float = 0.3,
+ ) -> List[VectorEntry]:
+ """
+ 在内存中对 entries 做余弦相似度搜索,返回 Top-K 结果(按得分降序)。
+ min_score: 最低相似度阈值,低于该值的结果被过滤。
+ """
+ scored: List[VectorEntry] = []
+ for entry in entries:
+ emb = entry.get("embedding")
+ if not emb:
+ continue
+ score = self.cosine_similarity(query_embedding, emb)
+ if score >= min_score:
+ entry["score"] = score
+ scored.append(entry)
+
+ scored.sort(key=lambda x: x["score"], reverse=True)
+ return scored[:top_k]
+
+ @staticmethod
+ def serialize_embedding(embedding: List[float]) -> str:
+ """将 embedding 序列化为 JSON 字符串。"""
+ return json.dumps(embedding, ensure_ascii=False)
+
+ @staticmethod
+ def deserialize_embedding(data: str) -> List[float]:
+ """从 JSON 字符串反序列化 embedding。"""
+ if isinstance(data, list):
+ return data
+ try:
+ return json.loads(data)
+ except (json.JSONDecodeError, TypeError):
+ return []
+
+
+# 全局单例
+embedding_service = EmbeddingService()
diff --git a/backend/app/services/knowledge_service.py b/backend/app/services/knowledge_service.py
new file mode 100644
index 0000000..239ec0f
--- /dev/null
+++ b/backend/app/services/knowledge_service.py
@@ -0,0 +1,375 @@
+"""
+知识库服务 — 文档管理、分块、Embedding、语义检索、RAG。
+
+流程:
+ 上传 → 解析 → 分块 → Embedding → 存储
+ 检索 → 查询 → Embedding → 余弦相似度 → Top-K
+ RAG → 检索 + 格式化上下文 → 返回给 Agent/用户
+"""
+from __future__ import annotations
+
+import json
+import logging
+import os
+import uuid
+from typing import Any, Dict, List, Optional
+
+from sqlalchemy.orm import Session
+
+from app.core.config import settings
+from app.models.knowledge_base import KnowledgeBase, Document, DocumentChunk
+from app.services.document_parser import parse_document
+from app.services.embedding_service import embedding_service, VectorEntry
+from app.services.text_chunker import chunk_text
+
+logger = logging.getLogger(__name__)
+
+# 上传文件存储根目录
+UPLOAD_DIR = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
+ "kb_uploads",
+)
+
+
+def _ensure_upload_dir():
+ """确保上传目录存在。"""
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
+
+
+def _get_kb_dir(kb_id: str) -> str:
+ """返回知识库文件存放目录。"""
+ d = os.path.join(UPLOAD_DIR, kb_id)
+ os.makedirs(d, exist_ok=True)
+ return d
+
+
+# ─── 知识库 CRUD ────────────────────────────────────────────────
+
+
+def create_knowledge_base(
+ db: Session,
+ name: str,
+ user_id: str,
+ description: str = "",
+ chunk_size: int = 500,
+ chunk_overlap: int = 50,
+) -> KnowledgeBase:
+ """创建知识库。"""
+ kb = KnowledgeBase(
+ name=name,
+ description=description,
+ user_id=user_id,
+ chunk_size=max(50, min(2000, chunk_size)),
+ chunk_overlap=max(0, min(chunk_size // 2, chunk_overlap)),
+ )
+ db.add(kb)
+ db.commit()
+ db.refresh(kb)
+ logger.info("知识库已创建: %s (%s)", kb.name, kb.id)
+ return kb
+
+
+def list_knowledge_bases(db: Session, user_id: Optional[str] = None) -> List[KnowledgeBase]:
+ """列出知识库。"""
+ q = db.query(KnowledgeBase)
+ if user_id:
+ q = q.filter(KnowledgeBase.user_id == user_id)
+ return q.order_by(KnowledgeBase.updated_at.desc()).all()
+
+
+def get_knowledge_base(db: Session, kb_id: str) -> Optional[KnowledgeBase]:
+ """获取知识库详情。"""
+ return db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
+
+
+def delete_knowledge_base(db: Session, kb_id: str) -> bool:
+ """删除知识库(连带文档和分块)。"""
+ kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
+ if not kb:
+ return False
+ # 删除磁盘文件
+ kb_dir = _get_kb_dir(kb_id)
+ if os.path.isdir(kb_dir):
+ import shutil
+ shutil.rmtree(kb_dir, ignore_errors=True)
+ db.delete(kb)
+ db.commit()
+ logger.info("知识库已删除: %s", kb_id)
+ return True
+
+
+# ─── 文档管理 ───────────────────────────────────────────────────
+
+
+async def upload_document(
+ db: Session,
+ kb_id: str,
+ filename: str,
+ file_content: bytes,
+) -> Document:
+ """
+ 上传文档到知识库:
+ 1. 保存原始文件
+ 2. 解析为文本
+ 3. 分块
+ 4. 生成 Embedding
+ 5. 存储分块
+ """
+ kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
+ if not kb:
+ raise ValueError(f"知识库不存在: {kb_id}")
+
+ _ensure_upload_dir()
+ kb_dir = _get_kb_dir(kb_id)
+
+ # 提取文件类型
+ ext = os.path.splitext(filename)[1].lower().lstrip(".")
+ if ext in ("txt", "md", "pdf", "docx", "csv"):
+ file_type = ext
+ else:
+ file_type = "txt"
+
+ # 保存原始文件
+ file_path = os.path.join(kb_dir, f"{uuid.uuid4()}_{filename}")
+ with open(file_path, "wb") as f:
+ f.write(file_content)
+ file_size = len(file_content)
+
+ # 创建文档记录
+ doc = Document(
+ kb_id=kb_id,
+ filename=filename,
+ file_type=file_type,
+ file_size=file_size,
+ status="processing",
+ )
+ db.add(doc)
+ db.commit()
+ db.refresh(doc)
+
+ try:
+ # 解析文本
+ text = parse_document(file_path, file_type)
+ if not text:
+ doc.status = "failed"
+ doc.error_message = "文档解析失败或内容为空"
+ db.commit()
+ return doc
+
+ # 分块
+ chunks = chunk_text(
+ text,
+ chunk_size=kb.chunk_size,
+ chunk_overlap=kb.chunk_overlap,
+ )
+ if not chunks:
+ doc.status = "failed"
+ doc.error_message = "文档分块后为空"
+ db.commit()
+ return doc
+
+ logger.info("文档分块完成: %s → %d 块", filename, len(chunks))
+
+ # 批量生成 Embedding
+ embeddings: List[Optional[List[float]]] = []
+ try:
+ embeddings = await embedding_service.generate_embeddings(chunks)
+ except Exception as e:
+ logger.warning("批量生成 embedding 失败,逐块回退: %s", e)
+ for c in chunks:
+ try:
+ emb = await embedding_service.generate_embedding(c)
+ embeddings.append(emb)
+ except Exception:
+ embeddings.append(None)
+
+ # 存储分块
+ chunk_records = []
+ for i, (chunk_text_content, emb) in enumerate(zip(chunks, embeddings)):
+ record = DocumentChunk(
+ document_id=doc.id,
+ kb_id=kb_id,
+ chunk_index=i,
+ content=chunk_text_content,
+ embedding=json.dumps(emb) if emb else None,
+ metadata_={
+ "filename": filename,
+ "file_type": file_type,
+ "chunk_index": i,
+ "has_embedding": emb is not None,
+ },
+ )
+ chunk_records.append(record)
+
+ db.add_all(chunk_records)
+
+ # 更新文档状态
+ doc.status = "ready"
+ doc.chunk_count = len(chunks)
+ kb.doc_count = db.query(Document).filter(
+ Document.kb_id == kb_id, Document.status == "ready"
+ ).count()
+
+ db.commit()
+ logger.info("文档处理完成: %s (%d 块, embedding=%s)",
+ filename, len(chunks),
+ "yes" if any(e for e in embeddings) else "no")
+
+ except Exception as e:
+ db.rollback()
+ doc = db.query(Document).filter(Document.id == doc.id).first()
+ if doc:
+ doc.status = "failed"
+ doc.error_message = str(e)[:500]
+ db.commit()
+ logger.error("文档处理失败: %s: %s", filename, e, exc_info=True)
+
+ return doc
+
+
+def list_documents(db: Session, kb_id: str) -> List[Document]:
+ """列出知识库中的文档。"""
+ return (
+ db.query(Document)
+ .filter(Document.kb_id == kb_id)
+ .order_by(Document.created_at.desc())
+ .all()
+ )
+
+
+def delete_document(db: Session, doc_id: str) -> bool:
+ """删除文档(连带分块)。"""
+ doc = db.query(Document).filter(Document.id == doc_id).first()
+ if not doc:
+ return False
+ # 减少知识库文档计数
+ kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == doc.kb_id).first()
+ db.delete(doc)
+ if kb:
+ kb.doc_count = db.query(Document).filter(
+ Document.kb_id == kb.id, Document.status == "ready"
+ ).count()
+ db.commit()
+ logger.info("文档已删除: %s", doc_id)
+ return True
+
+
+# ─── 语义检索 ───────────────────────────────────────────────────
+
+
+async def search(
+ db: Session,
+ kb_id: str,
+ query: str,
+ top_k: int = 5,
+ min_score: float = 0.3,
+) -> List[Dict[str, Any]]:
+ """
+ 语义搜索知识库。
+
+ 流程:查询文本 → Embedding → 余弦相似度匹配所有分块 → Top-K
+ """
+ # 生成查询 embedding
+ query_emb = await embedding_service.generate_embedding(query)
+ if not query_emb:
+ logger.warning("搜索失败:无法生成查询 embedding")
+ return []
+
+ # 加载该知识库所有分块
+ chunks = (
+ db.query(DocumentChunk)
+ .filter(DocumentChunk.kb_id == kb_id)
+ .all()
+ )
+
+ if not chunks:
+ return []
+
+ # 构建 VectorEntry 列表
+ entries: List[VectorEntry] = []
+ for c in chunks:
+ if not c.embedding:
+ continue
+ try:
+ emb = json.loads(c.embedding) if isinstance(c.embedding, str) else c.embedding
+ except (json.JSONDecodeError, TypeError):
+ continue
+ entries.append({
+ "id": c.id,
+ "scope_kind": "kb",
+ "scope_id": kb_id,
+ "content_text": c.content,
+ "embedding": emb,
+ "metadata": c.metadata_ or {},
+ })
+
+ if not entries:
+ return []
+
+ # 相似度搜索
+ matched = await embedding_service.similarity_search(
+ query_emb, entries, top_k=top_k, min_score=min_score
+ )
+
+ results = []
+ for m in matched:
+ results.append({
+ "chunk_id": m["id"],
+ "content": m["content_text"],
+ "score": m["score"],
+ "metadata": m.get("metadata", {}),
+ })
+
+ return results
+
+
+async def rag_query(
+ db: Session,
+ kb_id: str,
+ query: str,
+ top_k: int = 5,
+ min_score: float = 0.3,
+) -> Dict[str, Any]:
+ """
+ RAG 查询:搜索相关片段 + 格式化为上下文。
+
+ 返回:
+ {
+ "query": "...",
+ "context": "根据以下资料回答:\n\n[片段1]\n[片段2]\n...",
+ "sources": [...],
+ }
+ """
+ results = await search(db, kb_id, query, top_k=top_k, min_score=min_score)
+
+ if not results:
+ return {
+ "query": query,
+ "context": "",
+ "sources": [],
+ "found": False,
+ }
+
+ # 格式化上下文
+ lines = ["根据以下资料回答用户问题:\n"]
+ for i, r in enumerate(results, 1):
+ source = r["metadata"].get("filename", "未知来源")
+ lines.append(f"[{i}] (来源: {source}):\n{r['content']}\n")
+
+ context = "\n".join(lines)
+
+ sources = [
+ {
+ "content": r["content"],
+ "score": r["score"],
+ "source": r["metadata"].get("filename", "未知"),
+ }
+ for r in results
+ ]
+
+ return {
+ "query": query,
+ "context": context,
+ "sources": sources,
+ "found": True,
+ }
diff --git a/backend/app/services/text_chunker.py b/backend/app/services/text_chunker.py
new file mode 100644
index 0000000..d9b221f
--- /dev/null
+++ b/backend/app/services/text_chunker.py
@@ -0,0 +1,132 @@
+"""
+文本分块器 — 将长文本切分为适合 Embedding + 检索的块。
+
+策略:
+1. 按段落分割(连续换行)
+2. 超长段落按句子切分(句号、问号、感叹号、换行)
+3. 短段落合并,直到接近 chunk_size
+4. chunk 之间保留 overlap 字符的重叠
+"""
+from __future__ import annotations
+
+import re
+from typing import List
+
+
+def chunk_text(
+ text: str,
+ chunk_size: int = 500,
+ chunk_overlap: int = 50,
+) -> List[str]:
+ """
+ 将文本分割为语义完整的块。
+
+ Args:
+ text: 输入文本
+ chunk_size: 每块目标字符数
+ chunk_overlap: 相邻块重叠字符数
+
+ Returns:
+ 文本块列表
+ """
+ if not text or not text.strip():
+ return []
+
+ # 1. 按段落分割
+ paragraphs = _split_paragraphs(text)
+ if not paragraphs:
+ return []
+
+ # 2. 处理每个段落:超长段落进一步拆分
+ segments: List[str] = []
+ for para in paragraphs:
+ if len(para) <= chunk_size:
+ segments.append(para)
+ else:
+ segments.extend(_split_long_paragraph(para, chunk_size))
+
+ # 3. 合并短段落为块
+ chunks = _merge_segments(segments, chunk_size, chunk_overlap)
+
+ return chunks
+
+
+def _split_paragraphs(text: str) -> List[str]:
+ """按连续换行分割段落,过滤空白段落。"""
+ # 按两个以上换行分割
+ raw = re.split(r"\n\s*\n", text)
+ paras = []
+ for p in raw:
+ p = p.strip()
+ if p:
+ paras.append(p)
+ return paras
+
+
+def _split_long_paragraph(para: str, chunk_size: int) -> List[str]:
+ """将超长段落按句子切分。"""
+ # 先按句子结束符分割(去掉 look-behind 长度限制)
+ sentences = re.split(r"(?<=[。!?])|(?<=[.!?])\s*", para)
+ sentences = [s.strip() for s in sentences if s.strip()]
+
+ if not sentences:
+ # 无法按句子分割,按字符硬切
+ return [para[i : i + chunk_size] for i in range(0, len(para), chunk_size)]
+
+ chunks: List[str] = []
+ current = ""
+ for sent in sentences:
+ if len(current) + len(sent) <= chunk_size:
+ current += sent
+ else:
+ if current:
+ chunks.append(current.strip())
+ current = sent
+ if current:
+ chunks.append(current.strip())
+
+ # 如果某个块还是太长(单句超长),按字符硬切
+ result: List[str] = []
+ for c in chunks:
+ if len(c) <= chunk_size:
+ result.append(c)
+ else:
+ for i in range(0, len(c), chunk_size):
+ result.append(c[i : i + chunk_size])
+ return result
+
+
+def _merge_segments(segments: List[str], chunk_size: int, overlap: int) -> List[str]:
+ """合并短段落为块,块间带重叠。"""
+ if not segments:
+ return []
+
+ chunks: List[str] = []
+ current = ""
+
+ for seg in segments:
+ # 如果当前块 + 新段不超过上限,追加
+ if not current:
+ current = seg
+ elif len(current) + 1 + len(seg) <= chunk_size:
+ current += "\n\n" + seg
+ else:
+ # 当前块已满
+ chunks.append(current.strip())
+
+ # 重叠:取前一块末尾的 overlap 字符
+ if overlap > 0 and len(current) > overlap:
+ # 从上一个块末尾取 overlap 字符作为新块的起始
+ tail = current[-overlap:]
+ # 从上一个换行处开始,避免截断句子
+ newline_pos = tail.find("\n")
+ if newline_pos > 0 and newline_pos < overlap // 2:
+ tail = tail[newline_pos + 1 :]
+ current = tail + "\n\n" + seg
+ else:
+ current = seg
+
+ if current.strip():
+ chunks.append(current.strip())
+
+ return chunks
diff --git a/backend/app/services/tool_registry.py b/backend/app/services/tool_registry.py
index ce0a9b5..f49e625 100644
--- a/backend/app/services/tool_registry.py
+++ b/backend/app/services/tool_registry.py
@@ -1,123 +1,360 @@
"""
-工具注册表 - 管理所有可用工具
+工具注册表 — 管理所有可用工具(内置 / HTTP / 代码段)。
+
+提供统一的工具注册、查找和执行接口。
"""
-from typing import Dict, Any, Callable, Optional, List
+from __future__ import annotations
+
import json
import logging
+import traceback
+from typing import Any, Callable, Dict, List, Optional
+
from app.models.tool import Tool
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
+# 代码工具的安全内置模块
+_CODE_SAFE_GLOBALS = {
+ "json": json,
+ "dict": dict,
+ "list": list,
+ "str": str,
+ "int": int,
+ "float": float,
+ "bool": bool,
+ "len": len,
+ "range": range,
+ "enumerate": enumerate,
+ "zip": zip,
+ "map": map,
+ "filter": filter,
+ "sorted": sorted,
+ "min": min,
+ "max": max,
+ "sum": sum,
+ "abs": abs,
+ "round": round,
+ "isinstance": isinstance,
+ "type": type,
+ "True": True,
+ "False": False,
+ "None": None,
+}
+
+
class ToolRegistry:
- """工具注册表 - 管理所有可用工具"""
-
+ """工具注册表 — 管理所有可用工具。"""
+
def __init__(self):
self._builtin_tools: Dict[str, Callable] = {}
self._tool_schemas: Dict[str, Dict[str, Any]] = {}
-
+ # 自定义工具配置(从 DB 加载的非 builtin 工具)
+ self._custom_tool_configs: Dict[str, Dict[str, Any]] = {}
+
+ # ─── 内置工具注册 ─────────────────────────────────────────
+
def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]):
"""
- 注册内置工具
-
+ 注册内置工具。
+
Args:
name: 工具名称
- func: 工具函数(可以是同步或异步函数)
- schema: 工具定义(OpenAI Function格式)
+ func: 工具函数(同步或异步)
+ schema: OpenAI function calling 格式 schema
"""
self._builtin_tools[name] = func
self._tool_schemas[name] = schema
logger.debug("注册内置工具: %s", name)
-
+
+ # ─── 工具信息查询 ─────────────────────────────────────────
+
def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]:
- """获取工具定义"""
return self._tool_schemas.get(name)
-
+
def get_tool_function(self, name: str) -> Optional[Callable]:
- """获取工具函数"""
return self._builtin_tools.get(name)
-
+
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
- """获取所有工具定义(用于LLM)"""
return list(self._tool_schemas.values())
def builtin_tool_count(self) -> int:
- """已注册的内置工具数量(用于健康检查 / 启动自检)。"""
return len(self._builtin_tools)
def builtin_tool_names(self) -> List[str]:
- """已注册的内置工具名称,有序列表。"""
return sorted(self._builtin_tools.keys())
- def load_tools_from_db(self, db: Session, tool_names: List[str] = None):
- """
- 从数据库加载工具
-
- Args:
- db: 数据库会话
- tool_names: 工具名称列表(可选,如果为None则加载所有公开工具)
- """
- query = db.query(Tool).filter(Tool.is_public == True)
- if tool_names:
- query = query.filter(Tool.name.in_(tool_names))
-
- tools = query.all()
- for tool in tools:
- self._tool_schemas[tool.name] = tool.function_schema
- # 根据implementation_type加载工具实现
- if tool.implementation_type == 'builtin':
- # 从内置工具中查找
- if tool.name in self._builtin_tools:
- logger.debug(f"工具 {tool.name} 已在内置工具中注册")
- else:
- logger.warning(f"工具 {tool.name} 标记为builtin但未在内置工具中找到")
- elif tool.implementation_type == 'http':
- # HTTP工具需要特殊处理
- self._register_http_tool(tool)
- elif tool.implementation_type == 'workflow':
- # 工作流工具
- self._register_workflow_tool(tool)
- elif tool.implementation_type == 'code':
- # 代码执行工具
- self._register_code_tool(tool)
-
- logger.info(f"从数据库加载了 {len(tools)} 个工具")
-
- def _register_http_tool(self, tool: Tool):
- """注册HTTP工具"""
- # TODO: 实现HTTP工具的动态注册
- logger.warning(f"HTTP工具 {tool.name} 的动态注册尚未实现")
-
- def _register_workflow_tool(self, tool: Tool):
- """注册工作流工具"""
- # TODO: 实现工作流工具的动态注册
- logger.warning(f"工作流工具 {tool.name} 的动态注册尚未实现")
-
- def _register_code_tool(self, tool: Tool):
- """注册代码执行工具"""
- # TODO: 实现代码执行工具的动态注册
- logger.warning(f"代码执行工具 {tool.name} 的动态注册尚未实现")
-
def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]:
- """
- 根据工具名称列表获取工具定义
-
- Args:
- tool_names: 工具名称列表
-
- Returns:
- 工具定义列表(OpenAI Function格式)
- """
tools = []
for name in tool_names:
schema = self.get_tool_schema(name)
if schema:
tools.append(schema)
else:
- logger.warning(f"工具 {name} 未找到")
+ logger.warning("工具 %s 未找到", name)
return tools
+ # ─── 从数据库加载自定义工具 ──────────────────────────────
-# 全局工具注册表实例
+ def load_tools_from_db(self, db: Session, tool_names: Optional[List[str]] = None):
+ """
+ 从数据库加载工具定义。
+
+ Args:
+ db: 数据库会话
+ tool_names: 指定名称列表;None 则加载所有公开工具
+ """
+ query = db.query(Tool).filter(Tool.is_public == True)
+ if tool_names:
+ query = query.filter(Tool.name.in_(tool_names))
+
+ tools = query.all()
+ count = 0
+ for tool in tools:
+ self._tool_schemas[tool.name] = tool.function_schema
+ if tool.implementation_type == "builtin":
+ if tool.name not in self._builtin_tools:
+ logger.warning("工具 %s 标记为 builtin 但未注册", tool.name)
+ else:
+ # 存储自定义工具配置
+ config = tool.implementation_config or {}
+ config["_type"] = tool.implementation_type
+ config["_db_id"] = tool.id
+ self._custom_tool_configs[tool.name] = config
+ count += 1
+
+ if count:
+ logger.info("从数据库加载了 %d 个自定义工具", count)
+
+ # ─── 工具执行 ─────────────────────────────────────────────
+
+ async def execute_tool(self, name: str, args: Dict[str, Any]) -> str:
+ """
+ 执行任意工具(内置 / HTTP / 代码段)。
+
+ Args:
+ name: 工具名称
+ args: 参数字典
+
+ Returns:
+ 结果字符串
+ """
+ # 1. 内置工具
+ func = self._builtin_tools.get(name)
+ if func:
+ return await self._run_function(func, name, args)
+
+ # 2. 自定义工具
+ config = self._custom_tool_configs.get(name)
+ if not config:
+ return json.dumps({"error": f"工具 '{name}' 不存在"}, ensure_ascii=False)
+
+ impl_type = config.get("_type", "")
+ try:
+ if impl_type == "http":
+ return await self._execute_http_tool(name, config, args)
+ elif impl_type == "code":
+ return await self._execute_code_tool(name, config, args)
+ elif impl_type == "workflow":
+ return json.dumps({"error": "工作流工具暂不支持动态执行"},
+ ensure_ascii=False)
+ else:
+ return json.dumps({"error": f"不支持的实现类型: {impl_type}"},
+ ensure_ascii=False)
+ except Exception as e:
+ logger.error("工具 %s 执行失败: %s", name, e, exc_info=True)
+ return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
+ ensure_ascii=False)
+
+ @staticmethod
+ async def _run_function(func: Callable, name: str, args: Dict[str, Any]) -> str:
+ """执行内置工具函数。"""
+ import asyncio
+ try:
+ if asyncio.iscoroutinefunction(func):
+ result = await func(**args)
+ else:
+ result = func(**args)
+ if isinstance(result, (dict, list)):
+ return json.dumps(result, ensure_ascii=False)
+ return str(result)
+ except Exception as e:
+ logger.error("工具 '%s' 执行失败: %s", name, e, exc_info=True)
+ return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
+ ensure_ascii=False)
+
+ # ─── HTTP 工具 ────────────────────────────────────────────
+
+ async def _execute_http_tool(
+ self, name: str, config: Dict[str, Any], args: Dict[str, Any]
+ ) -> str:
+ """
+ 执行 HTTP 工具。
+
+ implementation_config 格式:
+ {
+ "url": "https://api.example.com/{path_param}",
+ "method": "GET",
+ "headers": {"Authorization": "Bearer xxx"},
+ "body_template": {"key": "{param}"}, # 可选
+ "timeout": 30
+ }
+ URL 和 body_template 中的 {param} 会被 args 替换。
+ """
+ import httpx
+
+ url = config.get("url", "")
+ method = (config.get("method") or "GET").upper()
+ headers = config.get("headers") or {}
+ body_template = config.get("body_template")
+ timeout = config.get("timeout", 30)
+
+ if not url:
+ return json.dumps({"error": "HTTP 工具未配置 URL"}, ensure_ascii=False)
+
+ # 模板替换:{param} → args["param"]
+ def _fmt(template: str) -> str:
+ try:
+ return template.format(**args)
+ except KeyError:
+ return template
+
+ url = _fmt(url)
+ body: Any = None
+ if body_template:
+ body_str = json.dumps(body_template)
+ body_str = _fmt(body_str)
+ try:
+ body = json.loads(body_str)
+ except json.JSONDecodeError:
+ body = body_str
+
+ async with httpx.AsyncClient(timeout=timeout) as client:
+ response = await client.request(method, url, headers=headers, json=body)
+ text = response.text[:10000] # 截断过长的响应
+
+ result = {
+ "status_code": response.status_code,
+ "body": text,
+ }
+ return json.dumps(result, ensure_ascii=False)
+
+ # ─── 代码工具 ─────────────────────────────────────────────
+
+ async def _execute_code_tool(
+ self, name: str, config: Dict[str, Any], args: Dict[str, Any]
+ ) -> str:
+ """
+ 执行代码工具。
+
+ implementation_config 格式:
+ {
+ "source": "def run(args):\\n return {'sum': args['a'] + args['b']}",
+ "language": "python"
+ }
+ 代码工具需定义一个 run(args) 函数作为入口。
+ source 在沙箱环境中执行,不可访问文件系统/网络。
+ """
+ source = config.get("source", "")
+ if not source:
+ return json.dumps({"error": "代码工具未配置 source"}, ensure_ascii=False)
+
+ # 编译代码,限制可访问的全局变量
+ safe_globals = _CODE_SAFE_GLOBALS.copy()
+ safe_globals["__builtins__"] = {} # 禁用所有内置函数
+
+ try:
+ exec(source, safe_globals)
+ except Exception as e:
+ return json.dumps({
+ "error": f"代码编译失败: {e}",
+ "traceback": traceback.format_exc(),
+ }, ensure_ascii=False)
+
+ run_func = safe_globals.get("run")
+ if not run_func or not callable(run_func):
+ return json.dumps({
+ "error": "代码工具必须定义一个 run(args) 函数"
+ }, ensure_ascii=False)
+
+ try:
+ result = run_func(args)
+ if isinstance(result, (dict, list)):
+ return json.dumps(result, ensure_ascii=False)
+ return str(result)
+ except Exception as e:
+ return json.dumps({
+ "error": f"代码执行失败: {e}",
+ "traceback": traceback.format_exc(),
+ }, ensure_ascii=False)
+
+ # ─── 测试工具(不保存,直接执行)─────────────────────────
+
+ async def test_http_tool(
+ self, url: str, method: str, headers: Dict[str, str],
+ body: Optional[Dict[str, Any]], args: Dict[str, Any],
+ timeout: int = 30,
+ ) -> Dict[str, Any]:
+ """测试 HTTP 工具配置(不保存到 DB)。"""
+ import httpx
+
+ def _fmt(template: str) -> str:
+ try:
+ return template.format(**args)
+ except KeyError:
+ return template
+
+ url = _fmt(url)
+ body_sent = None
+ if body:
+ body_str = json.dumps(body)
+ body_str = _fmt(body_str)
+ try:
+ body_sent = json.loads(body_str)
+ except json.JSONDecodeError:
+ body_sent = body_str
+
+ start = __import__("time").time()
+ try:
+ async with httpx.AsyncClient(timeout=timeout) as client:
+ response = await client.request(method.upper(), url,
+ headers=headers, json=body_sent)
+ elapsed_ms = int((__import__("time").time() - start) * 1000)
+ return {
+ "success": True,
+ "status_code": response.status_code,
+ "elapsed_ms": elapsed_ms,
+ "body": response.text[:10000],
+ }
+ except Exception as e:
+ return {"success": False, "error": str(e)}
+
+ async def test_code_tool(
+ self, source: str, args: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """测试代码工具(不保存到 DB)。"""
+ safe_globals = _CODE_SAFE_GLOBALS.copy()
+ safe_globals["__builtins__"] = {}
+
+ try:
+ exec(source, safe_globals)
+ except Exception as e:
+ return {"success": False, "error": f"编译失败: {e}"}
+
+ run_func = safe_globals.get("run")
+ if not run_func:
+ return {"success": False, "error": "代码须定义 run(args) 函数"}
+
+ try:
+ start = __import__("time").time()
+ result = run_func(args)
+ elapsed_ms = int((__import__("time").time() - start) * 1000)
+ return {"success": True, "elapsed_ms": elapsed_ms, "result": result}
+ except Exception as e:
+ return {"success": False, "error": f"执行失败: {e}"}
+
+
+# 全局单例
tool_registry = ToolRegistry()
diff --git a/backend/app/services/workflow_engine.py b/backend/app/services/workflow_engine.py
index ed6a37c..e6d72cf 100644
--- a/backend/app/services/workflow_engine.py
+++ b/backend/app/services/workflow_engine.py
@@ -1917,12 +1917,25 @@ class WorkflowEngine:
if hasattr(self, '_on_tool_executed_budget'):
_agent_on_tool = self._on_tool_executed_budget
+ # Agent 的 LLM 调用计入工作流预算
+ def _on_agent_llm():
+ self._llm_invocations += 1
+ if self._llm_invocations > self._cap_llm:
+ raise WorkflowExecutionError(
+ detail=f"已超过 LLM 节点调用预算({self._cap_llm} 次)",
+ )
+
result = await run_agent_node(
node_data=node.get("data", {}),
input_data=input_data,
execution_logger=self.logger,
user_id=self.trusted_model_config_user_id,
on_tool_executed=_agent_on_tool,
+ on_llm_invocation=_on_agent_llm,
+ budget_limits={
+ "max_llm_invocations": self._cap_llm,
+ "max_tool_calls": self._cap_tool,
+ },
)
if self.logger:
duration = int((time.time() - start_time) * 1000)
diff --git a/backend/tests/README.md b/backend/tests/README.md
index c7c3278..e6d7f6f 100644
--- a/backend/tests/README.md
+++ b/backend/tests/README.md
@@ -46,23 +46,31 @@ pytest --cov=app --cov-report=html
## 测试标记
-- `@pytest.mark.unit` - 单元测试
-- `@pytest.mark.integration` - 集成测试
+- `@pytest.mark.unit` - 纯单元测试(不依赖网络/数据库)
+- `@pytest.mark.integration` - 集成测试(需要网络或可选依赖)
- `@pytest.mark.slow` - 慢速测试(需要网络或数据库)
- `@pytest.mark.api` - API测试
- `@pytest.mark.workflow` - 工作流测试
- `@pytest.mark.auth` - 认证测试
+- `@pytest.mark.asyncio` - 异步测试(使用 `pytest-asyncio`)
## 测试结构
```
tests/
├── __init__.py
-├── conftest.py # 共享fixtures和配置
-├── test_auth.py # 认证API测试
-├── test_workflows.py # 工作流API测试
-├── test_workflow_engine.py # 工作流引擎测试
-└── test_workflow_validator.py # 工作流验证器测试
+├── conftest.py # 共享fixtures和配置
+├── test_auth.py # 认证API测试
+├── test_workflows.py # 工作流API测试
+├── test_workflow_engine.py # 工作流引擎测试
+├── test_workflow_validator.py # 工作流验证器测试
+├── test_text_chunker.py # 文本分块器单元测试
+├── test_document_parser.py # 文档解析器单元测试
+├── test_tool_registry.py # 工具注册表单元测试
+├── test_tools_api.py # 工具市场API测试
+├── test_agent_memory.py # Agent记忆/上下文/Schema测试
+├── test_embedding_service.py # Embedding服务单元测试
+└── test_knowledge_base.py # 知识库RAG单元测试
```
## Fixtures
@@ -84,13 +92,31 @@ tests/
## 测试数据库
-测试使用SQLite内存数据库,每个测试函数都会:
+测试使用SQLite临时文件数据库(而非 `:memory:`),每个测试函数都会:
1. 创建所有表
2. 执行测试
3. 删除所有表
这样可以确保测试之间的隔离性。
+**注意**:使用临时文件而非 `:memory:` 是因为 FastAPI + TestClient 在异步/多线程请求中,SQLite 内存数据库会发生"连接隔离"问题——每个连接看到不同的数据库实例。文件数据库确保了所有连接共享同一份数据。
+
+**注意**:纯服务层测试(如 `test_tool_registry.py`、`test_embedding_service.py`、`test_text_chunker.py`)不依赖 `db_session` fixture,它们直接对服务类打桩或使用临时文件,运行速度更快。
+
+## 异步测试配置
+
+使用 `pytest-asyncio` 支持异步测试。所有被 `@pytest.mark.asyncio` 标记的测试函数必须使用 `async def` 声明:
+
+```python
+@pytest.mark.unit
+@pytest.mark.asyncio
+async def test_my_async_method(self):
+ result = await my_service.my_method()
+ assert result is not None
+```
+
+**注意**:`conftest.py` 中的 fixtures 默认使用 `function` 作用域,异步 fixture 不需要额外配置。
+
## 编写新测试
### 示例:API测试
@@ -128,6 +154,14 @@ class TestMyService:
5. **测试数据**:使用 fixtures 提供测试数据,避免硬编码。
+6. **已知问题**:有 9 个历史遗留测试失败(主要在 `test_auth.py`、`test_workflow_engine.py`、`test_workflow_validator.py`、`test_nodes_all.py`、`test_nodes_phase4.py` 中),这些失败与本次改造无关,可忽略检查新建的测试文件时排除它们:
+ ```bash
+ pytest tests/test_text_chunker.py tests/test_document_parser.py \
+ tests/test_tool_registry.py tests/test_tools_api.py \
+ tests/test_agent_memory.py tests/test_embedding_service.py \
+ tests/test_knowledge_base.py -v
+ ```
+
## CI/CD集成
在CI/CD流程中运行测试:
diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py
index 97ccf16..0fa5b34 100644
--- a/backend/tests/conftest.py
+++ b/backend/tests/conftest.py
@@ -6,12 +6,35 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from fastapi.testclient import TestClient
from app.core.database import Base, get_db, SessionLocal
-from app.main import app
+from app.main import app as fastapi_app
from app.core.config import settings
-import os
-# 测试数据库URL(使用SQLite内存数据库)
-TEST_DATABASE_URL = "sqlite:///:memory:"
+# 导入所有模型,确保 Base.metadata 包含所有表
+from app.core.database import Base as _Base
+import app.models.user # noqa: F401
+import app.models.workflow # noqa: F401
+import app.models.agent # noqa: F401
+import app.models.execution # noqa: F401
+import app.models.model_config # noqa: F401
+import app.models.workflow_template # noqa: F401
+import app.models.permission # noqa: F401
+import app.models.alert_rule # noqa: F401
+import app.models.agent_llm_log # noqa: F401
+import app.models.agent_vector_memory # noqa: F401
+import app.models.knowledge_base # noqa: F401
+import app.models.tool # noqa: F401
+
+assert len(_Base.metadata.tables) > 0, "没有模型表被注册"
+
+# 测试数据库URL
+# 使用临时文件而非 :memory: 避免 FastAPI 异步/多线程请求中的
+# SQLite 内存数据库连接隔离问题(每个连接看到不同的数据库)
+import tempfile as _tempfile
+import os as _os
+import atexit as _atexit
+_test_db_fd, _test_db_path = _tempfile.mkstemp(suffix=".db")
+_os.close(_test_db_fd)
+TEST_DATABASE_URL = f"sqlite:///{_test_db_path}"
# 创建测试数据库引擎
test_engine = create_engine(
@@ -23,20 +46,27 @@ test_engine = create_engine(
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine)
+# 在进程退出时清理临时数据库文件
+def _cleanup_test_db():
+ test_engine.dispose()
+ if _os.path.exists(_test_db_path):
+ _os.unlink(_test_db_path)
+
+
+_atexit.register(_cleanup_test_db)
+
+
@pytest.fixture(scope="function")
def db_session():
"""创建测试数据库会话"""
- # 创建所有表
- Base.metadata.create_all(bind=test_engine)
-
- # 创建会话
+ _Base.metadata.create_all(bind=test_engine)
+
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
- # 删除所有表
- Base.metadata.drop_all(bind=test_engine)
+ _Base.metadata.drop_all(bind=test_engine)
@pytest.fixture(scope="function")
@@ -47,13 +77,13 @@ def client(db_session):
yield db_session
finally:
pass
-
- app.dependency_overrides[get_db] = override_get_db
-
- with TestClient(app) as test_client:
+
+ fastapi_app.dependency_overrides[get_db] = override_get_db
+
+ with TestClient(fastapi_app) as test_client:
yield test_client
-
- app.dependency_overrides.clear()
+
+ fastapi_app.dependency_overrides.clear()
@pytest.fixture
@@ -72,7 +102,7 @@ def authenticated_client(client, test_user_data):
# 注册用户
response = client.post("/api/v1/auth/register", json=test_user_data)
assert response.status_code == 201
-
+
# 登录获取token
login_response = client.post(
"/api/v1/auth/login",
@@ -83,10 +113,10 @@ def authenticated_client(client, test_user_data):
)
assert login_response.status_code == 200
token = login_response.json()["access_token"]
-
+
# 设置认证头
client.headers.update({"Authorization": f"Bearer {token}"})
-
+
return client
diff --git a/backend/tests/test_agent_memory.py b/backend/tests/test_agent_memory.py
new file mode 100644
index 0000000..4cf1322
--- /dev/null
+++ b/backend/tests/test_agent_memory.py
@@ -0,0 +1,321 @@
+"""
+Agent 记忆系统单元测试
+"""
+from __future__ import annotations
+
+import json
+import pytest
+from unittest.mock import AsyncMock, patch, MagicMock
+from datetime import datetime
+
+
+class TestAgentContext:
+ """AgentContext 基础功能测试"""
+
+ @pytest.mark.unit
+ def test_context_initialization(self):
+ from app.agent_runtime.context import AgentContext
+
+ ctx = AgentContext(
+ system_prompt="You are a helpful assistant.",
+ user_id="test-user",
+ session_id="test-session",
+ )
+ assert ctx.session_id == "test-session"
+ assert ctx.user_id == "test-user"
+ assert ctx.iteration == 0
+ assert ctx.tool_calls_made == 0
+ # system prompt 在 messages 中
+ msgs = ctx.messages
+ assert msgs[0]["role"] == "system"
+ assert msgs[0]["content"] == "You are a helpful assistant."
+
+ @pytest.mark.unit
+ def test_add_user_message(self):
+ from app.agent_runtime.context import AgentContext
+
+ ctx = AgentContext(session_id="s1")
+ ctx.add_user_message("Hello")
+ assert len(ctx.messages) == 2 # system + user
+ assert ctx.messages[1]["role"] == "user"
+ assert ctx.messages[1]["content"] == "Hello"
+
+ @pytest.mark.unit
+ def test_add_assistant_message(self):
+ from app.agent_runtime.context import AgentContext
+
+ ctx = AgentContext(session_id="s1")
+ ctx.add_user_message("Hi")
+ ctx.add_assistant_message("Hello!", tool_calls=[{
+ "id": "call_1",
+ "type": "function",
+ "function": {"name": "test", "arguments": "{}"},
+ }])
+ msgs = ctx.messages
+ assistant_msgs = [m for m in msgs if m["role"] == "assistant"]
+ assert len(assistant_msgs) == 1
+ assert assistant_msgs[0]["content"] == "Hello!"
+ assert "tool_calls" in assistant_msgs[0]
+
+ @pytest.mark.unit
+ def test_add_tool_result(self):
+ from app.agent_runtime.context import AgentContext
+
+ ctx = AgentContext(session_id="s1")
+ ctx.add_tool_result("call_1", "test_tool", '{"result": "ok"}')
+ tool_msgs = [m for m in ctx.messages if m["role"] == "tool"]
+ assert len(tool_msgs) == 1
+ assert tool_msgs[0]["tool_call_id"] == "call_1"
+ assert tool_msgs[0]["name"] == "test_tool"
+
+ @pytest.mark.unit
+ def test_iteration_tracking(self):
+ from app.agent_runtime.context import AgentContext
+
+ ctx = AgentContext("s1")
+ assert ctx.iteration == 0
+ ctx.iteration += 1
+ assert ctx.iteration == 1
+ ctx.tool_calls_made += 2
+ assert ctx.tool_calls_made == 2
+
+ @pytest.mark.unit
+ def test_context_reset(self):
+ from app.agent_runtime.context import AgentContext
+
+ ctx = AgentContext(system_prompt="Helpful.", session_id="s1")
+ ctx.add_user_message("Hello")
+ ctx.add_assistant_message("Hi")
+ ctx.iteration = 5
+ ctx.tool_calls_made = 3
+ ctx.reset()
+ assert ctx.iteration == 0
+ assert ctx.tool_calls_made == 0
+ # 重置后 messages 应仅含 system
+ msgs = ctx.messages
+ assert len(msgs) == 1
+ assert msgs[0]["role"] == "system"
+
+ @pytest.mark.unit
+ def test_set_system_prompt(self):
+ from app.agent_runtime.context import AgentContext
+
+ ctx = AgentContext(system_prompt="Original.", session_id="s1")
+ ctx.set_system_prompt("Updated.")
+ # 未发送过消息,可以更新
+ assert ctx.messages[0]["content"] == "Updated."
+
+
+class TestAgentMemory:
+ """AgentMemory 分层记忆测试"""
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_initialize_no_persist(self):
+ from app.agent_runtime.memory import AgentMemory
+
+ memory = AgentMemory(
+ scope_kind="agent",
+ scope_id="test-agent",
+ session_key="test-session",
+ persist=False,
+ )
+ result = await memory.initialize(query="Hello")
+ assert result == ""
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_save_context_no_persist(self):
+ from app.agent_runtime.memory import AgentMemory
+
+ memory = AgentMemory(
+ scope_kind="agent",
+ scope_id="test-agent",
+ session_key="test-session",
+ persist=False,
+ )
+ await memory.save_context(
+ user_message="Hello",
+ assistant_reply="Hi there!",
+ )
+ # 不报错即为通过
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_vector_memory_disabled(self):
+ from app.agent_runtime.memory import AgentMemory
+
+ memory = AgentMemory(
+ scope_kind="agent",
+ scope_id="test-agent",
+ session_key="test-session",
+ persist=False,
+ vector_memory_enabled=False,
+ )
+ await memory.save_context(
+ user_message="No vector",
+ assistant_reply="OK",
+ )
+ # 不报错即为通过
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_vector_search_no_results(self):
+ from app.agent_runtime.memory import AgentMemory
+
+ memory = AgentMemory(
+ scope_kind="agent",
+ scope_id="nonexistent",
+ persist=False,
+ vector_memory_enabled=True,
+ )
+ result = await memory._vector_search(query="test")
+ assert result == ""
+
+ @pytest.mark.unit
+ def test_trim_messages(self):
+ from app.agent_runtime.memory import AgentMemory
+
+ memory = AgentMemory(persist=False)
+ messages = [{"role": "system", "content": "You are helpful."}]
+ for i in range(30):
+ messages.append({"role": "user", "content": f"msg {i}"})
+ messages.append({"role": "assistant", "content": f"reply {i}"})
+ trimmed = memory.trim_messages(messages)
+ assert len(trimmed) <= memory.max_history + 1 # +1 for system
+ assert trimmed[0]["role"] == "system"
+
+ @pytest.mark.unit
+ def test_summarize_history(self):
+ from app.agent_runtime.memory import AgentMemory
+
+ history = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello"},
+ {"role": "user", "content": "How are you?"},
+ ]
+ summary = AgentMemory._summarize_history(history)
+ assert "2 轮" in summary
+
+
+class TestToolManager:
+ """AgentToolManager 测试"""
+
+ @pytest.mark.unit
+ def test_include_filter(self):
+ from app.agent_runtime.tool_manager import AgentToolManager
+
+ mgr = AgentToolManager(include_tools=["math", "file_read"])
+ assert mgr._include_tools == {"math", "file_read"}
+ assert mgr._exclude_tools == set()
+
+ @pytest.mark.unit
+ def test_exclude_filter(self):
+ from app.agent_runtime.tool_manager import AgentToolManager
+
+ mgr = AgentToolManager(exclude_tools=["dangerous_tool"])
+ assert "dangerous_tool" in mgr._exclude_tools
+
+ @pytest.mark.unit
+ def test_tool_name_extraction(self):
+ from app.agent_runtime.tool_manager import AgentToolManager
+
+ name = AgentToolManager._extract_tool_name({
+ "type": "function",
+ "function": {"name": "my_tool"},
+ })
+ assert name == "my_tool"
+
+ @pytest.mark.unit
+ def test_tool_name_extraction_flat(self):
+ from app.agent_runtime.tool_manager import AgentToolManager
+
+ name = AgentToolManager._extract_tool_name({"name": "flat_tool"})
+ assert name == "flat_tool"
+
+ @pytest.mark.unit
+ def test_tool_name_extraction_empty(self):
+ from app.agent_runtime.tool_manager import AgentToolManager
+
+ name = AgentToolManager._extract_tool_name({})
+ assert name is None
+
+ @pytest.mark.unit
+ def test_has_tools(self):
+ from app.agent_runtime.tool_manager import AgentToolManager
+
+ mgr = AgentToolManager()
+ assert mgr.has_tools() is True
+
+ @pytest.mark.unit
+ def test_tool_names(self):
+ from app.agent_runtime.tool_manager import AgentToolManager
+
+ mgr = AgentToolManager()
+ names = mgr.tool_names()
+ assert isinstance(names, list)
+ assert len(names) >= 1
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_execute_delegates_to_registry(self):
+ from app.agent_runtime.tool_manager import AgentToolManager
+ from app.services.tool_registry import tool_registry
+
+ mgr = AgentToolManager()
+ # 执行一个内置工具
+ result = await mgr.execute("datetime", {"format": "%Y"})
+ assert isinstance(result, str)
+ assert len(result) > 0
+
+
+class TestAgentSchemas:
+ """Agent Schemas 测试"""
+
+ @pytest.mark.unit
+ def test_agent_config_defaults(self):
+ from app.agent_runtime.schemas import AgentConfig
+
+ config = AgentConfig(name="Test Agent")
+ assert config.name == "Test Agent"
+ assert config.system_prompt == "你是一个有用的AI助手。请使用可用工具来帮助用户完成任务。"
+ assert config.llm.model == "gpt-4o-mini"
+ assert config.llm.temperature == 0.7
+ assert config.llm.max_iterations == 10
+
+ @pytest.mark.unit
+ def test_agent_memory_config_defaults(self):
+ from app.agent_runtime.schemas import AgentMemoryConfig
+
+ cfg = AgentMemoryConfig()
+ assert cfg.enabled is True
+ assert cfg.vector_memory_enabled is True
+ assert cfg.vector_memory_top_k == 5
+ assert cfg.max_history_messages == 20
+
+ @pytest.mark.unit
+ def test_agent_budget_config_defaults(self):
+ from app.agent_runtime.schemas import AgentBudgetConfig
+
+ cfg = AgentBudgetConfig()
+ assert cfg.max_llm_invocations == 200
+ assert cfg.max_tool_calls == 500
+
+ @pytest.mark.unit
+ def test_agent_result_fields(self):
+ from app.agent_runtime.schemas import AgentResult, AgentStep
+
+ result = AgentResult(
+ content="Test result",
+ iterations_used=3,
+ tool_calls_made=5,
+ truncated=False,
+ steps=[
+ AgentStep(iteration=1, type="think", content="Thinking..."),
+ AgentStep(iteration=2, type="tool_call", content="Calling tool", tool_name="test"),
+ ],
+ )
+ assert result.success is True
+ assert result.content == "Test result"
+ assert result.iterations_used == 3
+ assert len(result.steps) == 2
diff --git a/backend/tests/test_document_parser.py b/backend/tests/test_document_parser.py
new file mode 100644
index 0000000..ba7305d
--- /dev/null
+++ b/backend/tests/test_document_parser.py
@@ -0,0 +1,137 @@
+"""
+文档解析器单元测试
+"""
+from __future__ import annotations
+
+import os
+import tempfile
+import pytest
+
+from app.services.document_parser import parse_document
+
+
+class TestDocumentParser:
+ """文档解析器测试"""
+
+ @pytest.mark.unit
+ def test_parse_txt(self):
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", encoding="utf-8", delete=False) as f:
+ f.write("Hello world\nThis is a test.")
+ tmp_path = f.name
+ try:
+ result = parse_document(tmp_path, "txt")
+ assert "Hello world" in result
+ assert "This is a test" in result
+ finally:
+ os.unlink(tmp_path)
+
+ @pytest.mark.unit
+ def test_parse_txt_utf8_bom(self):
+ """带 BOM 的 UTF-8 文件"""
+ content = "\ufeff你好世界\n测试内容"
+ with tempfile.NamedTemporaryFile(mode="wb", suffix=".txt", delete=False) as f:
+ f.write(content.encode("utf-8-sig"))
+ tmp_path = f.name
+ try:
+ result = parse_document(tmp_path, "txt")
+ assert "你好" in result
+ assert "测试" in result
+ finally:
+ os.unlink(tmp_path)
+
+ @pytest.mark.unit
+ def test_parse_md(self):
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".md", encoding="utf-8", delete=False) as f:
+ f.write("# Title\n\nThis is **bold** text.")
+ tmp_path = f.name
+ try:
+ result = parse_document(tmp_path, "md")
+ assert "Title" in result
+ assert "bold" in result
+ finally:
+ os.unlink(tmp_path)
+
+ @pytest.mark.unit
+ def test_parse_csv(self):
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", encoding="utf-8", delete=False) as f:
+ f.write("name,age,city\nAlice,30,Beijing\nBob,25,Shanghai")
+ tmp_path = f.name
+ try:
+ result = parse_document(tmp_path, "csv")
+ assert "Alice" in result
+ assert "Beijing" in result
+ assert "Bob" in result
+ finally:
+ os.unlink(tmp_path)
+
+ @pytest.mark.unit
+ def test_parse_csv_with_different_delimiter(self):
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", encoding="utf-8", delete=False) as f:
+ f.write("a|b|c\n1|2|3\n4|5|6")
+ tmp_path = f.name
+ try:
+ result = parse_document(tmp_path, "csv")
+ assert "1" in result or "a" in result
+ finally:
+ os.unlink(tmp_path)
+
+ @pytest.mark.unit
+ def test_unsupported_format(self):
+ with tempfile.NamedTemporaryFile(suffix=".xyz", delete=False) as f:
+ tmp_path = f.name
+ try:
+ result = parse_document(tmp_path, "xyz")
+ assert result is None # 不支持的格式返回 None
+ finally:
+ os.unlink(tmp_path)
+
+ @pytest.mark.unit
+ def test_file_not_found(self):
+ result = parse_document("/nonexistent/file.txt", "txt")
+ assert result is None # 文件不存在返回 None
+
+ @pytest.mark.unit
+ def test_empty_file(self):
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", encoding="utf-8", delete=False) as f:
+ tmp_path = f.name
+ try:
+ result = parse_document(tmp_path, "txt")
+ assert result is not None
+ finally:
+ os.unlink(tmp_path)
+
+ @pytest.mark.integration
+ def test_parse_docx(self):
+ """需要 python-docx 包"""
+ pytest.importorskip("docx")
+ from docx import Document
+
+ doc = Document()
+ doc.add_paragraph("Hello from docx")
+ doc.add_paragraph("第二段落")
+ with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as f:
+ doc.save(f.name)
+ tmp_path = f.name
+ try:
+ result = parse_document(tmp_path, "docx")
+ assert "Hello from docx" in result
+ assert "第二段落" in result
+ finally:
+ os.unlink(tmp_path)
+
+ @pytest.mark.integration
+ def test_parse_pdf(self):
+ """需要 PyPDF2 包"""
+ pytest.importorskip("PyPDF2")
+ from PyPDF2 import PdfWriter
+
+ writer = PdfWriter()
+ writer.add_blank_page(72, 72) # 1 inch page
+ with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
+ writer.write(f)
+ tmp_path = f.name
+ try:
+ result = parse_document(tmp_path, "pdf")
+ assert result is not None
+ finally:
+ os.unlink(tmp_path)
diff --git a/backend/tests/test_embedding_service.py b/backend/tests/test_embedding_service.py
new file mode 100644
index 0000000..e7a4a40
--- /dev/null
+++ b/backend/tests/test_embedding_service.py
@@ -0,0 +1,146 @@
+"""
+Embedding 服务单元测试
+"""
+from __future__ import annotations
+
+import pytest
+from unittest.mock import patch, AsyncMock
+from app.services.embedding_service import EmbeddingService
+
+
+class TestEmbeddingService:
+ """Embedding 服务测试"""
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_cosine_similarity_identical(self):
+ from app.services.embedding_service import EmbeddingService
+
+ a = [1.0, 0.0, 0.0]
+ b = [1.0, 0.0, 0.0]
+ sim = EmbeddingService.cosine_similarity(a, b)
+ assert sim == pytest.approx(1.0, abs=1e-6)
+
+ @pytest.mark.unit
+ def test_cosine_similarity_orthogonal(self):
+ from app.services.embedding_service import EmbeddingService
+
+ a = [1.0, 0.0]
+ b = [0.0, 1.0]
+ sim = EmbeddingService.cosine_similarity(a, b)
+ assert sim == pytest.approx(0.0, abs=1e-6)
+
+ @pytest.mark.unit
+ def test_cosine_similarity_opposite(self):
+ from app.services.embedding_service import EmbeddingService
+
+ a = [1.0, 0.0]
+ b = [-1.0, 0.0]
+ sim = EmbeddingService.cosine_similarity(a, b)
+ assert sim == pytest.approx(-1.0, abs=1e-6)
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_similarity_search_empty(self):
+ from app.services.embedding_service import EmbeddingService
+
+ svc = EmbeddingService()
+ results = await svc.similarity_search(
+ [1.0, 0.0], [], top_k=5
+ )
+ assert results == []
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_similarity_search_ordering(self):
+ from app.services.embedding_service import EmbeddingService
+
+ svc = EmbeddingService()
+ entries = [
+ {"content_text": "dogs are great pets", "embedding": [0.9, 0.0, 0.0]},
+ {"content_text": "cats are independent", "embedding": [0.1, 0.0, 0.0]},
+ ]
+ query = [0.8, 0.0, 0.0]
+ results = await svc.similarity_search(query, entries, top_k=3, min_score=0.0)
+ assert len(results) == 2
+ assert results[0]["content_text"] == "dogs are great pets"
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_similarity_search_top_k(self):
+ from app.services.embedding_service import EmbeddingService
+
+ svc = EmbeddingService()
+ entries = [
+ {"content_text": f"entry {i}", "embedding": [float(i) / 10, 0.0]} for i in range(10)
+ ]
+ query = [1.0, 0.0]
+ results = await svc.similarity_search(query, entries, top_k=3)
+ assert len(results) == 3
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_similarity_search_min_score(self):
+ from app.services.embedding_service import EmbeddingService
+
+ svc = EmbeddingService()
+ entries = [
+ {"content_text": "close match", "embedding": [0.9, 0.0]},
+ {"content_text": "distant match", "embedding": [-0.5, 0.0]},
+ ]
+ query = [1.0, 0.0]
+ results = await svc.similarity_search(query, entries, top_k=5, min_score=0.5)
+ assert len(results) == 1
+ assert results[0]["content_text"] == "close match"
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_generate_embedding_empty(self):
+ from app.services.embedding_service import EmbeddingService
+
+ svc = EmbeddingService()
+ result = await svc.generate_embedding("")
+ assert result is None
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_generate_embedding_no_api_key(self):
+ """无 API Key 时返回 None"""
+ from app.services.embedding_service import EmbeddingService
+
+ svc = EmbeddingService()
+ with patch.object(svc, "_get_client", AsyncMock(return_value=None)):
+ result = await svc.generate_embedding("test")
+ assert result is None
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_generate_embeddings_empty(self):
+ from app.services.embedding_service import EmbeddingService
+
+ svc = EmbeddingService()
+ result = await svc.generate_embeddings([])
+ assert result == []
+
+ @pytest.mark.unit
+ def test_serialize_deserialize(self):
+ from app.services.embedding_service import EmbeddingService
+
+ emb = [0.1, 0.2, 0.3]
+ serialized = EmbeddingService.serialize_embedding(emb)
+ deserialized = EmbeddingService.deserialize_embedding(serialized)
+ assert deserialized == emb
+
+ @pytest.mark.unit
+ def test_deserialize_invalid(self):
+ from app.services.embedding_service import EmbeddingService
+
+ result = EmbeddingService.deserialize_embedding("invalid json")
+ assert result == []
+
+ @pytest.mark.unit
+ def test_deserialize_list_already(self):
+ from app.services.embedding_service import EmbeddingService
+
+ result = EmbeddingService.deserialize_embedding([1.0, 2.0])
+ assert result == [1.0, 2.0]
diff --git a/backend/tests/test_knowledge_base.py b/backend/tests/test_knowledge_base.py
new file mode 100644
index 0000000..2330ef8
--- /dev/null
+++ b/backend/tests/test_knowledge_base.py
@@ -0,0 +1,124 @@
+"""
+知识库 RAG 单元测试
+"""
+from __future__ import annotations
+
+import pytest
+from unittest.mock import patch, AsyncMock, MagicMock
+
+
+class TestKnowledgeService:
+ """知识库服务测试"""
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_search_empty_kb(self):
+ """空知识库搜索返回空列表"""
+ from app.services.knowledge_service import search
+
+ mock_db = MagicMock()
+ mock_query = MagicMock()
+ mock_db.query.return_value = mock_query
+ mock_query.filter.return_value = mock_query
+ mock_query.all.return_value = []
+ mock_query.first.return_value = None
+
+ results = await search(mock_db, kb_id="nonexistent", query="test")
+ assert results == []
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_rag_query_no_results(self):
+ """无检索结果时返回空上下文"""
+ from app.services.knowledge_service import rag_query
+
+ mock_db = MagicMock()
+ mock_query = MagicMock()
+ mock_db.query.return_value = mock_query
+ mock_query.filter.return_value = mock_query
+ mock_query.all.return_value = []
+ mock_query.first.return_value = None
+
+ result = await rag_query(mock_db, kb_id="test", query="no results")
+ assert result["found"] is False
+ assert result["context"] == ""
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_search_with_content(self):
+ """模拟有内容的搜索结果"""
+ from app.services.knowledge_service import search
+
+ from app.models.knowledge_base import DocumentChunk
+ mock_chunk = MagicMock(spec=DocumentChunk)
+ mock_chunk.id = "chunk-1"
+ mock_chunk.content = "test content about Python programming"
+ mock_chunk.chunk_index = 0
+ mock_chunk.document_id = "doc-1"
+ mock_chunk.metadata = {"filename": "test.txt"}
+
+ mock_db = MagicMock()
+ mock_query = MagicMock()
+ mock_db.query.return_value = mock_query
+ mock_query.filter.return_value = mock_query
+ mock_query.all.return_value = []
+
+ with patch("app.services.knowledge_service.embedding_service.generate_embedding",
+ AsyncMock(return_value=[0.1, 0.2, 0.3])):
+ with patch("app.services.knowledge_service.embedding_service.similarity_search",
+ AsyncMock(return_value=[
+ {"content_text": "test content about Python", "score": 0.85, "metadata": {}}
+ ])):
+ results = await search(mock_db, kb_id="test", query="Python")
+ # 可能返回空(chunks filter 不匹配)但不报错
+ assert isinstance(results, list)
+
+
+class TestKnowledgeModels:
+ """知识库模型测试"""
+
+ @pytest.mark.unit
+ def test_knowledge_base_model(self):
+ from app.models.knowledge_base import KnowledgeBase
+
+ kb = KnowledgeBase(
+ name="Test KB",
+ description="Test",
+ user_id="user-1",
+ chunk_size=500,
+ chunk_overlap=50,
+ )
+ assert kb.name == "Test KB"
+ assert kb.chunk_size == 500
+ assert kb.chunk_overlap == 50
+
+ @pytest.mark.unit
+ def test_document_model(self):
+ from app.models.knowledge_base import Document
+
+ doc = Document(
+ kb_id="kb-1",
+ filename="test.txt",
+ file_type="txt",
+ file_size=1024,
+ status="completed",
+ chunk_count=5,
+ )
+ assert doc.filename == "test.txt"
+ assert doc.status == "completed"
+ assert doc.chunk_count == 5
+
+ @pytest.mark.unit
+ def test_document_chunk_model(self):
+ from app.models.knowledge_base import DocumentChunk
+
+ chunk = DocumentChunk(
+ document_id="doc-1",
+ kb_id="kb-1",
+ chunk_index=0,
+ content="test content",
+ embedding="[0.1, 0.2, 0.3]",
+ metadata={"source": "test"},
+ )
+ assert chunk.chunk_index == 0
+ assert chunk.content == "test content"
diff --git a/backend/tests/test_text_chunker.py b/backend/tests/test_text_chunker.py
new file mode 100644
index 0000000..a911aca
--- /dev/null
+++ b/backend/tests/test_text_chunker.py
@@ -0,0 +1,139 @@
+"""
+文本分块器单元测试
+"""
+from __future__ import annotations
+
+import pytest
+from app.services.text_chunker import chunk_text, _split_paragraphs, _split_long_paragraph, _merge_segments
+
+
+class TestSplitParagraphs:
+ """段落分割测试"""
+
+ def test_empty_text(self):
+ assert _split_paragraphs("") == []
+ assert _split_paragraphs(" ") == []
+ assert _split_paragraphs("\n\n\n") == []
+
+ def test_single_paragraph(self):
+ result = _split_paragraphs("Hello world")
+ assert result == ["Hello world"]
+
+ def test_multiple_paragraphs(self):
+ text = "第一段内容。\n\n第二段内容。\n\n第三段内容。"
+ result = _split_paragraphs(text)
+ assert len(result) == 3
+ assert "第一段" in result[0]
+ assert "第二段" in result[1]
+ assert "第三段" in result[2]
+
+ def test_mixed_newlines(self):
+ text = "段1\n\n\n段2\n\n段3"
+ result = _split_paragraphs(text)
+ assert len(result) == 3
+
+
+class TestSplitLongParagraph:
+ """超长段落分割测试"""
+
+ def test_short_paragraph(self):
+ result = _split_long_paragraph("短文本", chunk_size=500)
+ assert result == ["短文本"]
+
+ def test_long_paragraph_chinese(self):
+ para = "第一句。" * 200 # 600 chars, exceeds 500
+ result = _split_long_paragraph(para, chunk_size=500)
+ assert len(result) >= 2
+ assert all(len(c) <= 500 for c in result)
+
+ def test_long_paragraph_english(self):
+ para = "Sentence. " * 200
+ result = _split_long_paragraph(para, chunk_size=500)
+ assert len(result) >= 2
+ assert all(len(c) <= 500 for c in result)
+
+ def test_no_sentence_boundary(self):
+ """无句号可分割时按字符硬切"""
+ para = "a" * 1000
+ result = _split_long_paragraph(para, chunk_size=500)
+ assert len(result) == 2
+ assert len(result[0]) == 500
+ assert len(result[1]) == 500
+
+
+class TestMergeSegments:
+ """段落合并测试"""
+
+ def test_empty(self):
+ assert _merge_segments([], 500, 0) == []
+
+ def test_single_segment(self):
+ result = _merge_segments(["hello"], 500, 0)
+ assert result == ["hello"]
+
+ def test_merge_short_segments(self):
+ segs = ["a" * 100, "b" * 100, "c" * 100] # each < 500, total ~300 < 500
+ result = _merge_segments(segs, 500, 0)
+ assert len(result) == 1
+ assert len(result[0]) > 200
+
+ def test_split_large_segments(self):
+ segs = ["a" * 300, "b" * 300, "c" * 300] # 需要分为多个chunk
+ result = _merge_segments(segs, 500, 0)
+ assert len(result) >= 2
+
+ def test_overlap(self):
+ segs = ["a" * 300, "b" * 300]
+ result = _merge_segments(segs, 500, 50)
+ assert len(result) >= 1
+
+
+class TestChunkText:
+ """chunk_text 整体测试"""
+
+ def test_empty_text(self):
+ assert chunk_text("") == []
+ assert chunk_text(" ") == []
+ assert chunk_text(None) == [] # type: ignore
+
+ def test_short_text(self):
+ result = chunk_text("Hello world", chunk_size=500, chunk_overlap=0)
+ assert len(result) == 1
+ assert "Hello world" in result[0]
+
+ def test_normal_text(self):
+ text = """
+这是第一段。它包含一些内容。
+
+这是第二段。它也有一些内容。而且更长一些。
+
+这是第三段。最后一段内容。
+"""
+ result = chunk_text(text, chunk_size=500, chunk_overlap=0)
+ assert len(result) >= 1
+ # 所有内容都应该在结果中
+ all_text = "".join(result)
+ assert "第一段" in all_text
+ assert "第二段" in all_text
+ assert "第三段" in all_text
+
+ def test_chinese_text(self):
+ """中文文本测试"""
+ text = "我喜欢吃川菜。特别是麻辣火锅。还有水煮鱼。这些都很美味。"
+ result = chunk_text(text, chunk_size=100, chunk_overlap=0)
+ assert len(result) >= 1
+
+ def test_overlap_between_chunks(self):
+ """验证块间重叠"""
+ para = "这是一个很长的段落。" * 50
+ result = chunk_text(para, chunk_size=200, chunk_overlap=50)
+ if len(result) > 1:
+ # 相邻块应该有重叠内容
+ assert len(result[0]) > 0
+ assert len(result[1]) > 0
+
+ @pytest.mark.unit
+ def test_with_mixed_punctuation(self):
+ text = "Hello! How are you? I am fine. Thank you. 你好!最近怎么样?我很好。谢谢。"
+ result = chunk_text(text, chunk_size=200, chunk_overlap=0)
+ assert len(result) >= 1
diff --git a/backend/tests/test_tool_registry.py b/backend/tests/test_tool_registry.py
new file mode 100644
index 0000000..da60eef
--- /dev/null
+++ b/backend/tests/test_tool_registry.py
@@ -0,0 +1,293 @@
+"""
+工具注册表单元测试
+"""
+from __future__ import annotations
+
+import json
+import pytest
+from unittest.mock import patch, AsyncMock
+
+from app.services.tool_registry import ToolRegistry, _CODE_SAFE_GLOBALS
+
+
+@pytest.fixture
+def registry():
+ r = ToolRegistry()
+ return r
+
+
+class TestToolRegistryBuiltin:
+ """内置工具注册与查询"""
+
+ @pytest.mark.unit
+ def test_register_and_get(self, registry):
+ def my_tool(**kwargs):
+ return {"result": kwargs.get("x", 0) + kwargs.get("y", 0)}
+
+ schema = {
+ "type": "function",
+ "function": {
+ "name": "add",
+ "description": "加法",
+ "parameters": {"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}},
+ },
+ }
+ registry.register_builtin_tool("add", my_tool, schema)
+ assert registry.get_tool_schema("add") == schema
+ assert registry.get_tool_function("add") == my_tool
+ assert registry.builtin_tool_count() == 1
+ assert "add" in registry.builtin_tool_names()
+
+ @pytest.mark.unit
+ def test_get_missing_tool(self, registry):
+ assert registry.get_tool_schema("nonexistent") is None
+ assert registry.get_tool_function("nonexistent") is None
+
+ @pytest.mark.unit
+ def test_get_all_schemas(self, registry):
+ schema1 = {"type": "function", "function": {"name": "tool1"}}
+ schema2 = {"type": "function", "function": {"name": "tool2"}}
+ registry.register_builtin_tool("tool1", lambda: None, schema1)
+ registry.register_builtin_tool("tool2", lambda: None, schema2)
+ assert len(registry.get_all_tool_schemas()) == 2
+
+ @pytest.mark.unit
+ def test_get_tools_by_names(self, registry):
+ registry.register_builtin_tool("a", lambda: None, {"function": {"name": "a"}})
+ registry.register_builtin_tool("b", lambda: None, {"function": {"name": "b"}})
+ tools = registry.get_tools_by_names(["a", "b", "c"])
+ assert len(tools) == 2
+ assert tools[0]["function"]["name"] == "a"
+
+ @pytest.mark.unit
+ def test_sync_function_execution(self, registry):
+ def sync_func(x=0, y=0):
+ return x + y
+
+ registry.register_builtin_tool("add", sync_func, {"function": {"name": "add"}})
+ result = registry._run_function(sync_func, "add", {"x": 3, "y": 4})
+ import asyncio
+ result = asyncio.run(registry._run_function(sync_func, "add", {"x": 3, "y": 4}))
+ parsed = json.loads(result)
+ assert parsed == 7
+
+
+@pytest.mark.usefixtures("registry")
+class TestToolRegistryHTTP:
+ """HTTP 工具执行测试(mock httpx)"""
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_execute_http(self, registry):
+ config = {
+ "url": "https://api.example.com/data?q={query}",
+ "method": "GET",
+ "headers": {"Authorization": "Bearer token"},
+ "timeout": 10,
+ "_type": "http",
+ }
+ registry._custom_tool_configs["test_http"] = config
+
+ with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
+ mock_response = AsyncMock()
+ mock_response.status_code = 200
+ mock_response.text = '{"result": "ok"}'
+ mock_request.return_value = mock_response
+
+ result = await registry.execute_tool("test_http", {"query": "hello"})
+ parsed = json.loads(result)
+ assert parsed["status_code"] == 200
+ # 验证 URL 模板替换
+ called_url = mock_request.call_args[0][1]
+ assert "hello" in called_url
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_execute_http_post_with_body(self, registry):
+ config = {
+ "url": "https://api.example.com/submit",
+ "method": "POST",
+ "headers": {},
+ "body_template": {"name": "{name}", "age": "{age}"},
+ "timeout": 10,
+ "_type": "http",
+ }
+ registry._custom_tool_configs["test_post"] = config
+
+ with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
+ mock_response = AsyncMock()
+ mock_response.status_code = 201
+ mock_response.text = '{"id": 1}'
+ mock_request.return_value = mock_response
+
+ result = await registry.execute_tool("test_post", {"name": "Alice", "age": 30})
+ parsed = json.loads(result)
+ assert parsed["status_code"] == 201
+ # Verify POST method
+ assert mock_request.call_args[0][0] == "POST"
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_http_no_url(self, registry):
+ config = {"_type": "http"}
+ registry._custom_tool_configs["bad_http"] = config
+ result = await registry.execute_tool("bad_http", {})
+ assert "error" in result
+
+
+class TestToolRegistryCode:
+ """代码沙箱执行测试"""
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_execute_code_simple(self, registry):
+ config = {
+ "source": "def run(args):\n return {'sum': args['a'] + args['b']}",
+ "_type": "code",
+ }
+ registry._custom_tool_configs["calc"] = config
+ result = await registry.execute_tool("calc", {"a": 10, "b": 20})
+ parsed = json.loads(result)
+ assert parsed["sum"] == 30
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_execute_code_text_stats(self, registry):
+ config = {
+ "source": "def run(args):\n text = args.get('text', '')\n return {'len': len(text), 'words': len(text.split())}",
+ "_type": "code",
+ }
+ registry._custom_tool_configs["stats"] = config
+ result = await registry.execute_tool("stats", {"text": "hello world test"})
+ parsed = json.loads(result)
+ assert parsed["len"] == 16
+ assert parsed["words"] == 3
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_code_source_missing(self, registry):
+ config = {"_type": "code"} # no source
+ registry._custom_tool_configs["no_source"] = config
+ result = await registry.execute_tool("no_source", {})
+ assert "error" in result
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_code_no_run_function(self, registry):
+ config = {"source": "x = 1", "_type": "code"}
+ registry._custom_tool_configs["no_run"] = config
+ result = await registry.execute_tool("no_run", {})
+ assert "run" in result
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_code_runtime_error(self, registry):
+ config = {
+ "source": "def run(args):\n raise ValueError('test error')",
+ "_type": "code",
+ }
+ registry._custom_tool_configs["err"] = config
+ result = await registry.execute_tool("err", {})
+ assert "error" in result
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_code_sandbox_restriction(self, registry):
+ """验证 __builtins__ 被禁用"""
+ config = {
+ "source": "def run(args):\n import os\n return os.name",
+ "_type": "code",
+ }
+ registry._custom_tool_configs["unsafe"] = config
+ result = await registry.execute_tool("unsafe", {})
+ # import 在沙箱中应该失败
+ assert "error" in result
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_code_sandbox_no_file_access(self, registry):
+ """验证无法访问文件系统"""
+ config = {
+ "source": "def run(args):\n open('/etc/passwd')\n return 'ok'",
+ "_type": "code",
+ }
+ registry._custom_tool_configs["file_access"] = config
+ result = await registry.execute_tool("file_access", {})
+ assert "error" in result
+
+
+class TestToolRegistryTestHelpers:
+ """测试工具(不保存到 DB)"""
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_test_code_tool_success(self, registry):
+ source = "def run(args):\n return {'result': args['x'] * 2}"
+ result = await registry.test_code_tool(source, {"x": 5})
+ assert result["success"] is True
+ assert result["result"]["result"] == 10
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_test_code_tool_compile_error(self, registry):
+ source = "def run(args):\n invalid syntax{{{"
+ result = await registry.test_code_tool(source, {})
+ assert result["success"] is False
+ assert "error" in result
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_test_http_tool(self, registry):
+ with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
+ mock_response = AsyncMock()
+ mock_response.status_code = 200
+ mock_response.text = '{"ip": "8.8.8.8"}'
+ mock_request.return_value = mock_response
+
+ from app.services.tool_registry import _CODE_SAFE_GLOBALS
+ result = await registry.test_http_tool(
+ url="https://httpbin.org/get?ip={ip}",
+ method="GET",
+ headers={},
+ body=None,
+ args={"ip": "8.8.8.8"},
+ timeout=5,
+ )
+ assert result["success"] is True
+ assert result["status_code"] == 200
+
+
+class TestToolRegistryExecute:
+ """execute_tool 整体流程"""
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_execute_builtin(self, registry):
+ def hello(**kwargs):
+ return f"Hello, {kwargs.get('name', 'world')}!"
+
+ registry.register_builtin_tool("hello", hello, {"function": {"name": "hello"}})
+ result = await registry.execute_tool("hello", {"name": "Test"})
+ assert "Hello, Test!" in result
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_execute_nonexistent(self, registry):
+ result = await registry.execute_tool("no_such_tool", {})
+ assert "error" in result
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_execute_unsupported_type(self, registry):
+ config = {"_type": "unsupported"}
+ registry._custom_tool_configs["weird"] = config
+ result = await registry.execute_tool("weird", {})
+ assert "error" in result
+
+ @pytest.mark.unit
+ @pytest.mark.asyncio
+ async def test_execute_workflow_not_supported(self, registry):
+ config = {"_type": "workflow"}
+ registry._custom_tool_configs["wf"] = config
+ result = await registry.execute_tool("wf", {})
+ assert "暂不支持" in result
diff --git a/backend/tests/test_tools_api.py b/backend/tests/test_tools_api.py
new file mode 100644
index 0000000..68171b1
--- /dev/null
+++ b/backend/tests/test_tools_api.py
@@ -0,0 +1,263 @@
+"""
+工具市场 API 集成测试
+"""
+from __future__ import annotations
+
+import json
+import pytest
+from unittest.mock import patch, AsyncMock
+
+
+class TestToolsAPI:
+ """工具市场 API 测试"""
+
+ @pytest.mark.api
+ def test_list_public_tools(self, authenticated_client):
+ resp = authenticated_client.get("/api/v1/tools")
+ assert resp.status_code == 200
+ assert isinstance(resp.json(), list)
+
+ @pytest.mark.api
+ def test_list_categories(self, authenticated_client):
+ resp = authenticated_client.get("/api/v1/tools/categories")
+ assert resp.status_code == 200
+ cats = resp.json()
+ assert isinstance(cats, list)
+ # 应包含默认分类
+ assert "数据处理" in cats or "网络请求" in cats
+
+ @pytest.mark.api
+ def test_list_without_auth(self, client):
+ resp = client.get("/api/v1/tools")
+ # 未认证时默认 scope=public 应返回 401 或公开工具
+ assert resp.status_code == 401
+
+ @pytest.mark.api
+ def test_list_builtin(self, authenticated_client):
+ resp = authenticated_client.get("/api/v1/tools/builtin")
+ assert resp.status_code == 200
+ tools = resp.json()
+ assert isinstance(tools, list)
+ assert len(tools) >= 10 # 期待至少 10 个内置工具
+
+ @pytest.mark.api
+ def test_create_and_get_tool(self, authenticated_client):
+ # 创建 HTTP 工具
+ create_resp = authenticated_client.post("/api/v1/tools", json={
+ "name": "echo_test",
+ "description": "Echo test tool",
+ "category": "network",
+ "function_schema": {
+ "name": "echo_test",
+ "description": "Returns the input",
+ "parameters": {
+ "type": "object",
+ "properties": {"msg": {"type": "string"}},
+ "required": ["msg"],
+ },
+ },
+ "implementation_type": "http",
+ "implementation_config": {
+ "url": "https://httpbin.org/post",
+ "method": "POST",
+ "headers": {},
+ "timeout": 10,
+ },
+ "is_public": True,
+ })
+ assert create_resp.status_code == 201
+ data = create_resp.json()
+ assert data["name"] == "echo_test"
+ assert data["implementation_type"] == "http"
+ tool_id = data["id"]
+
+ # 获取工具详情
+ get_resp = authenticated_client.get(f"/api/v1/tools/{tool_id}")
+ assert get_resp.status_code == 200
+ assert get_resp.json()["name"] == "echo_test"
+
+ @pytest.mark.api
+ def test_create_code_tool(self, authenticated_client):
+ create_resp = authenticated_client.post("/api/v1/tools", json={
+ "name": "double_test",
+ "description": "Double a number",
+ "category": "math",
+ "function_schema": {
+ "name": "double_test",
+ "description": "Doubles a number",
+ "parameters": {
+ "type": "object",
+ "properties": {"n": {"type": "number"}},
+ "required": ["n"],
+ },
+ },
+ "implementation_type": "code",
+ "implementation_config": {
+ "source": "def run(args):\n n = args.get('n', 0)\n return {'result': n * 2}",
+ "language": "python",
+ },
+ "is_public": True,
+ })
+ assert create_resp.status_code == 201
+ data = create_resp.json()
+ assert data["name"] == "double_test"
+ assert data["implementation_type"] == "code"
+
+ @pytest.mark.api
+ def test_create_duplicate_name(self, authenticated_client):
+ # 先创建
+ authenticated_client.post("/api/v1/tools", json={
+ "name": "dup_tool",
+ "description": "Dup",
+ "function_schema": {"name": "dup_tool", "parameters": {"type": "object", "properties": {}}},
+ "implementation_type": "code",
+ "implementation_config": {"source": "def run(args):\n return {}"},
+ "is_public": False,
+ })
+ # 重复创建应报错
+ resp = authenticated_client.post("/api/v1/tools", json={
+ "name": "dup_tool",
+ "description": "Dup again",
+ "function_schema": {"name": "dup_tool", "parameters": {"type": "object", "properties": {}}},
+ "implementation_type": "code",
+ "implementation_config": {"source": "def run(args):\n return {}"},
+ "is_public": False,
+ })
+ assert resp.status_code == 400
+ assert "已存在" in resp.json().get("detail", "")
+
+ @pytest.mark.api
+ def test_invalid_implementation_type(self, authenticated_client):
+ resp = authenticated_client.post("/api/v1/tools", json={
+ "name": "bad_tool",
+ "description": "Bad",
+ "function_schema": {"name": "bad_tool", "parameters": {"type": "object", "properties": {}}},
+ "implementation_type": "invalid_type",
+ "is_public": False,
+ })
+ assert resp.status_code == 400
+
+ @pytest.mark.api
+ def test_mine_scope(self, authenticated_client):
+ resp = authenticated_client.get("/api/v1/tools?scope=mine")
+ assert resp.status_code == 200
+
+ @pytest.mark.api
+ def test_get_nonexistent_tool(self, authenticated_client):
+ resp = authenticated_client.get("/api/v1/tools/nonexistent-id")
+ assert resp.status_code == 404
+
+ @pytest.mark.api
+ @pytest.mark.asyncio
+ async def test_test_http_endpoint(self, authenticated_client):
+ """测试 HTTP 工具的测试端点"""
+ with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
+ mock_response = AsyncMock()
+ mock_response.status_code = 200
+ mock_response.text = '{"origin": "1.2.3.4"}'
+ mock_request.return_value = mock_response
+
+ resp = authenticated_client.post("/api/v1/tools/test/http", json={
+ "url": "https://httpbin.org/get",
+ "method": "GET",
+ "headers": {},
+ "body": None,
+ "args": {"ip": "8.8.8.8"},
+ "timeout": 5,
+ })
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["success"] is True
+
+ @pytest.mark.api
+ def test_test_code_endpoint(self, authenticated_client):
+ resp = authenticated_client.post("/api/v1/tools/test/code", json={
+ "source": "def run(args):\n return args.get('x', 0) + args.get('y', 0)",
+ "args": {"x": 10, "y": 20},
+ })
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["success"] is True
+ assert data["result"] == 30
+
+ @pytest.mark.api
+ def test_test_code_compile_error(self, authenticated_client):
+ resp = authenticated_client.post("/api/v1/tools/test/code", json={
+ "source": "invalid python {{{",
+ "args": {},
+ })
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["success"] is False
+ assert "error" in data
+
+ @pytest.mark.api
+ def test_use_count(self, authenticated_client):
+ # 先创建工具
+ create_resp = authenticated_client.post("/api/v1/tools", json={
+ "name": "count_test",
+ "description": "Count test",
+ "function_schema": {"name": "count_test", "parameters": {"type": "object", "properties": {}}},
+ "implementation_type": "code",
+ "implementation_config": {"source": "def run(args):\n return {}"},
+ "is_public": True,
+ })
+ assert create_resp.status_code == 201
+ tool_id = create_resp.json()["id"]
+
+ # 使用计数
+ use_resp = authenticated_client.post(f"/api/v1/tools/{tool_id}/use")
+ assert use_resp.status_code == 200
+ assert use_resp.json()["use_count"] == 1
+
+ # 再次使用
+ use_resp2 = authenticated_client.post(f"/api/v1/tools/{tool_id}/use")
+ assert use_resp2.status_code == 200
+ assert use_resp2.json()["use_count"] == 2
+
+ @pytest.mark.api
+ def test_delete_tool(self, authenticated_client):
+ # 创建
+ create_resp = authenticated_client.post("/api/v1/tools", json={
+ "name": "del_test",
+ "description": "Delete test",
+ "function_schema": {"name": "del_test", "parameters": {"type": "object", "properties": {}}},
+ "implementation_type": "code",
+ "implementation_config": {"source": "def run(args):\n return {}"},
+ "is_public": False,
+ })
+ assert create_resp.status_code == 201
+ tool_id = create_resp.json()["id"]
+
+ # 删除
+ del_resp = authenticated_client.delete(f"/api/v1/tools/{tool_id}")
+ assert del_resp.status_code == 200
+
+ # 确认已删除
+ get_resp = authenticated_client.get(f"/api/v1/tools/{tool_id}")
+ assert get_resp.status_code == 404
+
+ @pytest.mark.api
+ def test_update_tool(self, authenticated_client):
+ create_resp = authenticated_client.post("/api/v1/tools", json={
+ "name": "update_test",
+ "description": "Original desc",
+ "function_schema": {"name": "update_test", "parameters": {"type": "object", "properties": {}}},
+ "implementation_type": "code",
+ "implementation_config": {"source": "def run(args):\n return {}"},
+ "is_public": False,
+ })
+ assert create_resp.status_code == 201
+ tool_id = create_resp.json()["id"]
+
+ update_resp = authenticated_client.put(f"/api/v1/tools/{tool_id}", json={
+ "name": "update_test",
+ "description": "Updated desc",
+ "function_schema": {"name": "update_test", "parameters": {"type": "object", "properties": {}}},
+ "implementation_type": "code",
+ "implementation_config": {"source": "def run(args):\n return {'updated': True}"},
+ "is_public": True,
+ })
+ assert update_resp.status_code == 200
+ assert update_resp.json()["description"] == "Updated desc"
+ assert update_resp.json()["is_public"] is True
diff --git a/frontend/src/components/MainLayout.vue b/frontend/src/components/MainLayout.vue
index e7cc0b7..950634b 100644
--- a/frontend/src/components/MainLayout.vue
+++ b/frontend/src/components/MainLayout.vue
@@ -78,7 +78,9 @@
-