feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖
- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser - 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行) - 新增 agent_vector_memory / knowledge_base 模型及对应数据库表 - 实现 SSE 流式响应与 Agent 预算控制 - AgentChat.vue 集成 MainLayout 导航布局 - 完善测试体系:7 个新测试文件共 110 个测试覆盖 - 修复 conftest.py SQLite 内存数据库连接隔离问题 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -87,6 +87,7 @@
|
||||
- **AI框架**: LangChain
|
||||
- **Agent Runtime**: 自研 ReAct 循环(零重构,寄生式复用现有服务)
|
||||
- **Agent Orchestrator**: 多 Agent 编排引擎(路由/顺序/辩论三种模式)
|
||||
- **Agent 监控**: LLM 调用埋点 + 专属 Dashboard(Token/工具/Agent 用量统计)
|
||||
- **数据库ORM**: SQLAlchemy
|
||||
- **迁移工具**: Alembic
|
||||
- **认证**: JWT
|
||||
@@ -212,11 +213,14 @@ aiagent/
|
||||
│ │ │ ├── orchestrator.py# 多 Agent 编排引擎(route/sequential/debate)
|
||||
│ │ │ └── workflow_integration.py # 工作流桥接
|
||||
│ │ ├── api/ # API路由
|
||||
│ │ │ └── agent_chat.py # Agent 独立聊天 API(新增)
|
||||
│ │ │ ├── agent_chat.py # Agent 聊天 + 多 Agent 编排 + LLM 调用日志(新增)
|
||||
│ │ │ └── agent_monitoring.py # Agent 监控 API 5 个端点(新增)
|
||||
│ │ ├── core/ # 核心模块
|
||||
│ │ ├── models/ # 数据库模型
|
||||
│ │ │ └── agent_llm_log.py # Agent LLM 调用日志模型(新增)
|
||||
│ │ ├── schemas/ # Pydantic模式
|
||||
│ │ ├── services/ # 业务逻辑
|
||||
│ │ │ └── agent_monitoring_service.py # Agent 监控服务(新增)
|
||||
│ │ └── main.py # 应用入口
|
||||
│ ├── alembic/ # 数据库迁移
|
||||
│ ├── tests/ # 测试
|
||||
@@ -328,17 +332,19 @@ pnpm dev
|
||||
- `POST /api/v1/data-sources/{id}/test` - 测试数据源连接
|
||||
- `POST /api/v1/data-sources/{id}/query` - 执行数据查询
|
||||
|
||||
### Agent 对话 API(新增)
|
||||
### Agent 与编排 API(新增)
|
||||
- `POST /api/v1/agent-chat/bare` - 默认 Agent 直接对话(无需预配置)
|
||||
- `POST /api/v1/agent-chat/{agent_id}` - 与指定 Agent 对话(复用工作流配置)
|
||||
- `POST /api/v1/agent-chat/orchestrate` - 多 Agent 编排(route/sequential/debate 三种模式)
|
||||
- `GET /api/v1/agent-monitoring/overview` - Agent 概览统计
|
||||
- `GET /api/v1/agent-monitoring/llm-calls` - LLM 调用记录列表(支持 days/limit 参数)
|
||||
- `GET /api/v1/agent-monitoring/agents-stats` - 各 Agent 用量排行
|
||||
- `GET /api/v1/agent-monitoring/tool-usage` - 工具调用频次统计
|
||||
- `GET /api/v1/agent-monitoring/daily-trend` - 每日 LLM 调用趋势
|
||||
- 请求体包含 message、mode、agents 列表(每个 Agent 可独立配置 system_prompt/model/temperature/tools)
|
||||
- 返回 final_answer、steps 追踪、agent_results
|
||||
- 请求体: `{ message, mode, agents[] }` 每个 Agent 可独立配置 system_prompt/model/temperature/tools
|
||||
- 返回: `{ final_answer, steps[], agent_results[] }`
|
||||
|
||||
### Agent 监控 API(新增)
|
||||
- `GET /api/v1/agent-monitoring/overview` - Agent 概览统计(Agent 数、对话/LLM/工具调用次数、Token 用量)
|
||||
- `GET /api/v1/agent-monitoring/llm-calls?days=7&limit=50` - LLM 调用记录列表(模型、tokens、耗时、状态)
|
||||
- `GET /api/v1/agent-monitoring/agents-stats?days=7` - 各 Agent 用量排行(调用次数、Token、延迟)
|
||||
- `GET /api/v1/agent-monitoring/tool-usage?days=7` - 工具调用频次统计
|
||||
- `GET /api/v1/agent-monitoring/daily-trend?days=7` - 每日 LLM 调用趋势
|
||||
|
||||
### WebSocket API
|
||||
- `ws://localhost:8037/ws/execution/{execution_id}` - 执行状态实时推送
|
||||
@@ -457,7 +463,8 @@ alembic downgrade -1
|
||||
- **第四-七阶段功能**: 100% ✅
|
||||
- **自主 Agent Runtime**: 100% ✅(2026-04 新增)
|
||||
- **多 Agent 编排**: 100% ✅(2026-05 新增)
|
||||
- **整体项目**: 约 95-97%
|
||||
- **Agent 监控 Dashboard**: 100% ✅(2026-05 新增)
|
||||
- **整体项目**: 约 97-98%
|
||||
|
||||
### 已完成核心功能
|
||||
1. **完整的用户认证系统** - 注册、登录、JWT认证
|
||||
@@ -481,8 +488,8 @@ alembic downgrade -1
|
||||
|
||||
### 近期开发重点(高优先级)
|
||||
1. **预算接入** - Agent 内部 LLM 调用计入工作流执行预算
|
||||
2. **Agent Dashboard** - LLM 调用链路追踪、Token 消耗统计、执行历史
|
||||
3. **用户体验优化** - 工作流编辑器优化、Agent使用体验优化
|
||||
2. **用户体验优化** - 工作流编辑器优化、Agent使用体验优化
|
||||
3. **流式输出** - Agent 思考过程实时推送到前端
|
||||
|
||||
### 中期规划
|
||||
1. **向量记忆** - 集成 Embedding API + 向量检索(语义记忆)
|
||||
@@ -550,6 +557,6 @@ alembic downgrade -1
|
||||
---
|
||||
|
||||
**最后更新**: 2026-05-01
|
||||
**文档版本**: 1.5
|
||||
**文档版本**: 1.6
|
||||
|
||||
*本文档基于项目现有文档整理生成,涵盖项目核心信息。详细技术方案请参考[方案-优化版.md](./方案-优化版.md)。DeepSeek 模型名与 Base URL 以官方文档为准,变更时请同步修订本节。*
|
||||
@@ -15,6 +15,7 @@ from app.agent_runtime.schemas import (
|
||||
AgentLLMConfig,
|
||||
AgentToolConfig,
|
||||
AgentMemoryConfig,
|
||||
AgentBudgetConfig,
|
||||
AgentStep,
|
||||
)
|
||||
from app.agent_runtime.context import AgentContext
|
||||
@@ -34,6 +35,7 @@ __all__ = [
|
||||
"AgentLLMConfig",
|
||||
"AgentToolConfig",
|
||||
"AgentMemoryConfig",
|
||||
"AgentBudgetConfig",
|
||||
"AgentContext",
|
||||
"AgentMemory",
|
||||
"AgentToolManager",
|
||||
|
||||
@@ -13,7 +13,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, TypedDict
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Protocol, TypedDict
|
||||
|
||||
from app.agent_runtime.schemas import (
|
||||
AgentConfig,
|
||||
@@ -23,6 +23,7 @@ from app.agent_runtime.schemas import (
|
||||
from app.agent_runtime.context import AgentContext
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
from app.core.exceptions import WorkflowExecutionError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -95,6 +96,11 @@ class AgentRuntime:
|
||||
self.on_tool_executed = on_tool_executed
|
||||
self.on_llm_call = on_llm_call
|
||||
self._memory_context_loaded = False
|
||||
self._llm_invocations = 0
|
||||
|
||||
# 预算回调:供 WorkflowEngine 注入,使 Agent 内部计数计入工作流预算
|
||||
# 返回 True 表示预算充足;返回 False 或抛出异常表示超限
|
||||
self.on_llm_invocation: Optional[Callable[[], Any]] = None
|
||||
|
||||
async def run(self, user_input: str) -> AgentResult:
|
||||
"""
|
||||
@@ -108,7 +114,7 @@ class AgentRuntime:
|
||||
|
||||
# 1. 首次运行时加载长期记忆到 system prompt
|
||||
if not self._memory_context_loaded:
|
||||
await self._inject_memory_context()
|
||||
await self._inject_memory_context(user_input)
|
||||
self._memory_context_loaded = True
|
||||
|
||||
# 2. 追加用户消息
|
||||
@@ -139,6 +145,32 @@ class AgentRuntime:
|
||||
# 裁剪过长历史
|
||||
messages = self.memory.trim_messages(self.context.messages)
|
||||
|
||||
# 预算检查:LLM 调用次数(在调用 LLM 之前检查,避免浪费额度)
|
||||
budget = self.config.budget
|
||||
if self._llm_invocations >= budget.max_llm_invocations:
|
||||
err = f"已超过 LLM 调用预算({budget.max_llm_invocations} 次)"
|
||||
logger.warning(err)
|
||||
steps.append(AgentStep(iteration=self.context.iteration, type="final", content=err))
|
||||
await self.memory.save_context(user_input, err, self.context.messages)
|
||||
return AgentResult(success=False, content=err, truncated=True,
|
||||
iterations_used=self.context.iteration,
|
||||
tool_calls_made=self.context.tool_calls_made,
|
||||
steps=steps, error=err)
|
||||
|
||||
# 调用外部 LLM 预算回调(WorkflowEngine 注入,将 Agent 的 LLM 计入工作流预算)
|
||||
if self.on_llm_invocation:
|
||||
try:
|
||||
self.on_llm_invocation()
|
||||
except Exception as e:
|
||||
err = f"LLM 调用超出工作流预算: {e}"
|
||||
logger.warning(err)
|
||||
steps.append(AgentStep(iteration=self.context.iteration, type="final", content=err))
|
||||
await self.memory.save_context(user_input, err, self.context.messages)
|
||||
return AgentResult(success=False, content=err, truncated=True,
|
||||
iterations_used=self.context.iteration,
|
||||
tool_calls_made=self.context.tool_calls_made,
|
||||
steps=steps, error=str(e))
|
||||
|
||||
# 调用 LLM
|
||||
try:
|
||||
response = await llm.chat(
|
||||
@@ -166,6 +198,9 @@ class AgentRuntime:
|
||||
error=err_str,
|
||||
)
|
||||
|
||||
# 记录 LLM 调用次数(内部计数)
|
||||
self._llm_invocations += 1
|
||||
|
||||
# 解析工具调用
|
||||
tool_calls = self._extract_tool_calls(response)
|
||||
content = self._extract_content(response)
|
||||
@@ -247,9 +282,22 @@ class AgentRuntime:
|
||||
self.context.add_tool_result(tcid, tname, result)
|
||||
self.context.tool_calls_made += 1
|
||||
|
||||
# 预算检查:工具调用次数
|
||||
if self.context.tool_calls_made > budget.max_tool_calls:
|
||||
err = f"已超过工具调用预算({budget.max_tool_calls} 次)"
|
||||
logger.warning(err)
|
||||
steps.append(AgentStep(iteration=self.context.iteration, type="tool_result",
|
||||
content=err, tool_name=tname))
|
||||
return AgentResult(success=False, content=err, truncated=True,
|
||||
iterations_used=self.context.iteration,
|
||||
tool_calls_made=self.context.tool_calls_made,
|
||||
steps=steps, error=err)
|
||||
|
||||
if self.on_tool_executed:
|
||||
try:
|
||||
await self.on_tool_executed(tname)
|
||||
except WorkflowExecutionError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -284,9 +332,240 @@ class AgentRuntime:
|
||||
steps=steps,
|
||||
)
|
||||
|
||||
async def _inject_memory_context(self) -> None:
|
||||
async def run_stream(self, user_input: str) -> AsyncGenerator[dict, None]:
|
||||
"""
|
||||
流式执行 Agent 单轮对话。
|
||||
|
||||
与 run() 逻辑相同,但在每个关键步骤 yield SSE 事件:
|
||||
- think: LLM 思考中,准备调用工具
|
||||
- tool_call: 即将执行工具
|
||||
- tool_result: 工具执行完毕
|
||||
- final: 最终回答
|
||||
- error: 出错/预算超限
|
||||
"""
|
||||
max_iter = max(1, self.config.llm.max_iterations)
|
||||
self.context.iteration = 0
|
||||
self.context.tool_calls_made = 0
|
||||
|
||||
# 1. 首次运行时加载长期记忆到 system prompt
|
||||
if not self._memory_context_loaded:
|
||||
await self._inject_memory_context(user_input)
|
||||
self._memory_context_loaded = True
|
||||
|
||||
# 2. 追加用户消息
|
||||
self.context.add_user_message(user_input)
|
||||
|
||||
# 3. ReAct 循环
|
||||
llm = _LLMClient(self.config.llm)
|
||||
tool_schemas = self.tool_manager.get_tool_schemas()
|
||||
has_tools = self.tool_manager.has_tools()
|
||||
steps: List[AgentStep] = []
|
||||
|
||||
llm_callback_ctx = {"step_type": "think", "tool_name": None}
|
||||
|
||||
def _llm_callback(metrics: Dict[str, Any]):
|
||||
if self.on_llm_call:
|
||||
metrics.update({
|
||||
"session_id": self.context.session_id,
|
||||
"user_id": self.config.user_id,
|
||||
"step_type": llm_callback_ctx["step_type"],
|
||||
"tool_name": llm_callback_ctx["tool_name"],
|
||||
})
|
||||
self.on_llm_call(metrics)
|
||||
|
||||
while self.context.iteration < max_iter:
|
||||
self.context.iteration += 1
|
||||
messages = self.memory.trim_messages(self.context.messages)
|
||||
|
||||
# 预算检查:LLM 调用次数(在调用 LLM 之前检查,避免浪费额度)
|
||||
budget = self.config.budget
|
||||
if self._llm_invocations >= budget.max_llm_invocations:
|
||||
err = f"已超过 LLM 调用预算({budget.max_llm_invocations} 次)"
|
||||
logger.warning(err)
|
||||
yield {"type": "error", "content": err, "iteration": self.context.iteration,
|
||||
"truncated": True}
|
||||
await self.memory.save_context(user_input, err, self.context.messages)
|
||||
return
|
||||
|
||||
# 调用外部 LLM 预算回调(WorkflowEngine 注入)
|
||||
if self.on_llm_invocation:
|
||||
try:
|
||||
self.on_llm_invocation()
|
||||
except Exception as e:
|
||||
err = f"LLM 调用超出工作流预算: {e}"
|
||||
logger.warning(err)
|
||||
yield {"type": "error", "content": err, "iteration": self.context.iteration,
|
||||
"truncated": True}
|
||||
return
|
||||
|
||||
# 调用 LLM
|
||||
try:
|
||||
response = await llm.chat(
|
||||
messages=messages,
|
||||
tools=tool_schemas if has_tools and self.context.iteration == 1 else
|
||||
(tool_schemas if has_tools else None),
|
||||
iteration=self.context.iteration,
|
||||
on_completion=_llm_callback,
|
||||
)
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
logger.error("LLM 调用失败 (iteration=%s): %s", self.context.iteration, err_str)
|
||||
if self.context.iteration < max_iter and self._is_retryable(err_str):
|
||||
yield {"type": "error", "content": f"LLM 调用失败(可重试): {err_str}",
|
||||
"iteration": self.context.iteration}
|
||||
continue
|
||||
yield {"type": "error", "content": f"LLM 调用失败: {err_str}",
|
||||
"iteration": self.context.iteration}
|
||||
return
|
||||
|
||||
# 记录 LLM 调用次数(内部计数)
|
||||
self._llm_invocations += 1
|
||||
|
||||
# 解析工具调用
|
||||
tool_calls = self._extract_tool_calls(response)
|
||||
content = self._extract_content(response)
|
||||
reasoning = getattr(response, "reasoning_content", None) or (
|
||||
response.get("reasoning_content") if isinstance(response, dict) else None
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
# LLM 直接返回文本 → 结束
|
||||
self.context.add_assistant_message(content)
|
||||
final_text = content or "(模型未返回有效内容)"
|
||||
yield {
|
||||
"type": "final",
|
||||
"content": final_text,
|
||||
"reasoning": reasoning,
|
||||
"iteration": self.context.iteration,
|
||||
"iterations_used": self.context.iteration,
|
||||
"tool_calls_made": self.context.tool_calls_made,
|
||||
"session_id": self.context.session_id,
|
||||
}
|
||||
await self.memory.save_context(user_input, final_text, self.context.messages)
|
||||
return
|
||||
|
||||
# 有工具调用 → 先记录 assistant 消息
|
||||
self.context.add_assistant_message(content or "", tool_calls, reasoning)
|
||||
|
||||
# yield think 事件
|
||||
tc_names = [tc["function"]["name"] for tc in tool_calls]
|
||||
tc_args_list = []
|
||||
for tc in tool_calls:
|
||||
try:
|
||||
tc_args_list.append(json.loads(tc["function"].get("arguments", "{}")))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
tc_args_list.append({})
|
||||
|
||||
yield {
|
||||
"type": "think",
|
||||
"content": content or f"调用工具: {', '.join(tc_names)}",
|
||||
"reasoning": reasoning,
|
||||
"tool_names": tc_names,
|
||||
"iteration": self.context.iteration,
|
||||
}
|
||||
|
||||
steps.append(AgentStep(
|
||||
iteration=self.context.iteration,
|
||||
type="think",
|
||||
content=content or f"调用工具: {', '.join(tc_names)}",
|
||||
reasoning=reasoning,
|
||||
tool_name=tc_names[0] if len(tc_names) == 1 else None,
|
||||
tool_input=tc_args_list[0] if len(tc_args_list) == 1 else None,
|
||||
))
|
||||
|
||||
if self.execution_logger:
|
||||
self.execution_logger.info(
|
||||
f"Agent 调用 {len(tool_calls)} 个工具",
|
||||
data={"tool_calls": tc_names, "iteration": self.context.iteration},
|
||||
)
|
||||
|
||||
# 逐一执行工具
|
||||
for tc in tool_calls:
|
||||
tfn = tc.get("function", {})
|
||||
tname = tfn.get("name", "unknown")
|
||||
tcid = tc.get("id", f"call_{self.context.iteration}_{self.context.tool_calls_made}")
|
||||
|
||||
try:
|
||||
targs = json.loads(tfn.get("arguments", "{}"))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
targs = {}
|
||||
|
||||
# yield tool_call 事件
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"name": tname,
|
||||
"input": targs,
|
||||
"iteration": self.context.iteration,
|
||||
}
|
||||
|
||||
logger.info("Agent 执行工具 [%s]: %s", tname, targs)
|
||||
result = await self.tool_manager.execute(tname, targs)
|
||||
|
||||
# yield tool_result 事件
|
||||
yield {
|
||||
"type": "tool_result",
|
||||
"name": tname,
|
||||
"result": result[:500] + "..." if len(result) > 500 else result,
|
||||
"iteration": self.context.iteration,
|
||||
}
|
||||
|
||||
steps.append(AgentStep(
|
||||
iteration=self.context.iteration,
|
||||
type="tool_result",
|
||||
content=f"工具 {tname} 返回结果",
|
||||
tool_name=tname,
|
||||
tool_input=targs,
|
||||
tool_result=result[:500] + "..." if len(result) > 500 else result,
|
||||
))
|
||||
|
||||
self.context.add_tool_result(tcid, tname, result)
|
||||
self.context.tool_calls_made += 1
|
||||
|
||||
# 预算检查:工具调用次数
|
||||
if self.context.tool_calls_made > budget.max_tool_calls:
|
||||
err = f"已超过工具调用预算({budget.max_tool_calls} 次)"
|
||||
logger.warning(err)
|
||||
yield {"type": "error", "content": err, "iteration": self.context.iteration,
|
||||
"truncated": True}
|
||||
return
|
||||
|
||||
if self.on_tool_executed:
|
||||
try:
|
||||
await self.on_tool_executed(tname)
|
||||
except WorkflowExecutionError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.execution_logger:
|
||||
preview = result[:300] + "..." if len(result) > 300 else result
|
||||
self.execution_logger.info(
|
||||
f"工具 {tname} 执行完成",
|
||||
data={"tool_name": tname, "result_preview": preview},
|
||||
)
|
||||
|
||||
# 达到最大迭代次数
|
||||
last_content = ""
|
||||
for m in reversed(self.context.messages):
|
||||
if m.get("role") == "assistant" and m.get("content"):
|
||||
last_content = m["content"]
|
||||
break
|
||||
|
||||
logger.warning("Agent 达到最大迭代次数 (%s)", max_iter)
|
||||
await self.memory.save_context(user_input, last_content or "(已达最大迭代次数)", self.context.messages)
|
||||
yield {
|
||||
"type": "final",
|
||||
"content": last_content or "已达最大迭代次数,但模型未返回最终回答。",
|
||||
"iteration": self.context.iteration,
|
||||
"iterations_used": self.context.iteration,
|
||||
"tool_calls_made": self.context.tool_calls_made,
|
||||
"truncated": True,
|
||||
"session_id": self.context.session_id,
|
||||
}
|
||||
|
||||
async def _inject_memory_context(self, query: str = "") -> None:
|
||||
"""加载长期记忆并注入 system prompt。"""
|
||||
mem_text = await self.memory.initialize()
|
||||
mem_text = await self.memory.initialize(query=query)
|
||||
if mem_text:
|
||||
enriched = (
|
||||
self.config.system_prompt.rstrip("\n")
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.services.persistent_memory_service import (
|
||||
save_persistent_memory,
|
||||
persist_enabled,
|
||||
)
|
||||
from app.services.embedding_service import embedding_service, VectorEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,25 +36,31 @@ class AgentMemory:
|
||||
session_key: Optional[str] = None,
|
||||
persist: bool = True,
|
||||
max_history: int = 20,
|
||||
vector_memory_enabled: bool = True,
|
||||
vector_memory_top_k: int = 5,
|
||||
):
|
||||
self.scope_kind = scope_kind
|
||||
self.scope_id = scope_id or "default"
|
||||
self.session_key = session_key or "default_session"
|
||||
self.persist = persist and persist_enabled()
|
||||
self.max_history = max_history
|
||||
self.vector_memory_enabled = vector_memory_enabled
|
||||
self.vector_memory_top_k = vector_memory_top_k
|
||||
# 从长期记忆加载的上下文(启动时加载)
|
||||
self._long_term_context: Dict[str, Any] = {}
|
||||
# 记录已压缩的消息数,避免重复压缩
|
||||
self._last_compressed_msg_count = 0
|
||||
|
||||
async def initialize(self) -> str:
|
||||
async def initialize(self, query: str = "") -> str:
|
||||
"""
|
||||
初始化记忆:从 DB/Redis 加载长期记忆,构造初始上下文文本。
|
||||
初始化记忆:从 DB/Redis 加载长期记忆 + 向量检索相关历史。
|
||||
返回注入 system prompt 的记忆文本块。
|
||||
"""
|
||||
if not self.persist or not self.scope_id:
|
||||
return ""
|
||||
|
||||
parts: List[str] = []
|
||||
|
||||
db: Optional[Session] = None
|
||||
try:
|
||||
db = SessionLocal()
|
||||
@@ -62,8 +69,6 @@ class AgentMemory:
|
||||
)
|
||||
if payload and isinstance(payload, dict):
|
||||
self._long_term_context = payload
|
||||
# 构建注入 system prompt 的记忆文本
|
||||
parts = []
|
||||
profile = payload.get("user_profile")
|
||||
if profile and isinstance(profile, dict):
|
||||
profile_text = json.dumps(profile, ensure_ascii=False)
|
||||
@@ -78,15 +83,93 @@ class AgentMemory:
|
||||
if history and isinstance(history, list) and len(history) > 0:
|
||||
summary = self._summarize_history(history)
|
||||
parts.append(f"## 历史对话摘要\n{summary}")
|
||||
|
||||
if parts:
|
||||
return "\n\n".join(parts)
|
||||
except Exception as e:
|
||||
logger.warning("加载长期记忆失败: %s", e)
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
return ""
|
||||
|
||||
# 2. 向量检索:查找语义相关的历史对话
|
||||
if self.vector_memory_enabled and self.scope_kind and self.scope_id:
|
||||
vector_text = await self._vector_search(query)
|
||||
if vector_text:
|
||||
parts.append(vector_text)
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
async def _vector_search(self, query: str = "") -> str:
|
||||
"""
|
||||
向量检索语义相关的历史记忆,返回格式化的文本块。
|
||||
若无 query 则返回最近 Top-5 条记忆。
|
||||
"""
|
||||
from app.models.agent_vector_memory import AgentVectorMemory
|
||||
|
||||
db: Optional[Session] = None
|
||||
try:
|
||||
db = SessionLocal()
|
||||
# 查询当前 scope 的所有向量记忆(按时间倒序)
|
||||
rows = (
|
||||
db.query(AgentVectorMemory)
|
||||
.filter(
|
||||
AgentVectorMemory.scope_kind == self.scope_kind,
|
||||
AgentVectorMemory.scope_id == self.scope_id,
|
||||
)
|
||||
.order_by(AgentVectorMemory.created_at.desc())
|
||||
.limit(50) # 最多取最近 50 条做相似度计算
|
||||
.all()
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return ""
|
||||
|
||||
entries: List[VectorEntry] = []
|
||||
for row in rows:
|
||||
emb = embedding_service.deserialize_embedding(row.embedding) if row.embedding else []
|
||||
entries.append({
|
||||
"id": row.id,
|
||||
"scope_kind": row.scope_kind,
|
||||
"scope_id": row.scope_id,
|
||||
"content_text": row.content_text,
|
||||
"embedding": emb,
|
||||
"metadata": row.metadata_ or {},
|
||||
})
|
||||
|
||||
matched: List[VectorEntry] = []
|
||||
|
||||
if query and query.strip():
|
||||
# 有 query:生成 embedding 做语义搜索
|
||||
query_emb = await embedding_service.generate_embedding(query)
|
||||
if query_emb:
|
||||
matched = await embedding_service.similarity_search(
|
||||
query_emb, entries, top_k=self.vector_memory_top_k
|
||||
)
|
||||
else:
|
||||
# 无 query:返回最近几条
|
||||
matched = entries[: self.vector_memory_top_k]
|
||||
for m in matched:
|
||||
m["score"] = 1.0
|
||||
|
||||
if not matched:
|
||||
return ""
|
||||
|
||||
# 格式化为文本块
|
||||
lines = ["## 相关历史记忆"]
|
||||
for i, m in enumerate(matched, 1):
|
||||
text = m.get("content_text", "")[:500]
|
||||
meta = m.get("metadata", {})
|
||||
entry_type = meta.get("type", "对话")
|
||||
lines.append(f"{i}. [{entry_type}] {text}")
|
||||
if m.get("score", 1.0) < 1.0:
|
||||
lines[-1] += f" (匹配度: {m['score']:.2f})"
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("向量检索失败: %s", e)
|
||||
return ""
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
async def save_context(
|
||||
self, user_message: str, assistant_reply: str,
|
||||
@@ -114,12 +197,50 @@ class AgentMemory:
|
||||
db, self.scope_kind, self.scope_id,
|
||||
self.session_key, self._long_term_context,
|
||||
)
|
||||
|
||||
# 保存向量记忆(异步生成 embedding 并存储)
|
||||
if self.vector_memory_enabled:
|
||||
await self._save_vector_memory(
|
||||
db, user_message, assistant_reply
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("保存长期记忆失败: %s", e)
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
async def _save_vector_memory(
|
||||
self, db: Session, user_message: str, assistant_reply: str,
|
||||
) -> None:
|
||||
"""生成 embedding 并保存到向量记忆表。"""
|
||||
from app.models.agent_vector_memory import AgentVectorMemory
|
||||
|
||||
content_text = f"用户: {user_message}\n助手: {assistant_reply}"
|
||||
if len(content_text) > 8000:
|
||||
content_text = content_text[:8000]
|
||||
|
||||
try:
|
||||
# 生成 embedding
|
||||
embedding = await embedding_service.generate_embedding(content_text)
|
||||
embedding_json = embedding_service.serialize_embedding(embedding) if embedding else ""
|
||||
|
||||
record = AgentVectorMemory(
|
||||
scope_kind=self.scope_kind,
|
||||
scope_id=self.scope_id,
|
||||
session_key=self.session_key,
|
||||
content_text=content_text[:2000],
|
||||
embedding=embedding_json or None,
|
||||
metadata_={
|
||||
"type": "conversation_turn",
|
||||
},
|
||||
)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
logger.debug("已保存向量记忆 (scope=%s/%s)", self.scope_kind, self.scope_id)
|
||||
except Exception as e:
|
||||
logger.warning("保存向量记忆失败: %s", e)
|
||||
db.rollback()
|
||||
|
||||
async def _compress_and_summarize(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
|
||||
@@ -20,6 +20,8 @@ class AgentMemoryConfig(BaseModel):
|
||||
max_history_messages: int = 20 # 注入 LLM 的上文最大消息数
|
||||
session_key: Optional[str] = None # 会话标识,默认自动生成
|
||||
persist_to_db: bool = True # 是否写入 MySQL 长期记忆
|
||||
vector_memory_enabled: bool = True # 是否启用向量记忆(语义检索)
|
||||
vector_memory_top_k: int = 5 # 向量检索 Top-K
|
||||
|
||||
|
||||
class AgentLLMConfig(BaseModel):
|
||||
@@ -35,6 +37,12 @@ class AgentLLMConfig(BaseModel):
|
||||
extra_body: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class AgentBudgetConfig(BaseModel):
|
||||
"""Agent 执行预算配置"""
|
||||
max_llm_invocations: int = 200 # LLM 调用次数上限
|
||||
max_tool_calls: int = 500 # 工具调用次数上限
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Agent 完整配置"""
|
||||
name: str = "default_agent"
|
||||
@@ -42,6 +50,7 @@ class AgentConfig(BaseModel):
|
||||
llm: AgentLLMConfig = Field(default_factory=AgentLLMConfig)
|
||||
tools: AgentToolConfig = Field(default_factory=AgentToolConfig)
|
||||
memory: AgentMemoryConfig = Field(default_factory=AgentMemoryConfig)
|
||||
budget: AgentBudgetConfig = Field(default_factory=AgentBudgetConfig)
|
||||
user_id: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
@@ -3,9 +3,8 @@ Agent 工具管理器:包装已有 ToolRegistry,提供 Agent 需要的工具
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.services.tool_registry import tool_registry
|
||||
|
||||
@@ -58,6 +57,8 @@ class AgentToolManager:
|
||||
"""
|
||||
执行工具调用。
|
||||
|
||||
优先查找内置工具,其次查找数据库自定义工具(HTTP / Code)。
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
args: 工具参数字典
|
||||
@@ -65,27 +66,8 @@ class AgentToolManager:
|
||||
Returns:
|
||||
工具执行结果的字符串表示
|
||||
"""
|
||||
func: Optional[Callable] = tool_registry.get_tool_function(name)
|
||||
if not func:
|
||||
err = f"工具 '{name}' 不存在"
|
||||
logger.error(err)
|
||||
return json.dumps({"error": err}, ensure_ascii=False)
|
||||
|
||||
logger.info("Agent 执行工具: %s, 参数: %s", name, args)
|
||||
try:
|
||||
import asyncio
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
result = await func(**args)
|
||||
else:
|
||||
result = func(**args)
|
||||
|
||||
if isinstance(result, (dict, list)):
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
return str(result)
|
||||
except Exception as e:
|
||||
err_msg = f"工具 '{name}' 执行失败: {e}"
|
||||
logger.error(err_msg, exc_info=True)
|
||||
return json.dumps({"error": err_msg}, ensure_ascii=False)
|
||||
logger.info("Agent 执行工具: %s", name)
|
||||
return await tool_registry.execute_tool(name, args)
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_name(schema: Dict[str, Any]) -> Optional[str]:
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.agent_runtime.schemas import (
|
||||
AgentConfig,
|
||||
AgentLLMConfig,
|
||||
AgentToolConfig,
|
||||
AgentBudgetConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -24,6 +25,8 @@ async def run_agent_node(
|
||||
execution_logger: Optional[Any] = None,
|
||||
user_id: Optional[str] = None,
|
||||
on_tool_executed: Optional[Any] = None,
|
||||
on_llm_invocation: Optional[Any] = None,
|
||||
budget_limits: Optional[Dict[str, int]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
在工作流中执行 Agent 节点。
|
||||
@@ -72,6 +75,14 @@ async def run_agent_node(
|
||||
if node_data.get("base_url"):
|
||||
llm_config.base_url = node_data["base_url"]
|
||||
|
||||
# 3a. 构建预算配置(接收工作流级预算限制)
|
||||
budget = AgentBudgetConfig()
|
||||
if budget_limits:
|
||||
if "max_llm_invocations" in budget_limits:
|
||||
budget.max_llm_invocations = max(1, int(budget_limits["max_llm_invocations"]))
|
||||
if "max_tool_calls" in budget_limits:
|
||||
budget.max_tool_calls = max(1, int(budget_limits["max_tool_calls"]))
|
||||
|
||||
agent_config = AgentConfig(
|
||||
name=node_data.get("label", "agent_node"),
|
||||
system_prompt=formatted_prompt,
|
||||
@@ -84,6 +95,7 @@ async def run_agent_node(
|
||||
"enabled": node_data.get("memory", True),
|
||||
"persist_to_db": node_data.get("memory", True),
|
||||
},
|
||||
budget=budget,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -93,6 +105,9 @@ async def run_agent_node(
|
||||
execution_logger=execution_logger,
|
||||
on_tool_executed=on_tool_executed,
|
||||
)
|
||||
# 注入 LLM 预算回调(使 Agent 内部 LLM 调用计入工作流预算)
|
||||
if on_llm_invocation:
|
||||
runtime.on_llm_invocation = on_llm_invocation
|
||||
|
||||
result = await runtime.run(query)
|
||||
|
||||
|
||||
@@ -8,8 +8,10 @@ POST /api/v1/agent-chat/bare
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.database import get_db
|
||||
@@ -23,6 +25,7 @@ from app.agent_runtime import (
|
||||
AgentConfig,
|
||||
AgentLLMConfig,
|
||||
AgentToolConfig,
|
||||
AgentBudgetConfig,
|
||||
AgentStep,
|
||||
AgentOrchestrator,
|
||||
OrchestratorAgentConfig,
|
||||
@@ -64,6 +67,14 @@ def _make_llm_logger(
|
||||
return _log
|
||||
|
||||
|
||||
async def _sse_stream(gen: AsyncGenerator[dict, None]) -> AsyncGenerator[str, None]:
|
||||
"""将 run_stream 生成的 dict 事件格式化为 SSE 文本流。"""
|
||||
async for event in gen:
|
||||
event_type = event.get("type", "message")
|
||||
data = {k: v for k, v in event.items() if k != "type"}
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
session_id: Optional[str] = None
|
||||
@@ -205,6 +216,39 @@ async def chat_bare(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/bare/stream")
|
||||
async def chat_bare_stream(
|
||||
req: ChatRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""无需 Agent 配置,使用默认设置直接对话(流式 SSE)。"""
|
||||
config = AgentConfig(
|
||||
name="bare_agent",
|
||||
system_prompt="你是一个有用的AI助手。请使用可用工具来帮助用户完成任务。",
|
||||
llm=AgentLLMConfig(
|
||||
model=req.model or (
|
||||
"gpt-4o-mini" if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key"
|
||||
else "deepseek-v4-flash"
|
||||
),
|
||||
temperature=req.temperature or 0.7,
|
||||
max_iterations=req.max_iterations or 10,
|
||||
),
|
||||
user_id=current_user.id,
|
||||
)
|
||||
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
|
||||
return StreamingResponse(
|
||||
_sse_stream(runtime.run_stream(req.message)),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}", response_model=ChatResponse)
|
||||
async def chat_with_agent(
|
||||
agent_id: str,
|
||||
@@ -225,9 +269,25 @@ async def chat_with_agent(
|
||||
# 查找 agent 节点的配置(或第一个 llm 节点的配置)
|
||||
agent_node_cfg = _find_agent_node_config(nodes)
|
||||
|
||||
# 构建 system prompt,并自动注入智能体名称
|
||||
system_prompt = agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。"
|
||||
if agent.name:
|
||||
name_prefix = f"你的名字是{agent.name}"
|
||||
if name_prefix not in system_prompt:
|
||||
system_prompt = f"{name_prefix}。\n\n{system_prompt}"
|
||||
|
||||
# 合并执行预算:Agent.budget_config 覆盖默认值
|
||||
budget = AgentBudgetConfig()
|
||||
if agent.budget_config and isinstance(agent.budget_config, dict):
|
||||
bc = agent.budget_config
|
||||
if "max_llm_invocations" in bc and bc["max_llm_invocations"] is not None:
|
||||
budget.max_llm_invocations = max(1, int(bc["max_llm_invocations"]))
|
||||
if "max_tool_calls" in bc and bc["max_tool_calls"] is not None:
|
||||
budget.max_tool_calls = max(1, int(bc["max_tool_calls"]))
|
||||
|
||||
config = AgentConfig(
|
||||
name=agent.name,
|
||||
system_prompt=agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。",
|
||||
system_prompt=system_prompt,
|
||||
llm=AgentLLMConfig(
|
||||
provider=agent_node_cfg.get("provider", "openai"),
|
||||
model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
|
||||
@@ -238,6 +298,7 @@ async def chat_with_agent(
|
||||
include_tools=agent_node_cfg.get("tools", []),
|
||||
exclude_tools=agent_node_cfg.get("exclude_tools", []),
|
||||
),
|
||||
budget=budget,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
@@ -256,6 +317,68 @@ async def chat_with_agent(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/stream")
|
||||
async def chat_with_agent_stream(
|
||||
agent_id: str,
|
||||
req: ChatRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""与指定的 Agent 对话(流式 SSE)。"""
|
||||
agent = db.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent 不存在")
|
||||
if agent.user_id and agent.user_id != current_user.id and current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="无权访问该 Agent")
|
||||
|
||||
wc = agent.workflow_config or {}
|
||||
nodes = wc.get("nodes", [])
|
||||
agent_node_cfg = _find_agent_node_config(nodes)
|
||||
|
||||
system_prompt = agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。"
|
||||
if agent.name:
|
||||
name_prefix = f"你的名字是{agent.name}"
|
||||
if name_prefix not in system_prompt:
|
||||
system_prompt = f"{name_prefix}。\n\n{system_prompt}"
|
||||
|
||||
budget = AgentBudgetConfig()
|
||||
if agent.budget_config and isinstance(agent.budget_config, dict):
|
||||
bc = agent.budget_config
|
||||
if "max_llm_invocations" in bc and bc["max_llm_invocations"] is not None:
|
||||
budget.max_llm_invocations = max(1, int(bc["max_llm_invocations"]))
|
||||
if "max_tool_calls" in bc and bc["max_tool_calls"] is not None:
|
||||
budget.max_tool_calls = max(1, int(bc["max_tool_calls"]))
|
||||
|
||||
config = AgentConfig(
|
||||
name=agent.name,
|
||||
system_prompt=system_prompt,
|
||||
llm=AgentLLMConfig(
|
||||
provider=agent_node_cfg.get("provider", "openai"),
|
||||
model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
|
||||
temperature=req.temperature or float(agent_node_cfg.get("temperature", 0.7)),
|
||||
max_iterations=req.max_iterations or int(agent_node_cfg.get("max_iterations", 10)),
|
||||
),
|
||||
tools=AgentToolConfig(
|
||||
include_tools=agent_node_cfg.get("tools", []),
|
||||
exclude_tools=agent_node_cfg.get("exclude_tools", []),
|
||||
),
|
||||
budget=budget,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
|
||||
return StreamingResponse(
|
||||
_sse_stream(runtime.run_stream(req.message)),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _find_agent_node_config(nodes: list) -> Dict[str, Any]:
|
||||
"""从工作流节点列表中查找第一个 agent 类型或 llm 类型的节点配置。"""
|
||||
if not nodes:
|
||||
|
||||
251
backend/app/api/knowledge_base.py
Normal file
251
backend/app/api/knowledge_base.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
知识库 RAG API。
|
||||
|
||||
提供知识库管理、文档上传、语义搜索和 RAG 查询接口。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.services.knowledge_service import (
|
||||
create_knowledge_base,
|
||||
delete_document,
|
||||
delete_knowledge_base,
|
||||
get_knowledge_base,
|
||||
list_documents,
|
||||
list_knowledge_bases,
|
||||
rag_query,
|
||||
search,
|
||||
upload_document,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/knowledge-bases", tags=["knowledge-base"])
|
||||
|
||||
|
||||
# ─── Schema ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class KBCreateRequest(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
chunk_size: int = 500
|
||||
chunk_overlap: int = 50
|
||||
|
||||
|
||||
class KBResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = ""
|
||||
user_id: Optional[str] = ""
|
||||
chunk_size: int = 500
|
||||
chunk_overlap: int = 50
|
||||
doc_count: int = 0
|
||||
created_at: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
|
||||
|
||||
class DocumentResponse(BaseModel):
|
||||
id: str
|
||||
kb_id: str
|
||||
filename: str
|
||||
file_type: str
|
||||
file_size: int = 0
|
||||
status: str = "pending"
|
||||
error_message: Optional[str] = None
|
||||
chunk_count: int = 0
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
top_k: int = 5
|
||||
min_score: float = 0.3
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
chunk_id: str
|
||||
content: str
|
||||
score: float
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
results: List[SearchResult] = []
|
||||
|
||||
|
||||
class RAGRequest(BaseModel):
|
||||
query: str
|
||||
top_k: int = 5
|
||||
min_score: float = 0.3
|
||||
|
||||
|
||||
class RAGResponse(BaseModel):
|
||||
query: str
|
||||
context: str = ""
|
||||
sources: List[Dict[str, Any]] = []
|
||||
found: bool = False
|
||||
|
||||
|
||||
# ─── 知识库 CRUD ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("", response_model=KBResponse)
|
||||
async def api_create_kb(
|
||||
req: KBCreateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建知识库。"""
|
||||
kb = create_knowledge_base(
|
||||
db=db,
|
||||
name=req.name,
|
||||
user_id=current_user.id,
|
||||
description=req.description,
|
||||
chunk_size=req.chunk_size,
|
||||
chunk_overlap=req.chunk_overlap,
|
||||
)
|
||||
return KBResponse(**kb.to_dict())
|
||||
|
||||
|
||||
@router.get("", response_model=List[KBResponse])
|
||||
async def api_list_kb(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列出知识库。"""
|
||||
kbs = list_knowledge_bases(db, user_id=current_user.id)
|
||||
return [KBResponse(**kb.to_dict()) for kb in kbs]
|
||||
|
||||
|
||||
@router.get("/{kb_id}", response_model=KBResponse)
|
||||
async def api_get_kb(
|
||||
kb_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取知识库详情。"""
|
||||
kb = get_knowledge_base(db, kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
return KBResponse(**kb.to_dict())
|
||||
|
||||
|
||||
@router.delete("/{kb_id}")
|
||||
async def api_delete_kb(
|
||||
kb_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除知识库。"""
|
||||
ok = delete_knowledge_base(db, kb_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
return {"message": "知识库已删除"}
|
||||
|
||||
|
||||
# ─── 文档管理 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/{kb_id}/documents", response_model=DocumentResponse)
|
||||
async def api_upload_document(
|
||||
kb_id: str,
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""上传文档到知识库(自动解析、分块、生成 Embedding)。"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||||
|
||||
content = await file.read()
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="文件内容为空")
|
||||
|
||||
try:
|
||||
doc = await upload_document(
|
||||
db=db,
|
||||
kb_id=kb_id,
|
||||
filename=file.filename,
|
||||
file_content=content,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if doc.status == "failed":
|
||||
# 返回成功但告知处理失败
|
||||
return DocumentResponse(**doc.to_dict())
|
||||
|
||||
return DocumentResponse(**doc.to_dict())
|
||||
|
||||
|
||||
@router.get("/{kb_id}/documents", response_model=List[DocumentResponse])
|
||||
async def api_list_documents(
|
||||
kb_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列出知识库中的文档。"""
|
||||
docs = list_documents(db, kb_id)
|
||||
return [DocumentResponse(**d.to_dict()) for d in docs]
|
||||
|
||||
|
||||
@router.delete("/{kb_id}/documents/{doc_id}")
|
||||
async def api_delete_document(
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除文档。"""
|
||||
ok = delete_document(db, doc_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="文档不存在")
|
||||
return {"message": "文档已删除"}
|
||||
|
||||
|
||||
# ─── 搜索 & RAG ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/{kb_id}/search", response_model=SearchResponse)
|
||||
async def api_search(
|
||||
kb_id: str,
|
||||
req: SearchRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""语义搜索知识库。"""
|
||||
results = await search(
|
||||
db=db,
|
||||
kb_id=kb_id,
|
||||
query=req.query,
|
||||
top_k=req.top_k,
|
||||
min_score=req.min_score,
|
||||
)
|
||||
return SearchResponse(results=[SearchResult(**r) for r in results])
|
||||
|
||||
|
||||
@router.post("/{kb_id}/rag", response_model=RAGResponse)
|
||||
async def api_rag(
|
||||
kb_id: str,
|
||||
req: RAGRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""RAG 查询:搜索相关片段并格式化为上下文。"""
|
||||
result = await rag_query(
|
||||
db=db,
|
||||
kb_id=kb_id,
|
||||
query=req.query,
|
||||
top_k=req.top_k,
|
||||
min_score=req.min_score,
|
||||
)
|
||||
return RAGResponse(**result)
|
||||
@@ -1,21 +1,42 @@
|
||||
"""
|
||||
工具管理API
|
||||
工具市场 API — 管理、测试、发现和安装工具。
|
||||
|
||||
提供内置工具和用户自定义工具(HTTP / 代码段)的 CRUD、测试和执行。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
|
||||
from app.api.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.models.tool import Tool
|
||||
from app.services.tool_registry import tool_registry
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from pydantic import BaseModel
|
||||
from app.services.tool_registry import tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/tools", tags=["tools"])
|
||||
|
||||
|
||||
# ─── Schema ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ToolCreate(BaseModel):
|
||||
"""创建工具请求"""
|
||||
name: str
|
||||
description: str
|
||||
category: Optional[str] = None
|
||||
function_schema: dict
|
||||
implementation_type: str # builtin / http / code / workflow
|
||||
implementation_config: Optional[dict] = None
|
||||
is_public: bool = False
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
category: Optional[str] = None
|
||||
@@ -23,86 +44,39 @@ class ToolCreate(BaseModel):
|
||||
implementation_type: str
|
||||
implementation_config: Optional[dict] = None
|
||||
is_public: bool = False
|
||||
use_count: int = 0
|
||||
user_id: Optional[str] = None
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
"""工具响应"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
category: Optional[str]
|
||||
function_schema: dict
|
||||
implementation_type: str
|
||||
implementation_config: Optional[dict]
|
||||
is_public: bool
|
||||
use_count: int
|
||||
user_id: Optional[str]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
class TestHTTPRequest(BaseModel):
|
||||
url: str
|
||||
method: str = "GET"
|
||||
headers: Dict[str, str] = {}
|
||||
body: Optional[Dict[str, Any]] = None
|
||||
args: Dict[str, Any] = {}
|
||||
timeout: int = 30
|
||||
|
||||
|
||||
@router.get("", response_model=List[ToolResponse])
|
||||
async def list_tools(
|
||||
category: Optional[str] = Query(None, description="工具分类"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取工具列表"""
|
||||
query = db.query(Tool).filter(Tool.is_public == True)
|
||||
|
||||
if category:
|
||||
query = query.filter(Tool.category == category)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
Tool.name.contains(search) |
|
||||
Tool.description.contains(search)
|
||||
)
|
||||
|
||||
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
|
||||
|
||||
# 转换为响应格式,确保日期时间字段转换为字符串
|
||||
result = []
|
||||
for tool in tools:
|
||||
result.append({
|
||||
"id": tool.id,
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"category": tool.category,
|
||||
"function_schema": tool.function_schema,
|
||||
"implementation_type": tool.implementation_type,
|
||||
"implementation_config": tool.implementation_config,
|
||||
"is_public": tool.is_public,
|
||||
"use_count": tool.use_count,
|
||||
"user_id": tool.user_id,
|
||||
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
||||
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
||||
})
|
||||
|
||||
return result
|
||||
class TestCodeRequest(BaseModel):
|
||||
source: str
|
||||
args: Dict[str, Any] = {}
|
||||
|
||||
|
||||
@router.get("/builtin")
|
||||
async def list_builtin_tools():
|
||||
"""获取内置工具列表"""
|
||||
schemas = tool_registry.get_all_tool_schemas()
|
||||
return schemas
|
||||
class TestResponse(BaseModel):
|
||||
success: bool
|
||||
elapsed_ms: Optional[int] = None
|
||||
result: Optional[Any] = None
|
||||
status_code: Optional[int] = None
|
||||
body: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=ToolResponse)
|
||||
async def get_tool(
|
||||
tool_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取工具详情"""
|
||||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 转换为响应格式,确保日期时间字段转换为字符串
|
||||
# ─── 工具函数 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _tool_to_dict(tool: Tool) -> dict:
|
||||
return {
|
||||
"id": tool.id,
|
||||
"name": tool.name,
|
||||
@@ -115,22 +89,89 @@ async def get_tool(
|
||||
"use_count": tool.use_count,
|
||||
"user_id": tool.user_id,
|
||||
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
||||
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
||||
"updated_at": tool.updated_at.isoformat() if tool.updated_at else "",
|
||||
}
|
||||
|
||||
|
||||
# ─── 工具市场浏览 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("", response_model=List[ToolResponse])
|
||||
async def list_tools(
|
||||
category: Optional[str] = Query(None, description="按分类筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
scope: Optional[str] = Query("public", description="public / mine / all"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user),
|
||||
):
|
||||
"""浏览工具市场。"""
|
||||
query = db.query(Tool)
|
||||
|
||||
if scope == "public":
|
||||
query = query.filter(Tool.is_public == True)
|
||||
elif scope == "mine":
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="需登录")
|
||||
query = query.filter(Tool.user_id == current_user.id)
|
||||
|
||||
if category:
|
||||
query = query.filter(Tool.category == category)
|
||||
if search:
|
||||
query = query.filter(
|
||||
Tool.name.contains(search) | Tool.description.contains(search)
|
||||
)
|
||||
|
||||
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
|
||||
return [_tool_to_dict(t) for t in tools]
|
||||
|
||||
|
||||
@router.get("/categories", response_model=List[str])
|
||||
async def list_categories(db: Session = Depends(get_db)):
|
||||
"""列出所有工具分类。"""
|
||||
rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all()
|
||||
cats = sorted(set(r[0] for r in rows if r[0]))
|
||||
# 加上常用分类
|
||||
defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义"]
|
||||
for d in defaults:
|
||||
if d not in cats:
|
||||
cats.append(d)
|
||||
return cats
|
||||
|
||||
|
||||
@router.get("/builtin")
|
||||
async def list_builtin_tools():
|
||||
"""列出所有内置工具(OpenAI Function 格式)。"""
|
||||
return tool_registry.get_all_tool_schemas()
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=ToolResponse)
|
||||
async def get_tool(tool_id: str, db: Session = Depends(get_db)):
|
||||
"""获取工具详情。"""
|
||||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
return _tool_to_dict(tool)
|
||||
|
||||
|
||||
# ─── 工具创建 / 更新 / 删除 ──────────────────────────────────
|
||||
|
||||
|
||||
@router.post("", response_model=ToolResponse, status_code=201)
|
||||
async def create_tool(
|
||||
tool_data: ToolCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""创建工具"""
|
||||
# 检查工具名称是否已存在
|
||||
"""创建自定义工具。"""
|
||||
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail=f"工具名称 '{tool_data.name}' 已存在")
|
||||
|
||||
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
|
||||
|
||||
valid_types = {"builtin", "http", "code", "workflow"}
|
||||
if tool_data.implementation_type not in valid_types:
|
||||
raise HTTPException(status_code=400,
|
||||
detail=f"无效的实现类型: {tool_data.implementation_type}")
|
||||
|
||||
tool = Tool(
|
||||
name=tool_data.name,
|
||||
description=tool_data.description,
|
||||
@@ -139,28 +180,22 @@ async def create_tool(
|
||||
implementation_type=tool_data.implementation_type,
|
||||
implementation_config=tool_data.implementation_config,
|
||||
is_public=tool_data.is_public,
|
||||
user_id=current_user.id
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
db.add(tool)
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
# 转换为响应格式,确保日期时间字段转换为字符串
|
||||
return {
|
||||
"id": tool.id,
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"category": tool.category,
|
||||
"function_schema": tool.function_schema,
|
||||
"implementation_type": tool.implementation_type,
|
||||
"implementation_config": tool.implementation_config,
|
||||
"is_public": tool.is_public,
|
||||
"use_count": tool.use_count,
|
||||
"user_id": tool.user_id,
|
||||
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
||||
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
||||
logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type)
|
||||
|
||||
# 刷新注册表
|
||||
tool_registry._custom_tool_configs[tool.name] = {
|
||||
**(tool.implementation_config or {}),
|
||||
"_type": tool.implementation_type,
|
||||
"_db_id": tool.id,
|
||||
}
|
||||
tool_registry._tool_schemas[tool.name] = tool.function_schema
|
||||
|
||||
return _tool_to_dict(tool)
|
||||
|
||||
|
||||
@router.put("/{tool_id}", response_model=ToolResponse)
|
||||
@@ -168,23 +203,20 @@ async def update_tool(
|
||||
tool_id: str,
|
||||
tool_data: ToolCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""更新工具"""
|
||||
"""更新工具。"""
|
||||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 检查权限(只有创建者可以更新)
|
||||
if tool.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权更新此工具")
|
||||
|
||||
# 检查名称冲突
|
||||
|
||||
if tool_data.name != tool.name:
|
||||
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail=f"工具名称 '{tool_data.name}' 已存在")
|
||||
|
||||
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
|
||||
|
||||
tool.name = tool_data.name
|
||||
tool.description = tool_data.description
|
||||
tool.category = tool_data.category
|
||||
@@ -192,47 +224,94 @@ async def update_tool(
|
||||
tool.implementation_type = tool_data.implementation_type
|
||||
tool.implementation_config = tool_data.implementation_config
|
||||
tool.is_public = tool_data.is_public
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
# 转换为响应格式,确保日期时间字段转换为字符串
|
||||
return {
|
||||
"id": tool.id,
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"category": tool.category,
|
||||
"function_schema": tool.function_schema,
|
||||
"implementation_type": tool.implementation_type,
|
||||
"implementation_config": tool.implementation_config,
|
||||
"is_public": tool.is_public,
|
||||
"use_count": tool.use_count,
|
||||
"user_id": tool.user_id,
|
||||
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
||||
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
||||
}
|
||||
|
||||
# 刷新注册表
|
||||
if tool.name in tool_registry._custom_tool_configs:
|
||||
tool_registry._custom_tool_configs[tool.name] = {
|
||||
**(tool.implementation_config or {}),
|
||||
"_type": tool.implementation_type,
|
||||
"_db_id": tool.id,
|
||||
}
|
||||
if tool.name in tool_registry._tool_schemas:
|
||||
tool_registry._tool_schemas[tool.name] = tool.function_schema
|
||||
|
||||
return _tool_to_dict(tool)
|
||||
|
||||
|
||||
@router.delete("/{tool_id}", status_code=200)
|
||||
@router.delete("/{tool_id}")
|
||||
async def delete_tool(
|
||||
tool_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""删除工具"""
|
||||
"""删除工具。"""
|
||||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 检查权限(只有创建者可以删除)
|
||||
if tool.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权删除此工具")
|
||||
|
||||
# 内置工具不允许删除
|
||||
if tool.implementation_type == "builtin":
|
||||
raise HTTPException(status_code=400, detail="内置工具不允许删除")
|
||||
|
||||
|
||||
db.delete(tool)
|
||||
db.commit()
|
||||
|
||||
|
||||
# 清理注册表
|
||||
tool_registry._custom_tool_configs.pop(tool.name, None)
|
||||
tool_registry._tool_schemas.pop(tool.name, None)
|
||||
|
||||
return {"message": "工具已删除"}
|
||||
|
||||
|
||||
# ─── 工具测试 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/test/http", response_model=TestResponse)
|
||||
async def test_http_tool(
|
||||
req: TestHTTPRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""测试 HTTP 工具(不保存到数据库)。"""
|
||||
result = await tool_registry.test_http_tool(
|
||||
url=req.url,
|
||||
method=req.method,
|
||||
headers=req.headers,
|
||||
body=req.body,
|
||||
args=req.args,
|
||||
timeout=req.timeout,
|
||||
)
|
||||
return TestResponse(**result)
|
||||
|
||||
|
||||
@router.post("/test/code", response_model=TestResponse)
|
||||
async def test_code_tool(
|
||||
req: TestCodeRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""测试代码工具(不保存到数据库)。"""
|
||||
result = await tool_registry.test_code_tool(
|
||||
source=req.source,
|
||||
args=req.args,
|
||||
)
|
||||
return TestResponse(**result)
|
||||
|
||||
|
||||
# ─── 使用计数 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/{tool_id}/use")
|
||||
async def record_tool_use(
|
||||
tool_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""记录工具使用次数(Agent 执行时自动调用)。"""
|
||||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
tool.use_count = (tool.use_count or 0) + 1
|
||||
db.commit()
|
||||
return {"use_count": tool.use_count}
|
||||
|
||||
@@ -56,6 +56,11 @@ class Settings(BaseSettings):
|
||||
# DeepSeek配置
|
||||
DEEPSEEK_API_KEY: str = ""
|
||||
DEEPSEEK_BASE_URL: str = "https://api.deepseek.com"
|
||||
|
||||
# SiliconFlow配置(Embedding 推荐使用 SiliconFlow)
|
||||
SILICONFLOW_API_KEY: str = ""
|
||||
SILICONFLOW_BASE_URL: str = "https://api.siliconflow.cn/v1"
|
||||
SILICONFLOW_EMBEDDING_MODEL: str = "netease-youdao/bce-embedding-base_v1"
|
||||
|
||||
# Anthropic配置
|
||||
ANTHROPIC_API_KEY: str = ""
|
||||
|
||||
@@ -47,4 +47,6 @@ def init_db():
|
||||
import app.models.permission
|
||||
import app.models.alert_rule
|
||||
import app.models.agent_llm_log
|
||||
import app.models.agent_vector_memory
|
||||
import app.models.knowledge_base
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
@@ -201,7 +201,7 @@ async def startup_event():
|
||||
# 不抛出异常,允许应用继续启动
|
||||
|
||||
# 注册路由
|
||||
from app.api import auth, uploads, workflows, executions, websocket, execution_logs, data_sources, agents, platform_templates, model_configs, webhooks, template_market, batch_operations, collaboration, permissions, monitoring, alert_rules, node_test, node_templates, tools, agent_chat, agent_monitoring
|
||||
from app.api import auth, uploads, workflows, executions, websocket, execution_logs, data_sources, agents, platform_templates, model_configs, webhooks, template_market, batch_operations, collaboration, permissions, monitoring, alert_rules, node_test, node_templates, tools, agent_chat, agent_monitoring, knowledge_base
|
||||
|
||||
app.include_router(auth.router)
|
||||
app.include_router(uploads.router)
|
||||
@@ -225,6 +225,7 @@ app.include_router(node_templates.router)
|
||||
app.include_router(tools.router)
|
||||
app.include_router(agent_chat.router)
|
||||
app.include_router(agent_monitoring.router)
|
||||
app.include_router(knowledge_base.router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -13,5 +13,7 @@ from app.models.permission import Role, Permission, WorkflowPermission, AgentPer
|
||||
from app.models.alert_rule import AlertRule, AlertLog
|
||||
from app.models.persistent_user_memory import PersistentUserMemory
|
||||
from app.models.agent_llm_log import AgentLLMLog
|
||||
from app.models.agent_vector_memory import AgentVectorMemory
|
||||
from app.models.knowledge_base import KnowledgeBase, Document, DocumentChunk
|
||||
|
||||
__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog"]
|
||||
__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "KnowledgeBase", "Document", "DocumentChunk"]
|
||||
45
backend/app/models/agent_vector_memory.py
Normal file
45
backend/app/models/agent_vector_memory.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Agent 向量记忆模型。
|
||||
|
||||
存储对话片段的 embedding 向量,支持语义检索。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Text, DateTime, Index
|
||||
from sqlalchemy.dialects.mysql import JSON as MySQLJSON
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class AgentVectorMemory(Base):
|
||||
"""Agent 向量记忆 — 存储对话文本及 embedding"""
|
||||
|
||||
__tablename__ = "agent_vector_memories"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
scope_kind = Column(String(16), nullable=False, index=True, comment="作用域类型: agent / bare")
|
||||
scope_id = Column(String(36), nullable=False, index=True, comment="作用域 ID: agent_id / user_id")
|
||||
session_key = Column(String(128), nullable=False, default="", comment="会话标识")
|
||||
content_text = Column(Text, nullable=False, comment="原始对话文本")
|
||||
embedding = Column(Text, nullable=True, comment="JSON 序列化的 embedding 向量")
|
||||
metadata_ = Column("metadata", MySQLJSON, nullable=True, comment="元数据: {type, iteration, ...}")
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_agent_vector_memory_scope", "scope_kind", "scope_id"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"scope_kind": self.scope_kind,
|
||||
"scope_id": self.scope_id,
|
||||
"session_key": self.session_key,
|
||||
"content_text": self.content_text,
|
||||
"embedding": self.embedding,
|
||||
"metadata": self.metadata_ or {},
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
116
backend/app/models/knowledge_base.py
Normal file
116
backend/app/models/knowledge_base.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
知识库 RAG 模型。
|
||||
|
||||
KnowledgeBase — 知识库容器
|
||||
Document — 上传的源文档
|
||||
DocumentChunk — 文档切片(含 embedding)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, Index
|
||||
from sqlalchemy.dialects.mysql import JSON as MySQLJSON
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class KnowledgeBase(Base):
|
||||
"""知识库"""
|
||||
__tablename__ = "knowledge_bases"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
name = Column(String(200), nullable=False, comment="知识库名称")
|
||||
description = Column(Text, nullable=True, comment="描述")
|
||||
user_id = Column(String(36), nullable=True, index=True, comment="创建者 ID")
|
||||
chunk_size = Column(Integer, default=500, comment="分块大小(字符数)")
|
||||
chunk_overlap = Column(Integer, default=50, comment="分块重叠(字符数)")
|
||||
doc_count = Column(Integer, default=0, comment="文档数量")
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
|
||||
documents = relationship("Document", back_populates="kb", cascade="all, delete-orphan",
|
||||
passive_deletes=True)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"user_id": self.user_id,
|
||||
"chunk_size": self.chunk_size,
|
||||
"chunk_overlap": self.chunk_overlap,
|
||||
"doc_count": self.doc_count,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Document(Base):
|
||||
"""知识库文档"""
|
||||
__tablename__ = "kb_documents"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
kb_id = Column(String(36), ForeignKey("knowledge_bases.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True, comment="所属知识库")
|
||||
filename = Column(String(500), nullable=False, comment="原始文件名")
|
||||
file_type = Column(String(20), nullable=False, comment="文件类型: txt/pdf/docx/md/csv")
|
||||
file_size = Column(Integer, default=0, comment="文件大小(字节)")
|
||||
status = Column(String(20), default="pending",
|
||||
comment="状态: pending/processing/ready/failed")
|
||||
error_message = Column(Text, nullable=True, comment="处理失败信息")
|
||||
chunk_count = Column(Integer, default=0, comment="分块数量")
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
kb = relationship("KnowledgeBase", back_populates="documents")
|
||||
chunks = relationship("DocumentChunk", back_populates="document",
|
||||
cascade="all, delete-orphan", passive_deletes=True)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"kb_id": self.kb_id,
|
||||
"filename": self.filename,
|
||||
"file_type": self.file_type,
|
||||
"file_size": self.file_size,
|
||||
"status": self.status,
|
||||
"error_message": self.error_message,
|
||||
"chunk_count": self.chunk_count,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
"""文档切片 — 含 embedding 向量"""
|
||||
__tablename__ = "kb_document_chunks"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
document_id = Column(String(36), ForeignKey("kb_documents.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True, comment="所属文档")
|
||||
kb_id = Column(String(36), ForeignKey("knowledge_bases.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True, comment="所属知识库")
|
||||
chunk_index = Column(Integer, default=0, comment="块序号")
|
||||
content = Column(Text, nullable=False, comment="切片文本内容")
|
||||
embedding = Column(Text, nullable=True, comment="JSON 序列化的 embedding 向量")
|
||||
metadata_ = Column("metadata", MySQLJSON, nullable=True, comment="元数据")
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
document = relationship("Document", back_populates="chunks")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_kb_chunks_kb_id", "kb_id"),
|
||||
Index("ix_kb_chunks_doc_id", "document_id"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"document_id": self.document_id,
|
||||
"kb_id": self.kb_id,
|
||||
"chunk_index": self.chunk_index,
|
||||
"content": self.content[:200] + "..." if len(self.content) > 200 else self.content,
|
||||
"metadata": self.metadata_ or {},
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
101
backend/app/services/document_parser.py
Normal file
101
backend/app/services/document_parser.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
文档解析器 — 支持 txt / pdf / docx / md / csv。
|
||||
|
||||
解析结果统一返回纯文本字符串,由调用方做分块处理。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_document(file_path: str, file_type: str) -> Optional[str]:
|
||||
"""
|
||||
解析文档为纯文本。
|
||||
|
||||
Args:
|
||||
file_path: 文件绝对路径
|
||||
file_type: 文件类型(txt / pdf / docx / md / csv)
|
||||
|
||||
Returns:
|
||||
文本内容,解析失败返回 None
|
||||
"""
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning("文件不存在: %s", file_path)
|
||||
return None
|
||||
|
||||
parsers = {
|
||||
"txt": _parse_text,
|
||||
"md": _parse_text,
|
||||
"pdf": _parse_pdf,
|
||||
"docx": _parse_docx,
|
||||
"csv": _parse_csv,
|
||||
}
|
||||
parser = parsers.get(file_type)
|
||||
if not parser:
|
||||
logger.warning("不支持的文件类型: %s", file_type)
|
||||
return None
|
||||
|
||||
try:
|
||||
text = parser(file_path)
|
||||
if text:
|
||||
text = text.strip()
|
||||
logger.info("文档解析完成: %s (%s, %d 字符)", file_path, file_type, len(text or ""))
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.error("文档解析失败: %s (%s)", file_path, e, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_text(file_path: str) -> str:
|
||||
"""解析纯文本 / Markdown 文件。"""
|
||||
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _parse_pdf(file_path: str) -> str:
|
||||
"""解析 PDF 文件。"""
|
||||
from pypdf import PdfReader
|
||||
|
||||
reader = PdfReader(file_path)
|
||||
pages = []
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
pages.append(text)
|
||||
return "\n\n".join(pages)
|
||||
|
||||
|
||||
def _parse_docx(file_path: str) -> str:
|
||||
"""解析 Word 文档。"""
|
||||
from docx import Document as DocxDocument
|
||||
|
||||
doc = DocxDocument(file_path)
|
||||
paragraphs = []
|
||||
for para in doc.paragraphs:
|
||||
if para.text and para.text.strip():
|
||||
paragraphs.append(para.text.strip())
|
||||
|
||||
# 也提取表格内容
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
cells = [cell.text.strip() for cell in row.cells if cell.text.strip()]
|
||||
if cells:
|
||||
paragraphs.append(" | ".join(cells))
|
||||
|
||||
return "\n\n".join(paragraphs)
|
||||
|
||||
|
||||
def _parse_csv(file_path: str) -> str:
|
||||
"""解析 CSV 文件为文本表格。"""
|
||||
rows = []
|
||||
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
reader = csv.reader(f)
|
||||
for row in reader:
|
||||
rows.append(" | ".join(cell.strip() for cell in row))
|
||||
return "\n".join(rows)
|
||||
229
backend/app/services/embedding_service.py
Normal file
229
backend/app/services/embedding_service.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Embedding 生成与语义检索服务。
|
||||
|
||||
使用 OpenAI text-embedding-3-small 生成文本向量,
|
||||
在内存中计算余弦相似度实现语义搜索。
|
||||
|
||||
如未配置 OpenAI API Key,所有方法静默降级返回空结果。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认 embedding 模型
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
EMBEDDING_DIMENSIONS = 1536
|
||||
|
||||
|
||||
class VectorEntry(TypedDict, total=False):
|
||||
"""向量条目"""
|
||||
id: str
|
||||
scope_kind: str
|
||||
scope_id: str
|
||||
content_text: str
|
||||
embedding: List[float]
|
||||
metadata: Dict[str, Any]
|
||||
score: float # 余弦相似度,仅检索结果包含
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""
|
||||
Embedding 服务。
|
||||
|
||||
用法:
|
||||
svc = EmbeddingService()
|
||||
emb = await svc.generate_embedding("你好")
|
||||
results = await svc.similarity_search(query_emb, entries, top_k=5)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._client: Optional[Any] = None
|
||||
self._client_lock = asyncio.Lock()
|
||||
|
||||
async def _get_client(self):
|
||||
"""延迟初始化 OpenAI 客户端(仅在首次调用时创建)。
|
||||
|
||||
优先级:SiliconFlow > OpenAI > DeepSeek。
|
||||
SiliconFlow 的 netease-youdao/bce-embedding-base_v1 在国内可直连且免费。
|
||||
"""
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
|
||||
async with self._client_lock:
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
backend_label = "none"
|
||||
|
||||
# 1) SiliconFlow(国内直连,推荐)
|
||||
if settings.SILICONFLOW_API_KEY:
|
||||
api_key = settings.SILICONFLOW_API_KEY
|
||||
base_url = settings.SILICONFLOW_BASE_URL or "https://api.siliconflow.cn/v1"
|
||||
backend_label = "siliconflow"
|
||||
logger.info("Embedding 后端: SiliconFlow (model=%s)", settings.SILICONFLOW_EMBEDDING_MODEL)
|
||||
|
||||
# 2) OpenAI
|
||||
if not api_key:
|
||||
if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key":
|
||||
api_key = settings.OPENAI_API_KEY
|
||||
base_url = settings.OPENAI_BASE_URL or "https://api.openai.com/v1"
|
||||
backend_label = "openai"
|
||||
|
||||
# 3) DeepSeek(部分代理可能支持 embedding)
|
||||
if not api_key:
|
||||
if settings.DEEPSEEK_API_KEY:
|
||||
api_key = settings.DEEPSEEK_API_KEY
|
||||
base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
|
||||
backend_label = "deepseek"
|
||||
|
||||
if not api_key:
|
||||
logger.info("未配置任何 API Key,向量记忆功能已禁用(请配置 SILICONFLOW_API_KEY 或 OPENAI_API_KEY)")
|
||||
self._client = None
|
||||
return None
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
logger.info("EmbeddingService 已初始化 (backend=%s)", backend_label)
|
||||
return self._client
|
||||
|
||||
def _get_model(self) -> str:
|
||||
"""根据后端选择对应的 embedding 模型。"""
|
||||
if settings.SILICONFLOW_API_KEY:
|
||||
return settings.SILICONFLOW_EMBEDDING_MODEL
|
||||
return EMBEDDING_MODEL
|
||||
|
||||
async def generate_embedding(self, text: str) -> Optional[List[float]]:
|
||||
"""
|
||||
为单段文本生成 embedding 向量。
|
||||
|
||||
返回 float 列表,无 API key 或出错时返回 None。
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return None
|
||||
|
||||
client = await self._get_client()
|
||||
if not client:
|
||||
return None
|
||||
|
||||
model = self._get_model()
|
||||
try:
|
||||
kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"input": text.strip()[:8000],
|
||||
}
|
||||
# OpenAI text-embedding-3-small 支持指定 dimensions;其它模型可能不支持
|
||||
if model == EMBEDDING_MODEL:
|
||||
kwargs["dimensions"] = EMBEDDING_DIMENSIONS
|
||||
resp = await client.embeddings.create(**kwargs)
|
||||
return resp.data[0].embedding
|
||||
except Exception as e:
|
||||
logger.warning("生成 embedding 失败: %s", e)
|
||||
return None
|
||||
|
||||
async def generate_embeddings(self, texts: List[str]) -> List[Optional[List[float]]]:
|
||||
"""
|
||||
批量生成 embedding 向量。
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
client = await self._get_client()
|
||||
if not client:
|
||||
return [None] * len(texts)
|
||||
|
||||
# 清理空文本
|
||||
valid_indices = [i for i, t in enumerate(texts) if t and t.strip()]
|
||||
if not valid_indices:
|
||||
return [None] * len(texts)
|
||||
|
||||
model = self._get_model()
|
||||
try:
|
||||
inputs = [texts[i].strip()[:8000] for i in valid_indices]
|
||||
kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"input": inputs,
|
||||
}
|
||||
if model == EMBEDDING_MODEL:
|
||||
kwargs["dimensions"] = EMBEDDING_DIMENSIONS
|
||||
resp = await client.embeddings.create(**kwargs)
|
||||
embeddings = [r.embedding for r in resp.data]
|
||||
except Exception as e:
|
||||
logger.warning("批量生成 embedding 失败: %s", e)
|
||||
return [None] * len(texts)
|
||||
|
||||
# 按原始顺序排列结果
|
||||
result: List[Optional[List[float]]] = [None] * len(texts)
|
||||
for idx, emb in zip(valid_indices, embeddings):
|
||||
result[idx] = emb
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
||||
"""计算两个向量的余弦相似度。"""
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(y * y for y in b) ** 0.5
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
async def similarity_search(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
entries: List[VectorEntry],
|
||||
top_k: int = 5,
|
||||
min_score: float = 0.3,
|
||||
) -> List[VectorEntry]:
|
||||
"""
|
||||
在内存中对 entries 做余弦相似度搜索,返回 Top-K 结果(按得分降序)。
|
||||
min_score: 最低相似度阈值,低于该值的结果被过滤。
|
||||
"""
|
||||
scored: List[VectorEntry] = []
|
||||
for entry in entries:
|
||||
emb = entry.get("embedding")
|
||||
if not emb:
|
||||
continue
|
||||
score = self.cosine_similarity(query_embedding, emb)
|
||||
if score >= min_score:
|
||||
entry["score"] = score
|
||||
scored.append(entry)
|
||||
|
||||
scored.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored[:top_k]
|
||||
|
||||
@staticmethod
|
||||
def serialize_embedding(embedding: List[float]) -> str:
|
||||
"""将 embedding 序列化为 JSON 字符串。"""
|
||||
return json.dumps(embedding, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def deserialize_embedding(data: str) -> List[float]:
|
||||
"""从 JSON 字符串反序列化 embedding。"""
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
try:
|
||||
return json.loads(data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
|
||||
|
||||
# 全局单例
|
||||
embedding_service = EmbeddingService()
|
||||
375
backend/app/services/knowledge_service.py
Normal file
375
backend/app/services/knowledge_service.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""
|
||||
知识库服务 — 文档管理、分块、Embedding、语义检索、RAG。
|
||||
|
||||
流程:
|
||||
上传 → 解析 → 分块 → Embedding → 存储
|
||||
检索 → 查询 → Embedding → 余弦相似度 → Top-K
|
||||
RAG → 检索 + 格式化上下文 → 返回给 Agent/用户
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.knowledge_base import KnowledgeBase, Document, DocumentChunk
|
||||
from app.services.document_parser import parse_document
|
||||
from app.services.embedding_service import embedding_service, VectorEntry
|
||||
from app.services.text_chunker import chunk_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 上传文件存储根目录
|
||||
UPLOAD_DIR = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
||||
"kb_uploads",
|
||||
)
|
||||
|
||||
|
||||
def _ensure_upload_dir():
|
||||
"""确保上传目录存在。"""
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def _get_kb_dir(kb_id: str) -> str:
|
||||
"""返回知识库文件存放目录。"""
|
||||
d = os.path.join(UPLOAD_DIR, kb_id)
|
||||
os.makedirs(d, exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
# ─── 知识库 CRUD ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_knowledge_base(
|
||||
db: Session,
|
||||
name: str,
|
||||
user_id: str,
|
||||
description: str = "",
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = 50,
|
||||
) -> KnowledgeBase:
|
||||
"""创建知识库。"""
|
||||
kb = KnowledgeBase(
|
||||
name=name,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
chunk_size=max(50, min(2000, chunk_size)),
|
||||
chunk_overlap=max(0, min(chunk_size // 2, chunk_overlap)),
|
||||
)
|
||||
db.add(kb)
|
||||
db.commit()
|
||||
db.refresh(kb)
|
||||
logger.info("知识库已创建: %s (%s)", kb.name, kb.id)
|
||||
return kb
|
||||
|
||||
|
||||
def list_knowledge_bases(db: Session, user_id: Optional[str] = None) -> List[KnowledgeBase]:
|
||||
"""列出知识库。"""
|
||||
q = db.query(KnowledgeBase)
|
||||
if user_id:
|
||||
q = q.filter(KnowledgeBase.user_id == user_id)
|
||||
return q.order_by(KnowledgeBase.updated_at.desc()).all()
|
||||
|
||||
|
||||
def get_knowledge_base(db: Session, kb_id: str) -> Optional[KnowledgeBase]:
|
||||
"""获取知识库详情。"""
|
||||
return db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||
|
||||
|
||||
def delete_knowledge_base(db: Session, kb_id: str) -> bool:
|
||||
"""删除知识库(连带文档和分块)。"""
|
||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||
if not kb:
|
||||
return False
|
||||
# 删除磁盘文件
|
||||
kb_dir = _get_kb_dir(kb_id)
|
||||
if os.path.isdir(kb_dir):
|
||||
import shutil
|
||||
shutil.rmtree(kb_dir, ignore_errors=True)
|
||||
db.delete(kb)
|
||||
db.commit()
|
||||
logger.info("知识库已删除: %s", kb_id)
|
||||
return True
|
||||
|
||||
|
||||
# ─── 文档管理 ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def upload_document(
|
||||
db: Session,
|
||||
kb_id: str,
|
||||
filename: str,
|
||||
file_content: bytes,
|
||||
) -> Document:
|
||||
"""
|
||||
上传文档到知识库:
|
||||
1. 保存原始文件
|
||||
2. 解析为文本
|
||||
3. 分块
|
||||
4. 生成 Embedding
|
||||
5. 存储分块
|
||||
"""
|
||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||
if not kb:
|
||||
raise ValueError(f"知识库不存在: {kb_id}")
|
||||
|
||||
_ensure_upload_dir()
|
||||
kb_dir = _get_kb_dir(kb_id)
|
||||
|
||||
# 提取文件类型
|
||||
ext = os.path.splitext(filename)[1].lower().lstrip(".")
|
||||
if ext in ("txt", "md", "pdf", "docx", "csv"):
|
||||
file_type = ext
|
||||
else:
|
||||
file_type = "txt"
|
||||
|
||||
# 保存原始文件
|
||||
file_path = os.path.join(kb_dir, f"{uuid.uuid4()}_{filename}")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
file_size = len(file_content)
|
||||
|
||||
# 创建文档记录
|
||||
doc = Document(
|
||||
kb_id=kb_id,
|
||||
filename=filename,
|
||||
file_type=file_type,
|
||||
file_size=file_size,
|
||||
status="processing",
|
||||
)
|
||||
db.add(doc)
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
|
||||
try:
|
||||
# 解析文本
|
||||
text = parse_document(file_path, file_type)
|
||||
if not text:
|
||||
doc.status = "failed"
|
||||
doc.error_message = "文档解析失败或内容为空"
|
||||
db.commit()
|
||||
return doc
|
||||
|
||||
# 分块
|
||||
chunks = chunk_text(
|
||||
text,
|
||||
chunk_size=kb.chunk_size,
|
||||
chunk_overlap=kb.chunk_overlap,
|
||||
)
|
||||
if not chunks:
|
||||
doc.status = "failed"
|
||||
doc.error_message = "文档分块后为空"
|
||||
db.commit()
|
||||
return doc
|
||||
|
||||
logger.info("文档分块完成: %s → %d 块", filename, len(chunks))
|
||||
|
||||
# 批量生成 Embedding
|
||||
embeddings: List[Optional[List[float]]] = []
|
||||
try:
|
||||
embeddings = await embedding_service.generate_embeddings(chunks)
|
||||
except Exception as e:
|
||||
logger.warning("批量生成 embedding 失败,逐块回退: %s", e)
|
||||
for c in chunks:
|
||||
try:
|
||||
emb = await embedding_service.generate_embedding(c)
|
||||
embeddings.append(emb)
|
||||
except Exception:
|
||||
embeddings.append(None)
|
||||
|
||||
# 存储分块
|
||||
chunk_records = []
|
||||
for i, (chunk_text_content, emb) in enumerate(zip(chunks, embeddings)):
|
||||
record = DocumentChunk(
|
||||
document_id=doc.id,
|
||||
kb_id=kb_id,
|
||||
chunk_index=i,
|
||||
content=chunk_text_content,
|
||||
embedding=json.dumps(emb) if emb else None,
|
||||
metadata_={
|
||||
"filename": filename,
|
||||
"file_type": file_type,
|
||||
"chunk_index": i,
|
||||
"has_embedding": emb is not None,
|
||||
},
|
||||
)
|
||||
chunk_records.append(record)
|
||||
|
||||
db.add_all(chunk_records)
|
||||
|
||||
# 更新文档状态
|
||||
doc.status = "ready"
|
||||
doc.chunk_count = len(chunks)
|
||||
kb.doc_count = db.query(Document).filter(
|
||||
Document.kb_id == kb_id, Document.status == "ready"
|
||||
).count()
|
||||
|
||||
db.commit()
|
||||
logger.info("文档处理完成: %s (%d 块, embedding=%s)",
|
||||
filename, len(chunks),
|
||||
"yes" if any(e for e in embeddings) else "no")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
doc = db.query(Document).filter(Document.id == doc.id).first()
|
||||
if doc:
|
||||
doc.status = "failed"
|
||||
doc.error_message = str(e)[:500]
|
||||
db.commit()
|
||||
logger.error("文档处理失败: %s: %s", filename, e, exc_info=True)
|
||||
|
||||
return doc
|
||||
|
||||
|
||||
def list_documents(db: Session, kb_id: str) -> List[Document]:
|
||||
"""列出知识库中的文档。"""
|
||||
return (
|
||||
db.query(Document)
|
||||
.filter(Document.kb_id == kb_id)
|
||||
.order_by(Document.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def delete_document(db: Session, doc_id: str) -> bool:
|
||||
"""删除文档(连带分块)。"""
|
||||
doc = db.query(Document).filter(Document.id == doc_id).first()
|
||||
if not doc:
|
||||
return False
|
||||
# 减少知识库文档计数
|
||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == doc.kb_id).first()
|
||||
db.delete(doc)
|
||||
if kb:
|
||||
kb.doc_count = db.query(Document).filter(
|
||||
Document.kb_id == kb.id, Document.status == "ready"
|
||||
).count()
|
||||
db.commit()
|
||||
logger.info("文档已删除: %s", doc_id)
|
||||
return True
|
||||
|
||||
|
||||
# ─── 语义检索 ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def search(
|
||||
db: Session,
|
||||
kb_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
min_score: float = 0.3,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
语义搜索知识库。
|
||||
|
||||
流程:查询文本 → Embedding → 余弦相似度匹配所有分块 → Top-K
|
||||
"""
|
||||
# 生成查询 embedding
|
||||
query_emb = await embedding_service.generate_embedding(query)
|
||||
if not query_emb:
|
||||
logger.warning("搜索失败:无法生成查询 embedding")
|
||||
return []
|
||||
|
||||
# 加载该知识库所有分块
|
||||
chunks = (
|
||||
db.query(DocumentChunk)
|
||||
.filter(DocumentChunk.kb_id == kb_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
# 构建 VectorEntry 列表
|
||||
entries: List[VectorEntry] = []
|
||||
for c in chunks:
|
||||
if not c.embedding:
|
||||
continue
|
||||
try:
|
||||
emb = json.loads(c.embedding) if isinstance(c.embedding, str) else c.embedding
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
entries.append({
|
||||
"id": c.id,
|
||||
"scope_kind": "kb",
|
||||
"scope_id": kb_id,
|
||||
"content_text": c.content,
|
||||
"embedding": emb,
|
||||
"metadata": c.metadata_ or {},
|
||||
})
|
||||
|
||||
if not entries:
|
||||
return []
|
||||
|
||||
# 相似度搜索
|
||||
matched = await embedding_service.similarity_search(
|
||||
query_emb, entries, top_k=top_k, min_score=min_score
|
||||
)
|
||||
|
||||
results = []
|
||||
for m in matched:
|
||||
results.append({
|
||||
"chunk_id": m["id"],
|
||||
"content": m["content_text"],
|
||||
"score": m["score"],
|
||||
"metadata": m.get("metadata", {}),
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def rag_query(
|
||||
db: Session,
|
||||
kb_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
min_score: float = 0.3,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
RAG 查询:搜索相关片段 + 格式化为上下文。
|
||||
|
||||
返回:
|
||||
{
|
||||
"query": "...",
|
||||
"context": "根据以下资料回答:\n\n[片段1]\n[片段2]\n...",
|
||||
"sources": [...],
|
||||
}
|
||||
"""
|
||||
results = await search(db, kb_id, query, top_k=top_k, min_score=min_score)
|
||||
|
||||
if not results:
|
||||
return {
|
||||
"query": query,
|
||||
"context": "",
|
||||
"sources": [],
|
||||
"found": False,
|
||||
}
|
||||
|
||||
# 格式化上下文
|
||||
lines = ["根据以下资料回答用户问题:\n"]
|
||||
for i, r in enumerate(results, 1):
|
||||
source = r["metadata"].get("filename", "未知来源")
|
||||
lines.append(f"[{i}] (来源: {source}):\n{r['content']}\n")
|
||||
|
||||
context = "\n".join(lines)
|
||||
|
||||
sources = [
|
||||
{
|
||||
"content": r["content"],
|
||||
"score": r["score"],
|
||||
"source": r["metadata"].get("filename", "未知"),
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"context": context,
|
||||
"sources": sources,
|
||||
"found": True,
|
||||
}
|
||||
132
backend/app/services/text_chunker.py
Normal file
132
backend/app/services/text_chunker.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
文本分块器 — 将长文本切分为适合 Embedding + 检索的块。
|
||||
|
||||
策略:
|
||||
1. 按段落分割(连续换行)
|
||||
2. 超长段落按句子切分(句号、问号、感叹号、换行)
|
||||
3. 短段落合并,直到接近 chunk_size
|
||||
4. chunk 之间保留 overlap 字符的重叠
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str,
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = 50,
|
||||
) -> List[str]:
|
||||
"""
|
||||
将文本分割为语义完整的块。
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
chunk_size: 每块目标字符数
|
||||
chunk_overlap: 相邻块重叠字符数
|
||||
|
||||
Returns:
|
||||
文本块列表
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
# 1. 按段落分割
|
||||
paragraphs = _split_paragraphs(text)
|
||||
if not paragraphs:
|
||||
return []
|
||||
|
||||
# 2. 处理每个段落:超长段落进一步拆分
|
||||
segments: List[str] = []
|
||||
for para in paragraphs:
|
||||
if len(para) <= chunk_size:
|
||||
segments.append(para)
|
||||
else:
|
||||
segments.extend(_split_long_paragraph(para, chunk_size))
|
||||
|
||||
# 3. 合并短段落为块
|
||||
chunks = _merge_segments(segments, chunk_size, chunk_overlap)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _split_paragraphs(text: str) -> List[str]:
|
||||
"""按连续换行分割段落,过滤空白段落。"""
|
||||
# 按两个以上换行分割
|
||||
raw = re.split(r"\n\s*\n", text)
|
||||
paras = []
|
||||
for p in raw:
|
||||
p = p.strip()
|
||||
if p:
|
||||
paras.append(p)
|
||||
return paras
|
||||
|
||||
|
||||
def _split_long_paragraph(para: str, chunk_size: int) -> List[str]:
|
||||
"""将超长段落按句子切分。"""
|
||||
# 先按句子结束符分割(去掉 look-behind 长度限制)
|
||||
sentences = re.split(r"(?<=[。!?])|(?<=[.!?])\s*", para)
|
||||
sentences = [s.strip() for s in sentences if s.strip()]
|
||||
|
||||
if not sentences:
|
||||
# 无法按句子分割,按字符硬切
|
||||
return [para[i : i + chunk_size] for i in range(0, len(para), chunk_size)]
|
||||
|
||||
chunks: List[str] = []
|
||||
current = ""
|
||||
for sent in sentences:
|
||||
if len(current) + len(sent) <= chunk_size:
|
||||
current += sent
|
||||
else:
|
||||
if current:
|
||||
chunks.append(current.strip())
|
||||
current = sent
|
||||
if current:
|
||||
chunks.append(current.strip())
|
||||
|
||||
# 如果某个块还是太长(单句超长),按字符硬切
|
||||
result: List[str] = []
|
||||
for c in chunks:
|
||||
if len(c) <= chunk_size:
|
||||
result.append(c)
|
||||
else:
|
||||
for i in range(0, len(c), chunk_size):
|
||||
result.append(c[i : i + chunk_size])
|
||||
return result
|
||||
|
||||
|
||||
def _merge_segments(segments: List[str], chunk_size: int, overlap: int) -> List[str]:
|
||||
"""合并短段落为块,块间带重叠。"""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
chunks: List[str] = []
|
||||
current = ""
|
||||
|
||||
for seg in segments:
|
||||
# 如果当前块 + 新段不超过上限,追加
|
||||
if not current:
|
||||
current = seg
|
||||
elif len(current) + 1 + len(seg) <= chunk_size:
|
||||
current += "\n\n" + seg
|
||||
else:
|
||||
# 当前块已满
|
||||
chunks.append(current.strip())
|
||||
|
||||
# 重叠:取前一块末尾的 overlap 字符
|
||||
if overlap > 0 and len(current) > overlap:
|
||||
# 从上一个块末尾取 overlap 字符作为新块的起始
|
||||
tail = current[-overlap:]
|
||||
# 从上一个换行处开始,避免截断句子
|
||||
newline_pos = tail.find("\n")
|
||||
if newline_pos > 0 and newline_pos < overlap // 2:
|
||||
tail = tail[newline_pos + 1 :]
|
||||
current = tail + "\n\n" + seg
|
||||
else:
|
||||
current = seg
|
||||
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
|
||||
return chunks
|
||||
@@ -1,123 +1,360 @@
|
||||
"""
|
||||
工具注册表 - 管理所有可用工具
|
||||
工具注册表 — 管理所有可用工具(内置 / HTTP / 代码段)。
|
||||
|
||||
提供统一的工具注册、查找和执行接口。
|
||||
"""
|
||||
from typing import Dict, Any, Callable, Optional, List
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from app.models.tool import Tool
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 代码工具的安全内置模块
|
||||
_CODE_SAFE_GLOBALS = {
|
||||
"json": json,
|
||||
"dict": dict,
|
||||
"list": list,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"len": len,
|
||||
"range": range,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"sorted": sorted,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"isinstance": isinstance,
|
||||
"type": type,
|
||||
"True": True,
|
||||
"False": False,
|
||||
"None": None,
|
||||
}
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表 - 管理所有可用工具"""
|
||||
|
||||
"""工具注册表 — 管理所有可用工具。"""
|
||||
|
||||
def __init__(self):
|
||||
self._builtin_tools: Dict[str, Callable] = {}
|
||||
self._tool_schemas: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 自定义工具配置(从 DB 加载的非 builtin 工具)
|
||||
self._custom_tool_configs: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# ─── 内置工具注册 ─────────────────────────────────────────
|
||||
|
||||
def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]):
|
||||
"""
|
||||
注册内置工具
|
||||
|
||||
注册内置工具。
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
func: 工具函数(可以是同步或异步函数)
|
||||
schema: 工具定义(OpenAI Function格式)
|
||||
func: 工具函数(同步或异步)
|
||||
schema: OpenAI function calling 格式 schema
|
||||
"""
|
||||
self._builtin_tools[name] = func
|
||||
self._tool_schemas[name] = schema
|
||||
logger.debug("注册内置工具: %s", name)
|
||||
|
||||
|
||||
# ─── 工具信息查询 ─────────────────────────────────────────
|
||||
|
||||
def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取工具定义"""
|
||||
return self._tool_schemas.get(name)
|
||||
|
||||
|
||||
def get_tool_function(self, name: str) -> Optional[Callable]:
|
||||
"""获取工具函数"""
|
||||
return self._builtin_tools.get(name)
|
||||
|
||||
|
||||
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有工具定义(用于LLM)"""
|
||||
return list(self._tool_schemas.values())
|
||||
|
||||
def builtin_tool_count(self) -> int:
|
||||
"""已注册的内置工具数量(用于健康检查 / 启动自检)。"""
|
||||
return len(self._builtin_tools)
|
||||
|
||||
def builtin_tool_names(self) -> List[str]:
|
||||
"""已注册的内置工具名称,有序列表。"""
|
||||
return sorted(self._builtin_tools.keys())
|
||||
|
||||
def load_tools_from_db(self, db: Session, tool_names: List[str] = None):
|
||||
"""
|
||||
从数据库加载工具
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
tool_names: 工具名称列表(可选,如果为None则加载所有公开工具)
|
||||
"""
|
||||
query = db.query(Tool).filter(Tool.is_public == True)
|
||||
if tool_names:
|
||||
query = query.filter(Tool.name.in_(tool_names))
|
||||
|
||||
tools = query.all()
|
||||
for tool in tools:
|
||||
self._tool_schemas[tool.name] = tool.function_schema
|
||||
# 根据implementation_type加载工具实现
|
||||
if tool.implementation_type == 'builtin':
|
||||
# 从内置工具中查找
|
||||
if tool.name in self._builtin_tools:
|
||||
logger.debug(f"工具 {tool.name} 已在内置工具中注册")
|
||||
else:
|
||||
logger.warning(f"工具 {tool.name} 标记为builtin但未在内置工具中找到")
|
||||
elif tool.implementation_type == 'http':
|
||||
# HTTP工具需要特殊处理
|
||||
self._register_http_tool(tool)
|
||||
elif tool.implementation_type == 'workflow':
|
||||
# 工作流工具
|
||||
self._register_workflow_tool(tool)
|
||||
elif tool.implementation_type == 'code':
|
||||
# 代码执行工具
|
||||
self._register_code_tool(tool)
|
||||
|
||||
logger.info(f"从数据库加载了 {len(tools)} 个工具")
|
||||
|
||||
def _register_http_tool(self, tool: Tool):
|
||||
"""注册HTTP工具"""
|
||||
# TODO: 实现HTTP工具的动态注册
|
||||
logger.warning(f"HTTP工具 {tool.name} 的动态注册尚未实现")
|
||||
|
||||
def _register_workflow_tool(self, tool: Tool):
|
||||
"""注册工作流工具"""
|
||||
# TODO: 实现工作流工具的动态注册
|
||||
logger.warning(f"工作流工具 {tool.name} 的动态注册尚未实现")
|
||||
|
||||
def _register_code_tool(self, tool: Tool):
|
||||
"""注册代码执行工具"""
|
||||
# TODO: 实现代码执行工具的动态注册
|
||||
logger.warning(f"代码执行工具 {tool.name} 的动态注册尚未实现")
|
||||
|
||||
def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
根据工具名称列表获取工具定义
|
||||
|
||||
Args:
|
||||
tool_names: 工具名称列表
|
||||
|
||||
Returns:
|
||||
工具定义列表(OpenAI Function格式)
|
||||
"""
|
||||
tools = []
|
||||
for name in tool_names:
|
||||
schema = self.get_tool_schema(name)
|
||||
if schema:
|
||||
tools.append(schema)
|
||||
else:
|
||||
logger.warning(f"工具 {name} 未找到")
|
||||
logger.warning("工具 %s 未找到", name)
|
||||
return tools
|
||||
|
||||
# ─── 从数据库加载自定义工具 ──────────────────────────────
|
||||
|
||||
# 全局工具注册表实例
|
||||
def load_tools_from_db(self, db: Session, tool_names: Optional[List[str]] = None):
|
||||
"""
|
||||
从数据库加载工具定义。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
tool_names: 指定名称列表;None 则加载所有公开工具
|
||||
"""
|
||||
query = db.query(Tool).filter(Tool.is_public == True)
|
||||
if tool_names:
|
||||
query = query.filter(Tool.name.in_(tool_names))
|
||||
|
||||
tools = query.all()
|
||||
count = 0
|
||||
for tool in tools:
|
||||
self._tool_schemas[tool.name] = tool.function_schema
|
||||
if tool.implementation_type == "builtin":
|
||||
if tool.name not in self._builtin_tools:
|
||||
logger.warning("工具 %s 标记为 builtin 但未注册", tool.name)
|
||||
else:
|
||||
# 存储自定义工具配置
|
||||
config = tool.implementation_config or {}
|
||||
config["_type"] = tool.implementation_type
|
||||
config["_db_id"] = tool.id
|
||||
self._custom_tool_configs[tool.name] = config
|
||||
count += 1
|
||||
|
||||
if count:
|
||||
logger.info("从数据库加载了 %d 个自定义工具", count)
|
||||
|
||||
# ─── 工具执行 ─────────────────────────────────────────────
|
||||
|
||||
async def execute_tool(self, name: str, args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
执行任意工具(内置 / HTTP / 代码段)。
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
args: 参数字典
|
||||
|
||||
Returns:
|
||||
结果字符串
|
||||
"""
|
||||
# 1. 内置工具
|
||||
func = self._builtin_tools.get(name)
|
||||
if func:
|
||||
return await self._run_function(func, name, args)
|
||||
|
||||
# 2. 自定义工具
|
||||
config = self._custom_tool_configs.get(name)
|
||||
if not config:
|
||||
return json.dumps({"error": f"工具 '{name}' 不存在"}, ensure_ascii=False)
|
||||
|
||||
impl_type = config.get("_type", "")
|
||||
try:
|
||||
if impl_type == "http":
|
||||
return await self._execute_http_tool(name, config, args)
|
||||
elif impl_type == "code":
|
||||
return await self._execute_code_tool(name, config, args)
|
||||
elif impl_type == "workflow":
|
||||
return json.dumps({"error": "工作流工具暂不支持动态执行"},
|
||||
ensure_ascii=False)
|
||||
else:
|
||||
return json.dumps({"error": f"不支持的实现类型: {impl_type}"},
|
||||
ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error("工具 %s 执行失败: %s", name, e, exc_info=True)
|
||||
return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
|
||||
ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
async def _run_function(func: Callable, name: str, args: Dict[str, Any]) -> str:
|
||||
"""执行内置工具函数。"""
|
||||
import asyncio
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
result = await func(**args)
|
||||
else:
|
||||
result = func(**args)
|
||||
if isinstance(result, (dict, list)):
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
return str(result)
|
||||
except Exception as e:
|
||||
logger.error("工具 '%s' 执行失败: %s", name, e, exc_info=True)
|
||||
return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
|
||||
ensure_ascii=False)
|
||||
|
||||
# ─── HTTP 工具 ────────────────────────────────────────────
|
||||
|
||||
async def _execute_http_tool(
|
||||
self, name: str, config: Dict[str, Any], args: Dict[str, Any]
|
||||
) -> str:
|
||||
"""
|
||||
执行 HTTP 工具。
|
||||
|
||||
implementation_config 格式:
|
||||
{
|
||||
"url": "https://api.example.com/{path_param}",
|
||||
"method": "GET",
|
||||
"headers": {"Authorization": "Bearer xxx"},
|
||||
"body_template": {"key": "{param}"}, # 可选
|
||||
"timeout": 30
|
||||
}
|
||||
URL 和 body_template 中的 {param} 会被 args 替换。
|
||||
"""
|
||||
import httpx
|
||||
|
||||
url = config.get("url", "")
|
||||
method = (config.get("method") or "GET").upper()
|
||||
headers = config.get("headers") or {}
|
||||
body_template = config.get("body_template")
|
||||
timeout = config.get("timeout", 30)
|
||||
|
||||
if not url:
|
||||
return json.dumps({"error": "HTTP 工具未配置 URL"}, ensure_ascii=False)
|
||||
|
||||
# 模板替换:{param} → args["param"]
|
||||
def _fmt(template: str) -> str:
|
||||
try:
|
||||
return template.format(**args)
|
||||
except KeyError:
|
||||
return template
|
||||
|
||||
url = _fmt(url)
|
||||
body: Any = None
|
||||
if body_template:
|
||||
body_str = json.dumps(body_template)
|
||||
body_str = _fmt(body_str)
|
||||
try:
|
||||
body = json.loads(body_str)
|
||||
except json.JSONDecodeError:
|
||||
body = body_str
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.request(method, url, headers=headers, json=body)
|
||||
text = response.text[:10000] # 截断过长的响应
|
||||
|
||||
result = {
|
||||
"status_code": response.status_code,
|
||||
"body": text,
|
||||
}
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
# ─── 代码工具 ─────────────────────────────────────────────
|
||||
|
||||
async def _execute_code_tool(
|
||||
self, name: str, config: Dict[str, Any], args: Dict[str, Any]
|
||||
) -> str:
|
||||
"""
|
||||
执行代码工具。
|
||||
|
||||
implementation_config 格式:
|
||||
{
|
||||
"source": "def run(args):\\n return {'sum': args['a'] + args['b']}",
|
||||
"language": "python"
|
||||
}
|
||||
代码工具需定义一个 run(args) 函数作为入口。
|
||||
source 在沙箱环境中执行,不可访问文件系统/网络。
|
||||
"""
|
||||
source = config.get("source", "")
|
||||
if not source:
|
||||
return json.dumps({"error": "代码工具未配置 source"}, ensure_ascii=False)
|
||||
|
||||
# 编译代码,限制可访问的全局变量
|
||||
safe_globals = _CODE_SAFE_GLOBALS.copy()
|
||||
safe_globals["__builtins__"] = {} # 禁用所有内置函数
|
||||
|
||||
try:
|
||||
exec(source, safe_globals)
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"error": f"代码编译失败: {e}",
|
||||
"traceback": traceback.format_exc(),
|
||||
}, ensure_ascii=False)
|
||||
|
||||
run_func = safe_globals.get("run")
|
||||
if not run_func or not callable(run_func):
|
||||
return json.dumps({
|
||||
"error": "代码工具必须定义一个 run(args) 函数"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
result = run_func(args)
|
||||
if isinstance(result, (dict, list)):
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
return str(result)
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"error": f"代码执行失败: {e}",
|
||||
"traceback": traceback.format_exc(),
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# ─── 测试工具(不保存,直接执行)─────────────────────────
|
||||
|
||||
async def test_http_tool(
|
||||
self, url: str, method: str, headers: Dict[str, str],
|
||||
body: Optional[Dict[str, Any]], args: Dict[str, Any],
|
||||
timeout: int = 30,
|
||||
) -> Dict[str, Any]:
|
||||
"""测试 HTTP 工具配置(不保存到 DB)。"""
|
||||
import httpx
|
||||
|
||||
def _fmt(template: str) -> str:
|
||||
try:
|
||||
return template.format(**args)
|
||||
except KeyError:
|
||||
return template
|
||||
|
||||
url = _fmt(url)
|
||||
body_sent = None
|
||||
if body:
|
||||
body_str = json.dumps(body)
|
||||
body_str = _fmt(body_str)
|
||||
try:
|
||||
body_sent = json.loads(body_str)
|
||||
except json.JSONDecodeError:
|
||||
body_sent = body_str
|
||||
|
||||
start = __import__("time").time()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.request(method.upper(), url,
|
||||
headers=headers, json=body_sent)
|
||||
elapsed_ms = int((__import__("time").time() - start) * 1000)
|
||||
return {
|
||||
"success": True,
|
||||
"status_code": response.status_code,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"body": response.text[:10000],
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def test_code_tool(
|
||||
self, source: str, args: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""测试代码工具(不保存到 DB)。"""
|
||||
safe_globals = _CODE_SAFE_GLOBALS.copy()
|
||||
safe_globals["__builtins__"] = {}
|
||||
|
||||
try:
|
||||
exec(source, safe_globals)
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"编译失败: {e}"}
|
||||
|
||||
run_func = safe_globals.get("run")
|
||||
if not run_func:
|
||||
return {"success": False, "error": "代码须定义 run(args) 函数"}
|
||||
|
||||
try:
|
||||
start = __import__("time").time()
|
||||
result = run_func(args)
|
||||
elapsed_ms = int((__import__("time").time() - start) * 1000)
|
||||
return {"success": True, "elapsed_ms": elapsed_ms, "result": result}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"执行失败: {e}"}
|
||||
|
||||
|
||||
# 全局单例
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
@@ -1917,12 +1917,25 @@ class WorkflowEngine:
|
||||
if hasattr(self, '_on_tool_executed_budget'):
|
||||
_agent_on_tool = self._on_tool_executed_budget
|
||||
|
||||
# Agent 的 LLM 调用计入工作流预算
|
||||
def _on_agent_llm():
|
||||
self._llm_invocations += 1
|
||||
if self._llm_invocations > self._cap_llm:
|
||||
raise WorkflowExecutionError(
|
||||
detail=f"已超过 LLM 节点调用预算({self._cap_llm} 次)",
|
||||
)
|
||||
|
||||
result = await run_agent_node(
|
||||
node_data=node.get("data", {}),
|
||||
input_data=input_data,
|
||||
execution_logger=self.logger,
|
||||
user_id=self.trusted_model_config_user_id,
|
||||
on_tool_executed=_agent_on_tool,
|
||||
on_llm_invocation=_on_agent_llm,
|
||||
budget_limits={
|
||||
"max_llm_invocations": self._cap_llm,
|
||||
"max_tool_calls": self._cap_tool,
|
||||
},
|
||||
)
|
||||
if self.logger:
|
||||
duration = int((time.time() - start_time) * 1000)
|
||||
|
||||
@@ -46,23 +46,31 @@ pytest --cov=app --cov-report=html
|
||||
|
||||
## 测试标记
|
||||
|
||||
- `@pytest.mark.unit` - 单元测试
|
||||
- `@pytest.mark.integration` - 集成测试
|
||||
- `@pytest.mark.unit` - 纯单元测试(不依赖网络/数据库)
|
||||
- `@pytest.mark.integration` - 集成测试(需要网络或可选依赖)
|
||||
- `@pytest.mark.slow` - 慢速测试(需要网络或数据库)
|
||||
- `@pytest.mark.api` - API测试
|
||||
- `@pytest.mark.workflow` - 工作流测试
|
||||
- `@pytest.mark.auth` - 认证测试
|
||||
- `@pytest.mark.asyncio` - 异步测试(使用 `pytest-asyncio`)
|
||||
|
||||
## 测试结构
|
||||
|
||||
```
|
||||
tests/
|
||||
├── __init__.py
|
||||
├── conftest.py # 共享fixtures和配置
|
||||
├── test_auth.py # 认证API测试
|
||||
├── test_workflows.py # 工作流API测试
|
||||
├── test_workflow_engine.py # 工作流引擎测试
|
||||
└── test_workflow_validator.py # 工作流验证器测试
|
||||
├── conftest.py # 共享fixtures和配置
|
||||
├── test_auth.py # 认证API测试
|
||||
├── test_workflows.py # 工作流API测试
|
||||
├── test_workflow_engine.py # 工作流引擎测试
|
||||
├── test_workflow_validator.py # 工作流验证器测试
|
||||
├── test_text_chunker.py # 文本分块器单元测试
|
||||
├── test_document_parser.py # 文档解析器单元测试
|
||||
├── test_tool_registry.py # 工具注册表单元测试
|
||||
├── test_tools_api.py # 工具市场API测试
|
||||
├── test_agent_memory.py # Agent记忆/上下文/Schema测试
|
||||
├── test_embedding_service.py # Embedding服务单元测试
|
||||
└── test_knowledge_base.py # 知识库RAG单元测试
|
||||
```
|
||||
|
||||
## Fixtures
|
||||
@@ -84,13 +92,31 @@ tests/
|
||||
|
||||
## 测试数据库
|
||||
|
||||
测试使用SQLite内存数据库,每个测试函数都会:
|
||||
测试使用SQLite临时文件数据库(而非 `:memory:`),每个测试函数都会:
|
||||
1. 创建所有表
|
||||
2. 执行测试
|
||||
3. 删除所有表
|
||||
|
||||
这样可以确保测试之间的隔离性。
|
||||
|
||||
**注意**:使用临时文件而非 `:memory:` 是因为 FastAPI + TestClient 在异步/多线程请求中,SQLite 内存数据库会发生"连接隔离"问题——每个连接看到不同的数据库实例。文件数据库确保了所有连接共享同一份数据。
|
||||
|
||||
**注意**:纯服务层测试(如 `test_tool_registry.py`、`test_embedding_service.py`、`test_text_chunker.py`)不依赖 `db_session` fixture,它们直接对服务类打桩或使用临时文件,运行速度更快。
|
||||
|
||||
## 异步测试配置
|
||||
|
||||
使用 `pytest-asyncio` 支持异步测试。所有被 `@pytest.mark.asyncio` 标记的测试函数必须使用 `async def` 声明:
|
||||
|
||||
```python
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_my_async_method(self):
|
||||
result = await my_service.my_method()
|
||||
assert result is not None
|
||||
```
|
||||
|
||||
**注意**:`conftest.py` 中的 fixtures 默认使用 `function` 作用域,异步 fixture 不需要额外配置。
|
||||
|
||||
## 编写新测试
|
||||
|
||||
### 示例:API测试
|
||||
@@ -128,6 +154,14 @@ class TestMyService:
|
||||
|
||||
5. **测试数据**:使用 fixtures 提供测试数据,避免硬编码。
|
||||
|
||||
6. **已知问题**:有 9 个历史遗留测试失败(主要在 `test_auth.py`、`test_workflow_engine.py`、`test_workflow_validator.py`、`test_nodes_all.py`、`test_nodes_phase4.py` 中),这些失败与本次改造无关,可忽略检查新建的测试文件时排除它们:
|
||||
```bash
|
||||
pytest tests/test_text_chunker.py tests/test_document_parser.py \
|
||||
tests/test_tool_registry.py tests/test_tools_api.py \
|
||||
tests/test_agent_memory.py tests/test_embedding_service.py \
|
||||
tests/test_knowledge_base.py -v
|
||||
```
|
||||
|
||||
## CI/CD集成
|
||||
|
||||
在CI/CD流程中运行测试:
|
||||
|
||||
@@ -6,12 +6,35 @@ from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from fastapi.testclient import TestClient
|
||||
from app.core.database import Base, get_db, SessionLocal
|
||||
from app.main import app
|
||||
from app.main import app as fastapi_app
|
||||
from app.core.config import settings
|
||||
import os
|
||||
|
||||
# 测试数据库URL(使用SQLite内存数据库)
|
||||
TEST_DATABASE_URL = "sqlite:///:memory:"
|
||||
# 导入所有模型,确保 Base.metadata 包含所有表
|
||||
from app.core.database import Base as _Base
|
||||
import app.models.user # noqa: F401
|
||||
import app.models.workflow # noqa: F401
|
||||
import app.models.agent # noqa: F401
|
||||
import app.models.execution # noqa: F401
|
||||
import app.models.model_config # noqa: F401
|
||||
import app.models.workflow_template # noqa: F401
|
||||
import app.models.permission # noqa: F401
|
||||
import app.models.alert_rule # noqa: F401
|
||||
import app.models.agent_llm_log # noqa: F401
|
||||
import app.models.agent_vector_memory # noqa: F401
|
||||
import app.models.knowledge_base # noqa: F401
|
||||
import app.models.tool # noqa: F401
|
||||
|
||||
assert len(_Base.metadata.tables) > 0, "没有模型表被注册"
|
||||
|
||||
# 测试数据库URL
|
||||
# 使用临时文件而非 :memory: 避免 FastAPI 异步/多线程请求中的
|
||||
# SQLite 内存数据库连接隔离问题(每个连接看到不同的数据库)
|
||||
import tempfile as _tempfile
|
||||
import os as _os
|
||||
import atexit as _atexit
|
||||
_test_db_fd, _test_db_path = _tempfile.mkstemp(suffix=".db")
|
||||
_os.close(_test_db_fd)
|
||||
TEST_DATABASE_URL = f"sqlite:///{_test_db_path}"
|
||||
|
||||
# 创建测试数据库引擎
|
||||
test_engine = create_engine(
|
||||
@@ -23,20 +46,27 @@ test_engine = create_engine(
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine)
|
||||
|
||||
|
||||
# 在进程退出时清理临时数据库文件
|
||||
def _cleanup_test_db():
|
||||
test_engine.dispose()
|
||||
if _os.path.exists(_test_db_path):
|
||||
_os.unlink(_test_db_path)
|
||||
|
||||
|
||||
_atexit.register(_cleanup_test_db)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""创建测试数据库会话"""
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=test_engine)
|
||||
|
||||
# 创建会话
|
||||
_Base.metadata.create_all(bind=test_engine)
|
||||
|
||||
session = TestingSessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
# 删除所有表
|
||||
Base.metadata.drop_all(bind=test_engine)
|
||||
_Base.metadata.drop_all(bind=test_engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -47,13 +77,13 @@ def client(db_session):
|
||||
yield db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with TestClient(app) as test_client:
|
||||
|
||||
fastapi_app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with TestClient(fastapi_app) as test_client:
|
||||
yield test_client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
fastapi_app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -72,7 +102,7 @@ def authenticated_client(client, test_user_data):
|
||||
# 注册用户
|
||||
response = client.post("/api/v1/auth/register", json=test_user_data)
|
||||
assert response.status_code == 201
|
||||
|
||||
|
||||
# 登录获取token
|
||||
login_response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
@@ -83,10 +113,10 @@ def authenticated_client(client, test_user_data):
|
||||
)
|
||||
assert login_response.status_code == 200
|
||||
token = login_response.json()["access_token"]
|
||||
|
||||
|
||||
# 设置认证头
|
||||
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||
|
||||
|
||||
return client
|
||||
|
||||
|
||||
|
||||
321
backend/tests/test_agent_memory.py
Normal file
321
backend/tests/test_agent_memory.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Agent 记忆系统单元测试
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class TestAgentContext:
|
||||
"""AgentContext 基础功能测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_context_initialization(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
user_id="test-user",
|
||||
session_id="test-session",
|
||||
)
|
||||
assert ctx.session_id == "test-session"
|
||||
assert ctx.user_id == "test-user"
|
||||
assert ctx.iteration == 0
|
||||
assert ctx.tool_calls_made == 0
|
||||
# system prompt 在 messages 中
|
||||
msgs = ctx.messages
|
||||
assert msgs[0]["role"] == "system"
|
||||
assert msgs[0]["content"] == "You are a helpful assistant."
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_user_message(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(session_id="s1")
|
||||
ctx.add_user_message("Hello")
|
||||
assert len(ctx.messages) == 2 # system + user
|
||||
assert ctx.messages[1]["role"] == "user"
|
||||
assert ctx.messages[1]["content"] == "Hello"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_assistant_message(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(session_id="s1")
|
||||
ctx.add_user_message("Hi")
|
||||
ctx.add_assistant_message("Hello!", tool_calls=[{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "test", "arguments": "{}"},
|
||||
}])
|
||||
msgs = ctx.messages
|
||||
assistant_msgs = [m for m in msgs if m["role"] == "assistant"]
|
||||
assert len(assistant_msgs) == 1
|
||||
assert assistant_msgs[0]["content"] == "Hello!"
|
||||
assert "tool_calls" in assistant_msgs[0]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_tool_result(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(session_id="s1")
|
||||
ctx.add_tool_result("call_1", "test_tool", '{"result": "ok"}')
|
||||
tool_msgs = [m for m in ctx.messages if m["role"] == "tool"]
|
||||
assert len(tool_msgs) == 1
|
||||
assert tool_msgs[0]["tool_call_id"] == "call_1"
|
||||
assert tool_msgs[0]["name"] == "test_tool"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_iteration_tracking(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext("s1")
|
||||
assert ctx.iteration == 0
|
||||
ctx.iteration += 1
|
||||
assert ctx.iteration == 1
|
||||
ctx.tool_calls_made += 2
|
||||
assert ctx.tool_calls_made == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_context_reset(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(system_prompt="Helpful.", session_id="s1")
|
||||
ctx.add_user_message("Hello")
|
||||
ctx.add_assistant_message("Hi")
|
||||
ctx.iteration = 5
|
||||
ctx.tool_calls_made = 3
|
||||
ctx.reset()
|
||||
assert ctx.iteration == 0
|
||||
assert ctx.tool_calls_made == 0
|
||||
# 重置后 messages 应仅含 system
|
||||
msgs = ctx.messages
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0]["role"] == "system"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_set_system_prompt(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(system_prompt="Original.", session_id="s1")
|
||||
ctx.set_system_prompt("Updated.")
|
||||
# 未发送过消息,可以更新
|
||||
assert ctx.messages[0]["content"] == "Updated."
|
||||
|
||||
|
||||
class TestAgentMemory:
|
||||
"""AgentMemory 分层记忆测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_no_persist(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
memory = AgentMemory(
|
||||
scope_kind="agent",
|
||||
scope_id="test-agent",
|
||||
session_key="test-session",
|
||||
persist=False,
|
||||
)
|
||||
result = await memory.initialize(query="Hello")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_context_no_persist(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
memory = AgentMemory(
|
||||
scope_kind="agent",
|
||||
scope_id="test-agent",
|
||||
session_key="test-session",
|
||||
persist=False,
|
||||
)
|
||||
await memory.save_context(
|
||||
user_message="Hello",
|
||||
assistant_reply="Hi there!",
|
||||
)
|
||||
# 不报错即为通过
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_vector_memory_disabled(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
memory = AgentMemory(
|
||||
scope_kind="agent",
|
||||
scope_id="test-agent",
|
||||
session_key="test-session",
|
||||
persist=False,
|
||||
vector_memory_enabled=False,
|
||||
)
|
||||
await memory.save_context(
|
||||
user_message="No vector",
|
||||
assistant_reply="OK",
|
||||
)
|
||||
# 不报错即为通过
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_vector_search_no_results(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
memory = AgentMemory(
|
||||
scope_kind="agent",
|
||||
scope_id="nonexistent",
|
||||
persist=False,
|
||||
vector_memory_enabled=True,
|
||||
)
|
||||
result = await memory._vector_search(query="test")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_trim_messages(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
memory = AgentMemory(persist=False)
|
||||
messages = [{"role": "system", "content": "You are helpful."}]
|
||||
for i in range(30):
|
||||
messages.append({"role": "user", "content": f"msg {i}"})
|
||||
messages.append({"role": "assistant", "content": f"reply {i}"})
|
||||
trimmed = memory.trim_messages(messages)
|
||||
assert len(trimmed) <= memory.max_history + 1 # +1 for system
|
||||
assert trimmed[0]["role"] == "system"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_summarize_history(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
history = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
summary = AgentMemory._summarize_history(history)
|
||||
assert "2 轮" in summary
|
||||
|
||||
|
||||
class TestToolManager:
|
||||
"""AgentToolManager 测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_include_filter(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
mgr = AgentToolManager(include_tools=["math", "file_read"])
|
||||
assert mgr._include_tools == {"math", "file_read"}
|
||||
assert mgr._exclude_tools == set()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exclude_filter(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
mgr = AgentToolManager(exclude_tools=["dangerous_tool"])
|
||||
assert "dangerous_tool" in mgr._exclude_tools
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tool_name_extraction(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
name = AgentToolManager._extract_tool_name({
|
||||
"type": "function",
|
||||
"function": {"name": "my_tool"},
|
||||
})
|
||||
assert name == "my_tool"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tool_name_extraction_flat(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
name = AgentToolManager._extract_tool_name({"name": "flat_tool"})
|
||||
assert name == "flat_tool"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tool_name_extraction_empty(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
name = AgentToolManager._extract_tool_name({})
|
||||
assert name is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_has_tools(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
mgr = AgentToolManager()
|
||||
assert mgr.has_tools() is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tool_names(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
mgr = AgentToolManager()
|
||||
names = mgr.tool_names()
|
||||
assert isinstance(names, list)
|
||||
assert len(names) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_delegates_to_registry(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
from app.services.tool_registry import tool_registry
|
||||
|
||||
mgr = AgentToolManager()
|
||||
# 执行一个内置工具
|
||||
result = await mgr.execute("datetime", {"format": "%Y"})
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
class TestAgentSchemas:
|
||||
"""Agent Schemas 测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_config_defaults(self):
|
||||
from app.agent_runtime.schemas import AgentConfig
|
||||
|
||||
config = AgentConfig(name="Test Agent")
|
||||
assert config.name == "Test Agent"
|
||||
assert config.system_prompt == "你是一个有用的AI助手。请使用可用工具来帮助用户完成任务。"
|
||||
assert config.llm.model == "gpt-4o-mini"
|
||||
assert config.llm.temperature == 0.7
|
||||
assert config.llm.max_iterations == 10
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_memory_config_defaults(self):
|
||||
from app.agent_runtime.schemas import AgentMemoryConfig
|
||||
|
||||
cfg = AgentMemoryConfig()
|
||||
assert cfg.enabled is True
|
||||
assert cfg.vector_memory_enabled is True
|
||||
assert cfg.vector_memory_top_k == 5
|
||||
assert cfg.max_history_messages == 20
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_budget_config_defaults(self):
|
||||
from app.agent_runtime.schemas import AgentBudgetConfig
|
||||
|
||||
cfg = AgentBudgetConfig()
|
||||
assert cfg.max_llm_invocations == 200
|
||||
assert cfg.max_tool_calls == 500
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_result_fields(self):
|
||||
from app.agent_runtime.schemas import AgentResult, AgentStep
|
||||
|
||||
result = AgentResult(
|
||||
content="Test result",
|
||||
iterations_used=3,
|
||||
tool_calls_made=5,
|
||||
truncated=False,
|
||||
steps=[
|
||||
AgentStep(iteration=1, type="think", content="Thinking..."),
|
||||
AgentStep(iteration=2, type="tool_call", content="Calling tool", tool_name="test"),
|
||||
],
|
||||
)
|
||||
assert result.success is True
|
||||
assert result.content == "Test result"
|
||||
assert result.iterations_used == 3
|
||||
assert len(result.steps) == 2
|
||||
137
backend/tests/test_document_parser.py
Normal file
137
backend/tests/test_document_parser.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
文档解析器单元测试
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
from app.services.document_parser import parse_document
|
||||
|
||||
|
||||
class TestDocumentParser:
|
||||
"""文档解析器测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_txt(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", encoding="utf-8", delete=False) as f:
|
||||
f.write("Hello world\nThis is a test.")
|
||||
tmp_path = f.name
|
||||
try:
|
||||
result = parse_document(tmp_path, "txt")
|
||||
assert "Hello world" in result
|
||||
assert "This is a test" in result
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_txt_utf8_bom(self):
|
||||
"""带 BOM 的 UTF-8 文件"""
|
||||
content = "\ufeff你好世界\n测试内容"
|
||||
with tempfile.NamedTemporaryFile(mode="wb", suffix=".txt", delete=False) as f:
|
||||
f.write(content.encode("utf-8-sig"))
|
||||
tmp_path = f.name
|
||||
try:
|
||||
result = parse_document(tmp_path, "txt")
|
||||
assert "你好" in result
|
||||
assert "测试" in result
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_md(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".md", encoding="utf-8", delete=False) as f:
|
||||
f.write("# Title\n\nThis is **bold** text.")
|
||||
tmp_path = f.name
|
||||
try:
|
||||
result = parse_document(tmp_path, "md")
|
||||
assert "Title" in result
|
||||
assert "bold" in result
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_csv(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", encoding="utf-8", delete=False) as f:
|
||||
f.write("name,age,city\nAlice,30,Beijing\nBob,25,Shanghai")
|
||||
tmp_path = f.name
|
||||
try:
|
||||
result = parse_document(tmp_path, "csv")
|
||||
assert "Alice" in result
|
||||
assert "Beijing" in result
|
||||
assert "Bob" in result
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_csv_with_different_delimiter(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", encoding="utf-8", delete=False) as f:
|
||||
f.write("a|b|c\n1|2|3\n4|5|6")
|
||||
tmp_path = f.name
|
||||
try:
|
||||
result = parse_document(tmp_path, "csv")
|
||||
assert "1" in result or "a" in result
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unsupported_format(self):
|
||||
with tempfile.NamedTemporaryFile(suffix=".xyz", delete=False) as f:
|
||||
tmp_path = f.name
|
||||
try:
|
||||
result = parse_document(tmp_path, "xyz")
|
||||
assert result is None # 不支持的格式返回 None
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_file_not_found(self):
|
||||
result = parse_document("/nonexistent/file.txt", "txt")
|
||||
assert result is None # 文件不存在返回 None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_file(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", encoding="utf-8", delete=False) as f:
|
||||
tmp_path = f.name
|
||||
try:
|
||||
result = parse_document(tmp_path, "txt")
|
||||
assert result is not None
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_parse_docx(self):
|
||||
"""需要 python-docx 包"""
|
||||
pytest.importorskip("docx")
|
||||
from docx import Document
|
||||
|
||||
doc = Document()
|
||||
doc.add_paragraph("Hello from docx")
|
||||
doc.add_paragraph("第二段落")
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as f:
|
||||
doc.save(f.name)
|
||||
tmp_path = f.name
|
||||
try:
|
||||
result = parse_document(tmp_path, "docx")
|
||||
assert "Hello from docx" in result
|
||||
assert "第二段落" in result
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_parse_pdf(self):
|
||||
"""需要 PyPDF2 包"""
|
||||
pytest.importorskip("PyPDF2")
|
||||
from PyPDF2 import PdfWriter
|
||||
|
||||
writer = PdfWriter()
|
||||
writer.add_blank_page(72, 72) # 1 inch page
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
writer.write(f)
|
||||
tmp_path = f.name
|
||||
try:
|
||||
result = parse_document(tmp_path, "pdf")
|
||||
assert result is not None
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
146
backend/tests/test_embedding_service.py
Normal file
146
backend/tests/test_embedding_service.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Embedding 服务单元测试
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
|
||||
class TestEmbeddingService:
|
||||
"""Embedding 服务测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_cosine_similarity_identical(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
a = [1.0, 0.0, 0.0]
|
||||
b = [1.0, 0.0, 0.0]
|
||||
sim = EmbeddingService.cosine_similarity(a, b)
|
||||
assert sim == pytest.approx(1.0, abs=1e-6)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cosine_similarity_orthogonal(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
a = [1.0, 0.0]
|
||||
b = [0.0, 1.0]
|
||||
sim = EmbeddingService.cosine_similarity(a, b)
|
||||
assert sim == pytest.approx(0.0, abs=1e-6)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cosine_similarity_opposite(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
a = [1.0, 0.0]
|
||||
b = [-1.0, 0.0]
|
||||
sim = EmbeddingService.cosine_similarity(a, b)
|
||||
assert sim == pytest.approx(-1.0, abs=1e-6)
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_similarity_search_empty(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
svc = EmbeddingService()
|
||||
results = await svc.similarity_search(
|
||||
[1.0, 0.0], [], top_k=5
|
||||
)
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_similarity_search_ordering(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
svc = EmbeddingService()
|
||||
entries = [
|
||||
{"content_text": "dogs are great pets", "embedding": [0.9, 0.0, 0.0]},
|
||||
{"content_text": "cats are independent", "embedding": [0.1, 0.0, 0.0]},
|
||||
]
|
||||
query = [0.8, 0.0, 0.0]
|
||||
results = await svc.similarity_search(query, entries, top_k=3, min_score=0.0)
|
||||
assert len(results) == 2
|
||||
assert results[0]["content_text"] == "dogs are great pets"
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_similarity_search_top_k(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
svc = EmbeddingService()
|
||||
entries = [
|
||||
{"content_text": f"entry {i}", "embedding": [float(i) / 10, 0.0]} for i in range(10)
|
||||
]
|
||||
query = [1.0, 0.0]
|
||||
results = await svc.similarity_search(query, entries, top_k=3)
|
||||
assert len(results) == 3
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_similarity_search_min_score(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
svc = EmbeddingService()
|
||||
entries = [
|
||||
{"content_text": "close match", "embedding": [0.9, 0.0]},
|
||||
{"content_text": "distant match", "embedding": [-0.5, 0.0]},
|
||||
]
|
||||
query = [1.0, 0.0]
|
||||
results = await svc.similarity_search(query, entries, top_k=5, min_score=0.5)
|
||||
assert len(results) == 1
|
||||
assert results[0]["content_text"] == "close match"
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_embedding_empty(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
svc = EmbeddingService()
|
||||
result = await svc.generate_embedding("")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_embedding_no_api_key(self):
|
||||
"""无 API Key 时返回 None"""
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
svc = EmbeddingService()
|
||||
with patch.object(svc, "_get_client", AsyncMock(return_value=None)):
|
||||
result = await svc.generate_embedding("test")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_embeddings_empty(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
svc = EmbeddingService()
|
||||
result = await svc.generate_embeddings([])
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serialize_deserialize(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
emb = [0.1, 0.2, 0.3]
|
||||
serialized = EmbeddingService.serialize_embedding(emb)
|
||||
deserialized = EmbeddingService.deserialize_embedding(serialized)
|
||||
assert deserialized == emb
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deserialize_invalid(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
result = EmbeddingService.deserialize_embedding("invalid json")
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deserialize_list_already(self):
|
||||
from app.services.embedding_service import EmbeddingService
|
||||
|
||||
result = EmbeddingService.deserialize_embedding([1.0, 2.0])
|
||||
assert result == [1.0, 2.0]
|
||||
124
backend/tests/test_knowledge_base.py
Normal file
124
backend/tests/test_knowledge_base.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
知识库 RAG 单元测试
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
|
||||
|
||||
class TestKnowledgeService:
|
||||
"""知识库服务测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_empty_kb(self):
|
||||
"""空知识库搜索返回空列表"""
|
||||
from app.services.knowledge_service import search
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_query = MagicMock()
|
||||
mock_db.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
mock_query.first.return_value = None
|
||||
|
||||
results = await search(mock_db, kb_id="nonexistent", query="test")
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_query_no_results(self):
|
||||
"""无检索结果时返回空上下文"""
|
||||
from app.services.knowledge_service import rag_query
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_query = MagicMock()
|
||||
mock_db.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
mock_query.first.return_value = None
|
||||
|
||||
result = await rag_query(mock_db, kb_id="test", query="no results")
|
||||
assert result["found"] is False
|
||||
assert result["context"] == ""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_content(self):
|
||||
"""模拟有内容的搜索结果"""
|
||||
from app.services.knowledge_service import search
|
||||
|
||||
from app.models.knowledge_base import DocumentChunk
|
||||
mock_chunk = MagicMock(spec=DocumentChunk)
|
||||
mock_chunk.id = "chunk-1"
|
||||
mock_chunk.content = "test content about Python programming"
|
||||
mock_chunk.chunk_index = 0
|
||||
mock_chunk.document_id = "doc-1"
|
||||
mock_chunk.metadata = {"filename": "test.txt"}
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_query = MagicMock()
|
||||
mock_db.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
with patch("app.services.knowledge_service.embedding_service.generate_embedding",
|
||||
AsyncMock(return_value=[0.1, 0.2, 0.3])):
|
||||
with patch("app.services.knowledge_service.embedding_service.similarity_search",
|
||||
AsyncMock(return_value=[
|
||||
{"content_text": "test content about Python", "score": 0.85, "metadata": {}}
|
||||
])):
|
||||
results = await search(mock_db, kb_id="test", query="Python")
|
||||
# 可能返回空(chunks filter 不匹配)但不报错
|
||||
assert isinstance(results, list)
|
||||
|
||||
|
||||
class TestKnowledgeModels:
|
||||
"""知识库模型测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_knowledge_base_model(self):
|
||||
from app.models.knowledge_base import KnowledgeBase
|
||||
|
||||
kb = KnowledgeBase(
|
||||
name="Test KB",
|
||||
description="Test",
|
||||
user_id="user-1",
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
)
|
||||
assert kb.name == "Test KB"
|
||||
assert kb.chunk_size == 500
|
||||
assert kb.chunk_overlap == 50
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_document_model(self):
|
||||
from app.models.knowledge_base import Document
|
||||
|
||||
doc = Document(
|
||||
kb_id="kb-1",
|
||||
filename="test.txt",
|
||||
file_type="txt",
|
||||
file_size=1024,
|
||||
status="completed",
|
||||
chunk_count=5,
|
||||
)
|
||||
assert doc.filename == "test.txt"
|
||||
assert doc.status == "completed"
|
||||
assert doc.chunk_count == 5
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_document_chunk_model(self):
|
||||
from app.models.knowledge_base import DocumentChunk
|
||||
|
||||
chunk = DocumentChunk(
|
||||
document_id="doc-1",
|
||||
kb_id="kb-1",
|
||||
chunk_index=0,
|
||||
content="test content",
|
||||
embedding="[0.1, 0.2, 0.3]",
|
||||
metadata={"source": "test"},
|
||||
)
|
||||
assert chunk.chunk_index == 0
|
||||
assert chunk.content == "test content"
|
||||
139
backend/tests/test_text_chunker.py
Normal file
139
backend/tests/test_text_chunker.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
文本分块器单元测试
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from app.services.text_chunker import chunk_text, _split_paragraphs, _split_long_paragraph, _merge_segments
|
||||
|
||||
|
||||
class TestSplitParagraphs:
|
||||
"""段落分割测试"""
|
||||
|
||||
def test_empty_text(self):
|
||||
assert _split_paragraphs("") == []
|
||||
assert _split_paragraphs(" ") == []
|
||||
assert _split_paragraphs("\n\n\n") == []
|
||||
|
||||
def test_single_paragraph(self):
|
||||
result = _split_paragraphs("Hello world")
|
||||
assert result == ["Hello world"]
|
||||
|
||||
def test_multiple_paragraphs(self):
|
||||
text = "第一段内容。\n\n第二段内容。\n\n第三段内容。"
|
||||
result = _split_paragraphs(text)
|
||||
assert len(result) == 3
|
||||
assert "第一段" in result[0]
|
||||
assert "第二段" in result[1]
|
||||
assert "第三段" in result[2]
|
||||
|
||||
def test_mixed_newlines(self):
|
||||
text = "段1\n\n\n段2\n\n段3"
|
||||
result = _split_paragraphs(text)
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
class TestSplitLongParagraph:
|
||||
"""超长段落分割测试"""
|
||||
|
||||
def test_short_paragraph(self):
|
||||
result = _split_long_paragraph("短文本", chunk_size=500)
|
||||
assert result == ["短文本"]
|
||||
|
||||
def test_long_paragraph_chinese(self):
|
||||
para = "第一句。" * 200 # 600 chars, exceeds 500
|
||||
result = _split_long_paragraph(para, chunk_size=500)
|
||||
assert len(result) >= 2
|
||||
assert all(len(c) <= 500 for c in result)
|
||||
|
||||
def test_long_paragraph_english(self):
|
||||
para = "Sentence. " * 200
|
||||
result = _split_long_paragraph(para, chunk_size=500)
|
||||
assert len(result) >= 2
|
||||
assert all(len(c) <= 500 for c in result)
|
||||
|
||||
def test_no_sentence_boundary(self):
|
||||
"""无句号可分割时按字符硬切"""
|
||||
para = "a" * 1000
|
||||
result = _split_long_paragraph(para, chunk_size=500)
|
||||
assert len(result) == 2
|
||||
assert len(result[0]) == 500
|
||||
assert len(result[1]) == 500
|
||||
|
||||
|
||||
class TestMergeSegments:
|
||||
"""段落合并测试"""
|
||||
|
||||
def test_empty(self):
|
||||
assert _merge_segments([], 500, 0) == []
|
||||
|
||||
def test_single_segment(self):
|
||||
result = _merge_segments(["hello"], 500, 0)
|
||||
assert result == ["hello"]
|
||||
|
||||
def test_merge_short_segments(self):
|
||||
segs = ["a" * 100, "b" * 100, "c" * 100] # each < 500, total ~300 < 500
|
||||
result = _merge_segments(segs, 500, 0)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) > 200
|
||||
|
||||
def test_split_large_segments(self):
|
||||
segs = ["a" * 300, "b" * 300, "c" * 300] # 需要分为多个chunk
|
||||
result = _merge_segments(segs, 500, 0)
|
||||
assert len(result) >= 2
|
||||
|
||||
def test_overlap(self):
|
||||
segs = ["a" * 300, "b" * 300]
|
||||
result = _merge_segments(segs, 500, 50)
|
||||
assert len(result) >= 1
|
||||
|
||||
|
||||
class TestChunkText:
|
||||
"""chunk_text 整体测试"""
|
||||
|
||||
def test_empty_text(self):
|
||||
assert chunk_text("") == []
|
||||
assert chunk_text(" ") == []
|
||||
assert chunk_text(None) == [] # type: ignore
|
||||
|
||||
def test_short_text(self):
|
||||
result = chunk_text("Hello world", chunk_size=500, chunk_overlap=0)
|
||||
assert len(result) == 1
|
||||
assert "Hello world" in result[0]
|
||||
|
||||
def test_normal_text(self):
|
||||
text = """
|
||||
这是第一段。它包含一些内容。
|
||||
|
||||
这是第二段。它也有一些内容。而且更长一些。
|
||||
|
||||
这是第三段。最后一段内容。
|
||||
"""
|
||||
result = chunk_text(text, chunk_size=500, chunk_overlap=0)
|
||||
assert len(result) >= 1
|
||||
# 所有内容都应该在结果中
|
||||
all_text = "".join(result)
|
||||
assert "第一段" in all_text
|
||||
assert "第二段" in all_text
|
||||
assert "第三段" in all_text
|
||||
|
||||
def test_chinese_text(self):
|
||||
"""中文文本测试"""
|
||||
text = "我喜欢吃川菜。特别是麻辣火锅。还有水煮鱼。这些都很美味。"
|
||||
result = chunk_text(text, chunk_size=100, chunk_overlap=0)
|
||||
assert len(result) >= 1
|
||||
|
||||
def test_overlap_between_chunks(self):
|
||||
"""验证块间重叠"""
|
||||
para = "这是一个很长的段落。" * 50
|
||||
result = chunk_text(para, chunk_size=200, chunk_overlap=50)
|
||||
if len(result) > 1:
|
||||
# 相邻块应该有重叠内容
|
||||
assert len(result[0]) > 0
|
||||
assert len(result[1]) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_with_mixed_punctuation(self):
|
||||
text = "Hello! How are you? I am fine. Thank you. 你好!最近怎么样?我很好。谢谢。"
|
||||
result = chunk_text(text, chunk_size=200, chunk_overlap=0)
|
||||
assert len(result) >= 1
|
||||
293
backend/tests/test_tool_registry.py
Normal file
293
backend/tests/test_tool_registry.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
工具注册表单元测试
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
from app.services.tool_registry import ToolRegistry, _CODE_SAFE_GLOBALS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry():
|
||||
r = ToolRegistry()
|
||||
return r
|
||||
|
||||
|
||||
class TestToolRegistryBuiltin:
|
||||
"""内置工具注册与查询"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_register_and_get(self, registry):
|
||||
def my_tool(**kwargs):
|
||||
return {"result": kwargs.get("x", 0) + kwargs.get("y", 0)}
|
||||
|
||||
schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add",
|
||||
"description": "加法",
|
||||
"parameters": {"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}},
|
||||
},
|
||||
}
|
||||
registry.register_builtin_tool("add", my_tool, schema)
|
||||
assert registry.get_tool_schema("add") == schema
|
||||
assert registry.get_tool_function("add") == my_tool
|
||||
assert registry.builtin_tool_count() == 1
|
||||
assert "add" in registry.builtin_tool_names()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_missing_tool(self, registry):
|
||||
assert registry.get_tool_schema("nonexistent") is None
|
||||
assert registry.get_tool_function("nonexistent") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_all_schemas(self, registry):
|
||||
schema1 = {"type": "function", "function": {"name": "tool1"}}
|
||||
schema2 = {"type": "function", "function": {"name": "tool2"}}
|
||||
registry.register_builtin_tool("tool1", lambda: None, schema1)
|
||||
registry.register_builtin_tool("tool2", lambda: None, schema2)
|
||||
assert len(registry.get_all_tool_schemas()) == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_tools_by_names(self, registry):
|
||||
registry.register_builtin_tool("a", lambda: None, {"function": {"name": "a"}})
|
||||
registry.register_builtin_tool("b", lambda: None, {"function": {"name": "b"}})
|
||||
tools = registry.get_tools_by_names(["a", "b", "c"])
|
||||
assert len(tools) == 2
|
||||
assert tools[0]["function"]["name"] == "a"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_sync_function_execution(self, registry):
|
||||
def sync_func(x=0, y=0):
|
||||
return x + y
|
||||
|
||||
registry.register_builtin_tool("add", sync_func, {"function": {"name": "add"}})
|
||||
result = registry._run_function(sync_func, "add", {"x": 3, "y": 4})
|
||||
import asyncio
|
||||
result = asyncio.run(registry._run_function(sync_func, "add", {"x": 3, "y": 4}))
|
||||
parsed = json.loads(result)
|
||||
assert parsed == 7
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("registry")
|
||||
class TestToolRegistryHTTP:
|
||||
"""HTTP 工具执行测试(mock httpx)"""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_http(self, registry):
|
||||
config = {
|
||||
"url": "https://api.example.com/data?q={query}",
|
||||
"method": "GET",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
"timeout": 10,
|
||||
"_type": "http",
|
||||
}
|
||||
registry._custom_tool_configs["test_http"] = config
|
||||
|
||||
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = '{"result": "ok"}'
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
result = await registry.execute_tool("test_http", {"query": "hello"})
|
||||
parsed = json.loads(result)
|
||||
assert parsed["status_code"] == 200
|
||||
# 验证 URL 模板替换
|
||||
called_url = mock_request.call_args[0][1]
|
||||
assert "hello" in called_url
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_http_post_with_body(self, registry):
|
||||
config = {
|
||||
"url": "https://api.example.com/submit",
|
||||
"method": "POST",
|
||||
"headers": {},
|
||||
"body_template": {"name": "{name}", "age": "{age}"},
|
||||
"timeout": 10,
|
||||
"_type": "http",
|
||||
}
|
||||
registry._custom_tool_configs["test_post"] = config
|
||||
|
||||
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.text = '{"id": 1}'
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
result = await registry.execute_tool("test_post", {"name": "Alice", "age": 30})
|
||||
parsed = json.loads(result)
|
||||
assert parsed["status_code"] == 201
|
||||
# Verify POST method
|
||||
assert mock_request.call_args[0][0] == "POST"
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_no_url(self, registry):
|
||||
config = {"_type": "http"}
|
||||
registry._custom_tool_configs["bad_http"] = config
|
||||
result = await registry.execute_tool("bad_http", {})
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestToolRegistryCode:
|
||||
"""代码沙箱执行测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_code_simple(self, registry):
|
||||
config = {
|
||||
"source": "def run(args):\n return {'sum': args['a'] + args['b']}",
|
||||
"_type": "code",
|
||||
}
|
||||
registry._custom_tool_configs["calc"] = config
|
||||
result = await registry.execute_tool("calc", {"a": 10, "b": 20})
|
||||
parsed = json.loads(result)
|
||||
assert parsed["sum"] == 30
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_code_text_stats(self, registry):
|
||||
config = {
|
||||
"source": "def run(args):\n text = args.get('text', '')\n return {'len': len(text), 'words': len(text.split())}",
|
||||
"_type": "code",
|
||||
}
|
||||
registry._custom_tool_configs["stats"] = config
|
||||
result = await registry.execute_tool("stats", {"text": "hello world test"})
|
||||
parsed = json.loads(result)
|
||||
assert parsed["len"] == 16
|
||||
assert parsed["words"] == 3
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_source_missing(self, registry):
|
||||
config = {"_type": "code"} # no source
|
||||
registry._custom_tool_configs["no_source"] = config
|
||||
result = await registry.execute_tool("no_source", {})
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_no_run_function(self, registry):
|
||||
config = {"source": "x = 1", "_type": "code"}
|
||||
registry._custom_tool_configs["no_run"] = config
|
||||
result = await registry.execute_tool("no_run", {})
|
||||
assert "run" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_runtime_error(self, registry):
|
||||
config = {
|
||||
"source": "def run(args):\n raise ValueError('test error')",
|
||||
"_type": "code",
|
||||
}
|
||||
registry._custom_tool_configs["err"] = config
|
||||
result = await registry.execute_tool("err", {})
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_sandbox_restriction(self, registry):
|
||||
"""验证 __builtins__ 被禁用"""
|
||||
config = {
|
||||
"source": "def run(args):\n import os\n return os.name",
|
||||
"_type": "code",
|
||||
}
|
||||
registry._custom_tool_configs["unsafe"] = config
|
||||
result = await registry.execute_tool("unsafe", {})
|
||||
# import 在沙箱中应该失败
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_sandbox_no_file_access(self, registry):
|
||||
"""验证无法访问文件系统"""
|
||||
config = {
|
||||
"source": "def run(args):\n open('/etc/passwd')\n return 'ok'",
|
||||
"_type": "code",
|
||||
}
|
||||
registry._custom_tool_configs["file_access"] = config
|
||||
result = await registry.execute_tool("file_access", {})
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestToolRegistryTestHelpers:
|
||||
"""测试工具(不保存到 DB)"""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_code_tool_success(self, registry):
|
||||
source = "def run(args):\n return {'result': args['x'] * 2}"
|
||||
result = await registry.test_code_tool(source, {"x": 5})
|
||||
assert result["success"] is True
|
||||
assert result["result"]["result"] == 10
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_code_tool_compile_error(self, registry):
|
||||
source = "def run(args):\n invalid syntax{{{"
|
||||
result = await registry.test_code_tool(source, {})
|
||||
assert result["success"] is False
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_http_tool(self, registry):
|
||||
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = '{"ip": "8.8.8.8"}'
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
from app.services.tool_registry import _CODE_SAFE_GLOBALS
|
||||
result = await registry.test_http_tool(
|
||||
url="https://httpbin.org/get?ip={ip}",
|
||||
method="GET",
|
||||
headers={},
|
||||
body=None,
|
||||
args={"ip": "8.8.8.8"},
|
||||
timeout=5,
|
||||
)
|
||||
assert result["success"] is True
|
||||
assert result["status_code"] == 200
|
||||
|
||||
|
||||
class TestToolRegistryExecute:
|
||||
"""execute_tool 整体流程"""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_builtin(self, registry):
|
||||
def hello(**kwargs):
|
||||
return f"Hello, {kwargs.get('name', 'world')}!"
|
||||
|
||||
registry.register_builtin_tool("hello", hello, {"function": {"name": "hello"}})
|
||||
result = await registry.execute_tool("hello", {"name": "Test"})
|
||||
assert "Hello, Test!" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_nonexistent(self, registry):
|
||||
result = await registry.execute_tool("no_such_tool", {})
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_unsupported_type(self, registry):
|
||||
config = {"_type": "unsupported"}
|
||||
registry._custom_tool_configs["weird"] = config
|
||||
result = await registry.execute_tool("weird", {})
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_workflow_not_supported(self, registry):
|
||||
config = {"_type": "workflow"}
|
||||
registry._custom_tool_configs["wf"] = config
|
||||
result = await registry.execute_tool("wf", {})
|
||||
assert "暂不支持" in result
|
||||
263
backend/tests/test_tools_api.py
Normal file
263
backend/tests/test_tools_api.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
工具市场 API 集成测试
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
|
||||
class TestToolsAPI:
|
||||
"""工具市场 API 测试"""
|
||||
|
||||
@pytest.mark.api
|
||||
def test_list_public_tools(self, authenticated_client):
|
||||
resp = authenticated_client.get("/api/v1/tools")
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
|
||||
@pytest.mark.api
|
||||
def test_list_categories(self, authenticated_client):
|
||||
resp = authenticated_client.get("/api/v1/tools/categories")
|
||||
assert resp.status_code == 200
|
||||
cats = resp.json()
|
||||
assert isinstance(cats, list)
|
||||
# 应包含默认分类
|
||||
assert "数据处理" in cats or "网络请求" in cats
|
||||
|
||||
@pytest.mark.api
|
||||
def test_list_without_auth(self, client):
|
||||
resp = client.get("/api/v1/tools")
|
||||
# 未认证时默认 scope=public 应返回 401 或公开工具
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.api
|
||||
def test_list_builtin(self, authenticated_client):
|
||||
resp = authenticated_client.get("/api/v1/tools/builtin")
|
||||
assert resp.status_code == 200
|
||||
tools = resp.json()
|
||||
assert isinstance(tools, list)
|
||||
assert len(tools) >= 10 # 期待至少 10 个内置工具
|
||||
|
||||
@pytest.mark.api
|
||||
def test_create_and_get_tool(self, authenticated_client):
|
||||
# 创建 HTTP 工具
|
||||
create_resp = authenticated_client.post("/api/v1/tools", json={
|
||||
"name": "echo_test",
|
||||
"description": "Echo test tool",
|
||||
"category": "network",
|
||||
"function_schema": {
|
||||
"name": "echo_test",
|
||||
"description": "Returns the input",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"msg": {"type": "string"}},
|
||||
"required": ["msg"],
|
||||
},
|
||||
},
|
||||
"implementation_type": "http",
|
||||
"implementation_config": {
|
||||
"url": "https://httpbin.org/post",
|
||||
"method": "POST",
|
||||
"headers": {},
|
||||
"timeout": 10,
|
||||
},
|
||||
"is_public": True,
|
||||
})
|
||||
assert create_resp.status_code == 201
|
||||
data = create_resp.json()
|
||||
assert data["name"] == "echo_test"
|
||||
assert data["implementation_type"] == "http"
|
||||
tool_id = data["id"]
|
||||
|
||||
# 获取工具详情
|
||||
get_resp = authenticated_client.get(f"/api/v1/tools/{tool_id}")
|
||||
assert get_resp.status_code == 200
|
||||
assert get_resp.json()["name"] == "echo_test"
|
||||
|
||||
@pytest.mark.api
|
||||
def test_create_code_tool(self, authenticated_client):
|
||||
create_resp = authenticated_client.post("/api/v1/tools", json={
|
||||
"name": "double_test",
|
||||
"description": "Double a number",
|
||||
"category": "math",
|
||||
"function_schema": {
|
||||
"name": "double_test",
|
||||
"description": "Doubles a number",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"n": {"type": "number"}},
|
||||
"required": ["n"],
|
||||
},
|
||||
},
|
||||
"implementation_type": "code",
|
||||
"implementation_config": {
|
||||
"source": "def run(args):\n n = args.get('n', 0)\n return {'result': n * 2}",
|
||||
"language": "python",
|
||||
},
|
||||
"is_public": True,
|
||||
})
|
||||
assert create_resp.status_code == 201
|
||||
data = create_resp.json()
|
||||
assert data["name"] == "double_test"
|
||||
assert data["implementation_type"] == "code"
|
||||
|
||||
@pytest.mark.api
|
||||
def test_create_duplicate_name(self, authenticated_client):
|
||||
# 先创建
|
||||
authenticated_client.post("/api/v1/tools", json={
|
||||
"name": "dup_tool",
|
||||
"description": "Dup",
|
||||
"function_schema": {"name": "dup_tool", "parameters": {"type": "object", "properties": {}}},
|
||||
"implementation_type": "code",
|
||||
"implementation_config": {"source": "def run(args):\n return {}"},
|
||||
"is_public": False,
|
||||
})
|
||||
# 重复创建应报错
|
||||
resp = authenticated_client.post("/api/v1/tools", json={
|
||||
"name": "dup_tool",
|
||||
"description": "Dup again",
|
||||
"function_schema": {"name": "dup_tool", "parameters": {"type": "object", "properties": {}}},
|
||||
"implementation_type": "code",
|
||||
"implementation_config": {"source": "def run(args):\n return {}"},
|
||||
"is_public": False,
|
||||
})
|
||||
assert resp.status_code == 400
|
||||
assert "已存在" in resp.json().get("detail", "")
|
||||
|
||||
@pytest.mark.api
|
||||
def test_invalid_implementation_type(self, authenticated_client):
|
||||
resp = authenticated_client.post("/api/v1/tools", json={
|
||||
"name": "bad_tool",
|
||||
"description": "Bad",
|
||||
"function_schema": {"name": "bad_tool", "parameters": {"type": "object", "properties": {}}},
|
||||
"implementation_type": "invalid_type",
|
||||
"is_public": False,
|
||||
})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.api
|
||||
def test_mine_scope(self, authenticated_client):
|
||||
resp = authenticated_client.get("/api/v1/tools?scope=mine")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.api
|
||||
def test_get_nonexistent_tool(self, authenticated_client):
|
||||
resp = authenticated_client.get("/api/v1/tools/nonexistent-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_http_endpoint(self, authenticated_client):
|
||||
"""测试 HTTP 工具的测试端点"""
|
||||
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = '{"origin": "1.2.3.4"}'
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
resp = authenticated_client.post("/api/v1/tools/test/http", json={
|
||||
"url": "https://httpbin.org/get",
|
||||
"method": "GET",
|
||||
"headers": {},
|
||||
"body": None,
|
||||
"args": {"ip": "8.8.8.8"},
|
||||
"timeout": 5,
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.api
|
||||
def test_test_code_endpoint(self, authenticated_client):
|
||||
resp = authenticated_client.post("/api/v1/tools/test/code", json={
|
||||
"source": "def run(args):\n return args.get('x', 0) + args.get('y', 0)",
|
||||
"args": {"x": 10, "y": 20},
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["result"] == 30
|
||||
|
||||
@pytest.mark.api
|
||||
def test_test_code_compile_error(self, authenticated_client):
|
||||
resp = authenticated_client.post("/api/v1/tools/test/code", json={
|
||||
"source": "invalid python {{{",
|
||||
"args": {},
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is False
|
||||
assert "error" in data
|
||||
|
||||
@pytest.mark.api
|
||||
def test_use_count(self, authenticated_client):
|
||||
# 先创建工具
|
||||
create_resp = authenticated_client.post("/api/v1/tools", json={
|
||||
"name": "count_test",
|
||||
"description": "Count test",
|
||||
"function_schema": {"name": "count_test", "parameters": {"type": "object", "properties": {}}},
|
||||
"implementation_type": "code",
|
||||
"implementation_config": {"source": "def run(args):\n return {}"},
|
||||
"is_public": True,
|
||||
})
|
||||
assert create_resp.status_code == 201
|
||||
tool_id = create_resp.json()["id"]
|
||||
|
||||
# 使用计数
|
||||
use_resp = authenticated_client.post(f"/api/v1/tools/{tool_id}/use")
|
||||
assert use_resp.status_code == 200
|
||||
assert use_resp.json()["use_count"] == 1
|
||||
|
||||
# 再次使用
|
||||
use_resp2 = authenticated_client.post(f"/api/v1/tools/{tool_id}/use")
|
||||
assert use_resp2.status_code == 200
|
||||
assert use_resp2.json()["use_count"] == 2
|
||||
|
||||
@pytest.mark.api
|
||||
def test_delete_tool(self, authenticated_client):
|
||||
# 创建
|
||||
create_resp = authenticated_client.post("/api/v1/tools", json={
|
||||
"name": "del_test",
|
||||
"description": "Delete test",
|
||||
"function_schema": {"name": "del_test", "parameters": {"type": "object", "properties": {}}},
|
||||
"implementation_type": "code",
|
||||
"implementation_config": {"source": "def run(args):\n return {}"},
|
||||
"is_public": False,
|
||||
})
|
||||
assert create_resp.status_code == 201
|
||||
tool_id = create_resp.json()["id"]
|
||||
|
||||
# 删除
|
||||
del_resp = authenticated_client.delete(f"/api/v1/tools/{tool_id}")
|
||||
assert del_resp.status_code == 200
|
||||
|
||||
# 确认已删除
|
||||
get_resp = authenticated_client.get(f"/api/v1/tools/{tool_id}")
|
||||
assert get_resp.status_code == 404
|
||||
|
||||
@pytest.mark.api
|
||||
def test_update_tool(self, authenticated_client):
|
||||
create_resp = authenticated_client.post("/api/v1/tools", json={
|
||||
"name": "update_test",
|
||||
"description": "Original desc",
|
||||
"function_schema": {"name": "update_test", "parameters": {"type": "object", "properties": {}}},
|
||||
"implementation_type": "code",
|
||||
"implementation_config": {"source": "def run(args):\n return {}"},
|
||||
"is_public": False,
|
||||
})
|
||||
assert create_resp.status_code == 201
|
||||
tool_id = create_resp.json()["id"]
|
||||
|
||||
update_resp = authenticated_client.put(f"/api/v1/tools/{tool_id}", json={
|
||||
"name": "update_test",
|
||||
"description": "Updated desc",
|
||||
"function_schema": {"name": "update_test", "parameters": {"type": "object", "properties": {}}},
|
||||
"implementation_type": "code",
|
||||
"implementation_config": {"source": "def run(args):\n return {'updated': True}"},
|
||||
"is_public": True,
|
||||
})
|
||||
assert update_resp.status_code == 200
|
||||
assert update_resp.json()["description"] == "Updated desc"
|
||||
assert update_resp.json()["is_public"] is True
|
||||
@@ -78,7 +78,9 @@
|
||||
</el-menu>
|
||||
|
||||
<!-- 页面内容 -->
|
||||
<slot />
|
||||
<div class="page-content">
|
||||
<slot />
|
||||
</div>
|
||||
</el-main>
|
||||
</el-container>
|
||||
</div>
|
||||
@@ -109,6 +111,7 @@ const activeMenu = computed(() => {
|
||||
if (route.path === '/monitoring') return 'monitoring'
|
||||
if (route.path === '/agent-monitoring') return 'agent-monitoring'
|
||||
if (route.path === '/alert-rules') return 'alert-rules'
|
||||
if (route.path === '/agent-chat' || route.path.startsWith('/agent-chat/')) return 'agent-chat'
|
||||
return 'workflows'
|
||||
})
|
||||
|
||||
@@ -189,4 +192,17 @@ const handleLogout = () => {
|
||||
height: 50px;
|
||||
line-height: 50px;
|
||||
}
|
||||
|
||||
:deep(.el-main) {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.page-content {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-height: 0;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
<template>
|
||||
<div class="agent-chat-page">
|
||||
<div class="chat-header">
|
||||
<div class="header-left">
|
||||
<h2>{{ chatMode === 'single' ? (agent ? agent.name : 'AI Agent 对话') : '多 Agent 编排' }}</h2>
|
||||
<MainLayout>
|
||||
<div class="page-header">
|
||||
<div class="page-header-left">
|
||||
<h3>{{ chatMode === 'single' ? (agent ? agent.name : 'AI Agent 对话') : '多 Agent 编排' }}</h3>
|
||||
</div>
|
||||
<div class="header-actions">
|
||||
<div class="page-header-actions">
|
||||
<el-switch
|
||||
v-model="chatMode"
|
||||
active-value="orchestrate"
|
||||
@@ -35,19 +35,19 @@
|
||||
</el-button>
|
||||
</template>
|
||||
|
||||
<el-button @click="clearChat" :disabled="messages.length === 0">清空</el-button>
|
||||
<el-button @click="clearChat" :disabled="displayMessages.length === 0">清空</el-button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="chat-messages" ref="messagesRef">
|
||||
<div v-if="messages.length === 0" class="chat-empty">
|
||||
<div v-if="displayMessages.length === 0" class="chat-empty">
|
||||
<el-icon :size="48"><ChatLineSquare /></el-icon>
|
||||
<p v-if="chatMode === 'single'">选择一个 Agent 开始对话</p>
|
||||
<p v-else>配置多个 Agent 后发送消息进行编排对话</p>
|
||||
<p class="hint">Agent 可以使用内置工具帮你完成任务</p>
|
||||
</div>
|
||||
|
||||
<div v-for="(msg, i) in messages" :key="i" class="message" :class="[msg.role, msg.status === 'error' ? 'error' : '']">
|
||||
<div v-for="(msg, i) in displayMessages" :key="i" class="message" :class="[msg.role, msg.status === 'error' ? 'error' : '']">
|
||||
<div class="message-avatar">
|
||||
<el-avatar :size="36" :icon="msg.role === 'user' ? UserFilled : Promotion" />
|
||||
</div>
|
||||
@@ -137,13 +137,21 @@
|
||||
</template>
|
||||
|
||||
<div class="message-meta">
|
||||
{{ msg.role === 'user' ? '用户' : 'Agent' }} · {{ formatTime(msg.timestamp) }}
|
||||
{{ msg.role === 'user' ? '用户' : 'Agent' }} · {{ relativeTime(msg.timestamp) }}
|
||||
<span v-if="msg.iterations" class="meta-iterations">· {{ msg.iterations }} 步 · {{ msg.tool_calls_made }} 次工具调用</span>
|
||||
<span class="meta-actions">
|
||||
<el-button v-if="msg.role === 'assistant'" link size="small" @click="copyMessage(msg)" title="复制">
|
||||
<el-icon><DocumentCopy /></el-icon>
|
||||
</el-button>
|
||||
<el-button v-if="msg.status === 'error'" link type="danger" size="small" @click="retryMessage(i)" title="重试">
|
||||
<el-icon><Refresh /></el-icon>
|
||||
</el-button>
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="loading" class="message assistant">
|
||||
<div v-if="loading && !streamingActive" class="message assistant">
|
||||
<div class="message-avatar"><el-avatar :size="36" icon="Promotion" /></div>
|
||||
<div class="message-bubble">
|
||||
<div class="thinking"><span class="dot"></span><span class="dot"></span><span class="dot"></span></div>
|
||||
@@ -159,7 +167,7 @@
|
||||
</div>
|
||||
|
||||
<!-- 编排 Agent 编辑器 -->
|
||||
<el-dialog v-model="showOrchestrateEditor" title="编排 Agent 配置" width="700px">
|
||||
<el-dialog v-model="showOrchestrateEditor" title="编排 Agent 配置" width="700px" @closed="saveState">
|
||||
<div class="orch-editor">
|
||||
<div v-for="(agt, i) in orchestrateAgents" :key="i" class="orch-agent-card">
|
||||
<div class="orch-agent-header">
|
||||
@@ -188,14 +196,15 @@
|
||||
<el-button @click="showOrchestrateEditor = false">关闭</el-button>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</div>
|
||||
</MainLayout>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, nextTick } from 'vue'
|
||||
import { ref, computed, watch, onMounted, nextTick } from 'vue'
|
||||
import { useRoute } from 'vue-router'
|
||||
import { ElMessage } from 'element-plus'
|
||||
import { ChatLineSquare, UserFilled, Promotion, Tools, CaretRight, ChatDotSquare, Select } from '@element-plus/icons-vue'
|
||||
import { ChatLineSquare, UserFilled, Promotion, Tools, CaretRight, ChatDotSquare, Select, DocumentCopy, Refresh } from '@element-plus/icons-vue'
|
||||
import MainLayout from '@/components/MainLayout.vue'
|
||||
import api from '@/api'
|
||||
import type { Agent } from '@/stores/agent'
|
||||
|
||||
@@ -221,16 +230,66 @@ interface OrchestrateAgentForm {
|
||||
temperature: number; max_iterations: number; description: string
|
||||
}
|
||||
|
||||
const STORAGE_KEY = 'agent_chat_state'
|
||||
|
||||
interface ChatState {
|
||||
messages: Record<string, ChatMessage[]> // keyed by agentId / '__bare__' / '__orchestrate__'
|
||||
sessionId: Record<string, string>
|
||||
currentAgentId: string
|
||||
chatMode: 'single' | 'orchestrate'
|
||||
orchestrateMode: string
|
||||
orchestrateAgents: OrchestrateAgentForm[]
|
||||
}
|
||||
|
||||
function saveState() {
|
||||
try {
|
||||
const state: ChatState = {
|
||||
messages: messages.value,
|
||||
sessionId: sessionId.value,
|
||||
currentAgentId: currentAgentId.value,
|
||||
chatMode: chatMode.value,
|
||||
orchestrateMode: orchestrateMode.value,
|
||||
orchestrateAgents: orchestrateAgents.value,
|
||||
}
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(state))
|
||||
} catch { /* quota exceeded, ignore */ }
|
||||
}
|
||||
|
||||
function loadState(): ChatState | null {
|
||||
try {
|
||||
const raw = localStorage.getItem(STORAGE_KEY)
|
||||
if (!raw) return null
|
||||
const parsed = JSON.parse(raw)
|
||||
// 兼容旧格式:旧版本 messages 是数组,迁移为 Record
|
||||
if (Array.isArray(parsed.messages)) {
|
||||
const oldSessionId = parsed.sessionId || ''
|
||||
parsed.messages = { '__bare__': parsed.messages }
|
||||
parsed.sessionId = { '__bare__': oldSessionId }
|
||||
}
|
||||
return parsed
|
||||
} catch { return null }
|
||||
}
|
||||
|
||||
const route = useRoute()
|
||||
const agents = ref<Agent[]>([])
|
||||
const currentAgentId = ref('')
|
||||
const messages = ref<ChatMessage[]>([])
|
||||
const messages = ref<Record<string, ChatMessage[]>>({})
|
||||
const inputMessage = ref('')
|
||||
const loading = ref(false)
|
||||
const streamingActive = ref(false)
|
||||
const messagesRef = ref<HTMLElement | null>(null)
|
||||
const sessionId = ref('')
|
||||
const sessionId = ref<Record<string, string>>({})
|
||||
const agent = ref<Agent | null>(null)
|
||||
|
||||
const currentAgentKey = computed(() => {
|
||||
if (chatMode.value === 'orchestrate') return '__orchestrate__'
|
||||
return currentAgentId.value || '__bare__'
|
||||
})
|
||||
|
||||
const displayMessages = computed(() => {
|
||||
return messages.value[currentAgentKey.value] || []
|
||||
})
|
||||
|
||||
// 编排模式
|
||||
const chatMode = ref<'single' | 'orchestrate'>('single')
|
||||
const orchestrateMode = ref('debate')
|
||||
@@ -251,14 +310,38 @@ function addOrchestrateAgent() {
|
||||
max_iterations: 10,
|
||||
description: '',
|
||||
})
|
||||
saveState()
|
||||
}
|
||||
|
||||
// 模式/配置变化时自动保存
|
||||
watch(chatMode, saveState)
|
||||
watch(orchestrateMode, saveState)
|
||||
|
||||
onMounted(async () => {
|
||||
await loadAgents()
|
||||
|
||||
const saved = loadState()
|
||||
if (saved) {
|
||||
messages.value = saved.messages
|
||||
sessionId.value = saved.sessionId
|
||||
chatMode.value = saved.chatMode
|
||||
orchestrateMode.value = saved.orchestrateMode
|
||||
orchestrateAgents.value = saved.orchestrateAgents
|
||||
// 恢复展开状态
|
||||
for (const arr of Object.values(messages.value)) {
|
||||
arr.forEach(m => { if (m.steps?.length) m._traceOpen = false })
|
||||
}
|
||||
}
|
||||
|
||||
if (route.params.id) {
|
||||
currentAgentId.value = route.params.id as string
|
||||
await switchAgent()
|
||||
} else if (saved?.currentAgentId) {
|
||||
currentAgentId.value = saved.currentAgentId
|
||||
await switchAgent()
|
||||
}
|
||||
|
||||
nextTick(scrollToBottom)
|
||||
})
|
||||
|
||||
async function loadAgents() {
|
||||
@@ -267,16 +350,25 @@ async function loadAgents() {
|
||||
}
|
||||
|
||||
async function switchAgent() {
|
||||
if (!currentAgentId.value) { agent.value = null; return }
|
||||
try { const resp = await api.get(`/api/v1/agents/${currentAgentId.value}`); agent.value = resp.data }
|
||||
catch { ElMessage.error('加载 Agent 失败'); agent.value = null }
|
||||
if (!currentAgentId.value) { agent.value = null; saveState(); nextTick(scrollToBottom); return }
|
||||
try {
|
||||
const resp = await api.get(`/api/v1/agents/${currentAgentId.value}`)
|
||||
agent.value = resp.data
|
||||
saveState()
|
||||
nextTick(scrollToBottom)
|
||||
} catch {
|
||||
ElMessage.error('加载 Agent 失败')
|
||||
agent.value = null
|
||||
}
|
||||
}
|
||||
|
||||
async function sendMessage() {
|
||||
const text = inputMessage.value.trim()
|
||||
if (!text || loading.value) return
|
||||
|
||||
messages.value.push({ role: 'user', content: text, timestamp: Date.now() })
|
||||
const key = currentAgentKey.value
|
||||
if (!messages.value[key]) messages.value[key] = []
|
||||
messages.value[key].push({ role: 'user', content: text, timestamp: Date.now() })
|
||||
inputMessage.value = ''
|
||||
loading.value = true
|
||||
scrollToBottom()
|
||||
@@ -294,36 +386,212 @@ async function sendMessage() {
|
||||
})
|
||||
const data = resp.data as OrchestrateResult
|
||||
data.steps.forEach(s => { s._open = false })
|
||||
messages.value.push({
|
||||
messages.value[key].push({
|
||||
role: 'assistant', content: data.final_answer, timestamp: Date.now(),
|
||||
orchestrateResult: data, _traceOpen: true,
|
||||
})
|
||||
} else {
|
||||
const endpoint = currentAgentId.value ? `/api/v1/agent-chat/${currentAgentId.value}` : '/api/v1/agent-chat/bare'
|
||||
const resp = await api.post(endpoint, { message: text, session_id: sessionId.value || undefined })
|
||||
const data = resp.data
|
||||
sessionId.value = data.session_id
|
||||
messages.value.push({
|
||||
role: 'assistant', content: data.content, timestamp: Date.now(),
|
||||
iterations: data.iterations_used, tool_calls_made: data.tool_calls_made,
|
||||
status: data.truncated ? 'error' : 'success', steps: data.steps || [],
|
||||
_traceOpen: data.steps && data.steps.length > 0,
|
||||
})
|
||||
const sessId = sessionId.value[key] || ''
|
||||
const streamEndpoint = currentAgentId.value
|
||||
? `/api/v1/agent-chat/${currentAgentId.value}/stream`
|
||||
: '/api/v1/agent-chat/bare/stream'
|
||||
|
||||
// 尝试 SSE 流式
|
||||
let usedStreaming = false
|
||||
streamingActive.value = false
|
||||
try {
|
||||
const token = localStorage.getItem('token') || ''
|
||||
const resp = await fetch(streamEndpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(token ? { 'Authorization': `Bearer ${token}` } : {}),
|
||||
},
|
||||
body: JSON.stringify({ message: text, session_id: sessId || undefined }),
|
||||
})
|
||||
|
||||
if (resp.ok && resp.body) {
|
||||
usedStreaming = true
|
||||
|
||||
// 创建占位消息,流式更新
|
||||
const msg: ChatMessage = {
|
||||
role: 'assistant', content: '', timestamp: Date.now(),
|
||||
steps: [], _traceOpen: true, iterations: 0, tool_calls_made: 0,
|
||||
}
|
||||
const idx = messages.value[key].push(msg) - 1
|
||||
const currentMsg = messages.value[key][idx]
|
||||
|
||||
const reader = resp.body.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
let receivedFirstEvent = false
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
const parts = buffer.split('\n\n')
|
||||
buffer = parts.pop() || ''
|
||||
|
||||
for (const part of parts) {
|
||||
const lines = part.split('\n')
|
||||
let eventType = ''
|
||||
let dataStr = ''
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('event: ')) eventType = line.slice(7)
|
||||
else if (line.startsWith('data: ')) dataStr = line.slice(6)
|
||||
}
|
||||
if (!dataStr) continue
|
||||
|
||||
try {
|
||||
const data = JSON.parse(dataStr)
|
||||
|
||||
// 首个事件到达 → 隐藏 loading dots
|
||||
if (!receivedFirstEvent && (eventType === 'think' || eventType === 'tool_call' || eventType === 'tool_result')) {
|
||||
receivedFirstEvent = true
|
||||
streamingActive.value = true
|
||||
}
|
||||
|
||||
if (eventType === 'think') {
|
||||
currentMsg.steps!.push({
|
||||
iteration: data.iteration, type: 'think',
|
||||
content: data.content || '',
|
||||
reasoning: data.reasoning,
|
||||
tool_name: data.tool_names?.[0],
|
||||
})
|
||||
} else if (eventType === 'tool_call') {
|
||||
currentMsg.steps!.push({
|
||||
iteration: data.iteration, type: 'tool_call',
|
||||
content: `调用工具: ${data.name}`,
|
||||
tool_name: data.name,
|
||||
tool_input: data.input,
|
||||
})
|
||||
} else if (eventType === 'tool_result') {
|
||||
currentMsg.steps!.push({
|
||||
iteration: data.iteration, type: 'tool_result',
|
||||
content: `工具 ${data.name} 返回结果`,
|
||||
tool_name: data.name,
|
||||
tool_result: data.result,
|
||||
})
|
||||
} else if (eventType === 'final') {
|
||||
currentMsg.content = data.content || ''
|
||||
currentMsg.iterations = data.iterations_used || 0
|
||||
currentMsg.tool_calls_made = data.tool_calls_made || 0
|
||||
if (data.session_id) {
|
||||
sessionId.value[key] = data.session_id
|
||||
}
|
||||
streamingActive.value = false
|
||||
} else if (eventType === 'error') {
|
||||
currentMsg.content = data.content || ''
|
||||
currentMsg.status = 'error'
|
||||
streamingActive.value = false
|
||||
}
|
||||
} catch { /* 跳过畸形事件 */ }
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch { /* 流式不可用,降级到普通 POST */
|
||||
streamingActive.value = false
|
||||
}
|
||||
|
||||
if (!usedStreaming) {
|
||||
// 降级:标准 POST 请求
|
||||
const fallbackEndpoint = currentAgentId.value
|
||||
? `/api/v1/agent-chat/${currentAgentId.value}`
|
||||
: '/api/v1/agent-chat/bare'
|
||||
const resp = await api.post(fallbackEndpoint, { message: text, session_id: sessId || undefined })
|
||||
const data = resp.data
|
||||
sessionId.value[key] = data.session_id
|
||||
messages.value[key].push({
|
||||
role: 'assistant', content: data.content, timestamp: Date.now(),
|
||||
iterations: data.iterations_used, tool_calls_made: data.tool_calls_made,
|
||||
status: data.truncated ? 'error' : 'success', steps: data.steps || [],
|
||||
_traceOpen: data.steps && data.steps.length > 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
saveState()
|
||||
} catch (e: any) {
|
||||
messages.value.push({
|
||||
messages.value[key].push({
|
||||
role: 'assistant', content: `错误:${e.response?.data?.detail || e.message || '请求失败'}`,
|
||||
timestamp: Date.now(), status: 'error',
|
||||
})
|
||||
saveState()
|
||||
} finally {
|
||||
loading.value = false; scrollToBottom()
|
||||
}
|
||||
}
|
||||
|
||||
function toggleTrace(msg: ChatMessage) { msg._traceOpen = !msg._traceOpen }
|
||||
function clearChat() { messages.value = []; sessionId.value = '' }
|
||||
function clearChat() {
|
||||
const key = currentAgentKey.value
|
||||
messages.value[key] = []
|
||||
if (chatMode.value === 'single') {
|
||||
sessionId.value[key] = ''
|
||||
}
|
||||
saveState()
|
||||
}
|
||||
function scrollToBottom() { nextTick(() => { if (messagesRef.value) messagesRef.value.scrollTop = messagesRef.value.scrollHeight }) }
|
||||
function formatTime(ts: number) { return new Date(ts).toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit' }) }
|
||||
|
||||
function relativeTime(ts: number): string {
|
||||
const diff = Date.now() - ts
|
||||
if (diff < 60000) return '刚刚'
|
||||
if (diff < 3600000) return `${Math.floor(diff / 60000)} 分钟前`
|
||||
if (diff < 86400000) return `${Math.floor(diff / 3600000)} 小时前`
|
||||
const d = new Date(ts)
|
||||
return `${d.getMonth() + 1}月${d.getDate()}日 ${d.getHours().toString().padStart(2, '0')}:${d.getMinutes().toString().padStart(2, '0')}`
|
||||
}
|
||||
|
||||
function copyMessage(msg: ChatMessage) {
|
||||
const text = msg.orchestrateResult?.final_answer || msg.content
|
||||
if (!text) return
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
ElMessage.success('已复制')
|
||||
}).catch(() => {
|
||||
ElMessage.warning('复制失败')
|
||||
})
|
||||
}
|
||||
|
||||
function retryMessage(idx: number) {
|
||||
const key = currentAgentKey.value
|
||||
const msgs = messages.value[key]
|
||||
if (!msgs) return
|
||||
|
||||
// 查找错误消息之前的最后一条用户消息
|
||||
let userMsg = ''
|
||||
for (let i = idx - 1; i >= 0; i--) {
|
||||
if (msgs[i].role === 'user') {
|
||||
userMsg = msgs[i].content
|
||||
break
|
||||
}
|
||||
}
|
||||
if (!userMsg) {
|
||||
ElMessage.warning('未找到可重试的消息')
|
||||
return
|
||||
}
|
||||
|
||||
// 移除该错误消息及关联的用户消息
|
||||
const removeIndices: number[] = []
|
||||
for (let i = idx - 1; i >= 0; i--) {
|
||||
if (msgs[i].role === 'user' && msgs[i].content === userMsg) {
|
||||
removeIndices.push(i)
|
||||
break
|
||||
}
|
||||
}
|
||||
removeIndices.push(idx)
|
||||
|
||||
// 从后往前删除,避免 index 错乱
|
||||
removeIndices.sort((a, b) => b - a)
|
||||
for (const ri of removeIndices) {
|
||||
msgs.splice(ri, 1)
|
||||
}
|
||||
|
||||
saveState()
|
||||
// 填入输入框并发送
|
||||
inputMessage.value = userMsg
|
||||
nextTick(() => sendMessage())
|
||||
}
|
||||
|
||||
function renderMarkdown(text: string): string {
|
||||
if (!text) return ''
|
||||
@@ -336,12 +604,12 @@ function renderMarkdown(text: string): string {
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.agent-chat-page { display: flex; flex-direction: column; height: calc(100vh - 120px); max-width: 960px; margin: 0 auto; padding: 16px; }
|
||||
.chat-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 16px; padding-bottom: 12px; border-bottom: 1px solid var(--el-border-color-light); }
|
||||
.header-left { display: flex; align-items: center; gap: 8px; }
|
||||
.header-left h2 { margin: 0; font-size: 18px; }
|
||||
.header-actions { display: flex; gap: 8px; align-items: center; flex-wrap: wrap; }
|
||||
.chat-messages { flex: 1; overflow-y: auto; padding: 12px 0; display: flex; flex-direction: column; gap: 16px; }
|
||||
.agent-chat-page { display: flex; flex-direction: column; height: 100%; max-width: 960px; margin: 0 auto; }
|
||||
.page-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 12px; padding-bottom: 12px; border-bottom: 1px solid var(--el-border-color-light); }
|
||||
.page-header-left { display: flex; align-items: center; gap: 8px; }
|
||||
.page-header-left h3 { margin: 0; font-size: 16px; font-weight: 600; }
|
||||
.page-header-actions { display: flex; gap: 8px; align-items: center; flex-wrap: wrap; }
|
||||
.chat-messages { flex: 1; overflow-y: auto; padding: 8px 0; display: flex; flex-direction: column; gap: 12px; min-height: 0; }
|
||||
.chat-empty { display: flex; flex-direction: column; align-items: center; justify-content: center; height: 100%; color: var(--el-text-color-secondary); gap: 12px; }
|
||||
.chat-empty .hint { font-size: 13px; color: var(--el-text-color-placeholder); }
|
||||
.message { display: flex; gap: 12px; max-width: 88%; }
|
||||
@@ -358,8 +626,10 @@ function renderMarkdown(text: string): string {
|
||||
.tool-calls-header { display: flex; align-items: center; gap: 4px; font-size: 12px; color: var(--el-text-color-secondary); margin-bottom: 4px; }
|
||||
.tool-call-item { display: flex; align-items: center; gap: 8px; padding: 4px 8px; font-size: 12px; }
|
||||
.tool-name { font-weight: 500; color: var(--el-color-primary); }
|
||||
.message-meta { font-size: 11px; color: var(--el-text-color-placeholder); margin-top: 4px; }
|
||||
.message-meta { font-size: 11px; color: var(--el-text-color-placeholder); margin-top: 4px; display: flex; align-items: center; gap: 4px; flex-wrap: wrap; }
|
||||
.meta-iterations { color: var(--el-color-info); }
|
||||
.meta-actions { margin-left: auto; display: flex; gap: 2px; opacity: 0; transition: opacity 0.15s; }
|
||||
.message-bubble:hover .meta-actions { opacity: 1; }
|
||||
|
||||
/* Thinking trace */
|
||||
.thinking-trace { margin-top: 10px; border-top: 1px solid var(--el-border-color-light); padding-top: 8px; }
|
||||
|
||||
@@ -99,6 +99,33 @@
|
||||
</el-col>
|
||||
</el-row>
|
||||
|
||||
<!-- 执行预算 -->
|
||||
<el-divider content-position="left">执行预算</el-divider>
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="8">
|
||||
<el-form-item label="LLM 调用次数上限">
|
||||
<el-input-number
|
||||
v-model="form.max_llm_invocations"
|
||||
:min="1"
|
||||
:max="10000"
|
||||
style="width: 100%"
|
||||
/>
|
||||
<div class="form-tip">单次会话最多可调用 LLM 的次数,超限将停止执行</div>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="8">
|
||||
<el-form-item label="工具调用次数上限">
|
||||
<el-input-number
|
||||
v-model="form.max_tool_calls"
|
||||
:min="1"
|
||||
:max="50000"
|
||||
style="width: 100%"
|
||||
/>
|
||||
<div class="form-tip">单次会话最多可调用工具的次数,超限将停止执行</div>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
|
||||
<!-- 工具选择 -->
|
||||
<el-form-item label="可用工具">
|
||||
<el-checkbox-group v-model="form.tools" class="tool-checkbox-group">
|
||||
@@ -149,6 +176,9 @@ const form = ref({
|
||||
max_iterations: 10,
|
||||
tools: [] as string[],
|
||||
memory_enabled: true,
|
||||
// 预算配置
|
||||
max_llm_invocations: 200,
|
||||
max_tool_calls: 500,
|
||||
})
|
||||
|
||||
onMounted(async () => {
|
||||
@@ -173,6 +203,11 @@ onMounted(async () => {
|
||||
form.value.max_iterations = data.max_iterations ?? form.value.max_iterations
|
||||
form.value.tools = Array.isArray(data.tools) ? [...data.tools] : []
|
||||
form.value.memory_enabled = data.memory !== false
|
||||
|
||||
// 从 Agent budget_config 加载预算配置
|
||||
const bc = a.budget_config || {}
|
||||
form.value.max_llm_invocations = bc.max_llm_invocations ?? form.value.max_llm_invocations
|
||||
form.value.max_tool_calls = bc.max_tool_calls ?? form.value.max_tool_calls
|
||||
} catch (e: any) {
|
||||
ElMessage.error('加载 Agent 失败')
|
||||
router.push('/agents')
|
||||
@@ -241,7 +276,13 @@ async function handleSave() {
|
||||
targetNode.data.memory = form.value.memory_enabled
|
||||
|
||||
wf.nodes = nodes
|
||||
await agentStore.updateAgent(agent.value.id, { workflow_config: wf })
|
||||
await agentStore.updateAgent(agent.value.id, {
|
||||
workflow_config: wf,
|
||||
budget_config: {
|
||||
max_llm_invocations: form.value.max_llm_invocations,
|
||||
max_tool_calls: form.value.max_tool_calls,
|
||||
},
|
||||
})
|
||||
ElMessage.success('配置已保存')
|
||||
} catch (e: any) {
|
||||
ElMessage.error(e.response?.data?.detail || '保存失败')
|
||||
|
||||
@@ -10,36 +10,47 @@
|
||||
|
||||
## 已完成改造
|
||||
|
||||
### 新增文件(14 个)
|
||||
### 新增文件(22 个)
|
||||
|
||||
| 文件 | 行数 | 用途 |
|
||||
|------|------|------|
|
||||
| `backend/app/agent_runtime/__init__.py` | 45 | 包导出 |
|
||||
| `backend/app/agent_runtime/schemas.py` | 100 | Agent 配置 Schema + AgentStep 执行追踪 |
|
||||
| `backend/app/agent_runtime/schemas.py` | 115 | Agent 配置 Schema + AgentStep 执行追踪 + 向量记忆配置 |
|
||||
| `backend/app/agent_runtime/context.py` | 85 | 会话上下文 |
|
||||
| `backend/app/agent_runtime/memory.py` | 155 | 分层记忆管理器 + LLM 自动压缩总结 |
|
||||
| `backend/app/agent_runtime/tool_manager.py` | 80 | 工具管理器 |
|
||||
| `backend/app/agent_runtime/core.py` | 260 | **AgentRuntime 主循环 + 执行追踪 + LLM 埋点** |
|
||||
| `backend/app/agent_runtime/memory.py` | 200 | 分层记忆管理器 + LLM 自动压缩总结 + 向量记忆检索/保存 |
|
||||
| `backend/app/agent_runtime/tool_manager.py` | 77 | 工具管理器(精简,委托给 ToolRegistry) |
|
||||
| `backend/app/agent_runtime/core.py` | 290 | **AgentRuntime 主循环 + 执行追踪 + LLM 埋点 + 预算控制** |
|
||||
| `backend/app/agent_runtime/orchestrator.py` | 380 | **多 Agent 编排引擎** |
|
||||
| `backend/app/agent_runtime/workflow_integration.py` | 100 | 工作流桥接 |
|
||||
| `backend/app/api/agent_chat.py` | 250 | Agent 聊天 + 多 Agent 编排 + LLM 调用日志 |
|
||||
| `backend/app/api/agent_chat.py` | 280 | Agent 聊天 + 多 Agent 编排 + LLM 调用日志 + SSE 流式输出 |
|
||||
| `backend/app/api/agent_monitoring.py` | 55 | **Agent 监控 API(5 个端点)** |
|
||||
| `backend/app/services/agent_monitoring_service.py` | 140 | **Agent 监控服务(5 个统计方法)** |
|
||||
| `backend/app/services/embedding_service.py` | 65 | **Embedding 生成服务(SiliconFlow / OpenAI 适配)** |
|
||||
| `backend/app/services/knowledge_service.py` | 180 | **知识库 RAG 服务(上传 → 解析 → 切片 → 向量化 → 检索)** |
|
||||
| `backend/app/services/document_parser.py` | 135 | **文档解析器(txt/pdf/docx/csv)** |
|
||||
| `backend/app/services/text_chunker.py` | 132 | **文本分块器(段落/句子的语义分块 + overlap)** |
|
||||
| `backend/app/services/tool_registry.py` | 361 | **工具注册表(内置 + HTTP 执行 + 代码沙箱执行)** |
|
||||
| `backend/app/api/tools.py` | 325 | **工具市场 API(CRUD + 测试沙箱 + 分类浏览 + 使用计数)** |
|
||||
| `backend/app/models/agent_llm_log.py` | 30 | **Agent LLM 调用日志模型** |
|
||||
| `frontend/src/views/AgentChat.vue` | 370 | Agent 聊天界面 + 多 Agent 编排 UI |
|
||||
| `backend/app/models/agent_vector_memory.py` | 30 | **Agent 向量记忆表模型** |
|
||||
| `backend/app/models/knowledge_base.py` | 80 | **知识库/文档/文档块三表模型** |
|
||||
| `frontend/src/views/AgentChat.vue` | 370 | Agent 聊天界面 + 多 Agent 编排 UI + SSE 流式渲染 |
|
||||
| `frontend/src/views/AgentDashboard.vue` | 260 | **Agent 监控 Dashboard** |
|
||||
|
||||
### 修改文件(10 个)
|
||||
|
||||
| 文件 | 改动 |
|
||||
|------|------|
|
||||
| `backend/app/services/workflow_engine.py` | `execute_node()` 新增 `agent` 节点类型分支(约 50 行) |
|
||||
| `backend/app/main.py` | 注册 `agent_chat` + `agent_monitoring` 路由模块 |
|
||||
| `backend/app/core/database.py` | `init_db` 导入 `agent_llm_log` 模型 |
|
||||
| `backend/app/models/__init__.py` | 导出 `AgentLLMLog` |
|
||||
| `backend/app/agent_runtime/core.py` | `_LLMClient.chat()` 埋点: timing + token 采集 + `on_completion` 回调;`AgentRuntime` 新增 `on_llm_call` 参数 |
|
||||
| `backend/app/services/workflow_engine.py` | `execute_node()` 新增 `agent` 节点类型分支(约 50 行);Agent 节点注入 `budget_limits` + `_on_agent_llm` 预算回调 |
|
||||
| `backend/app/main.py` | 注册 `agent_chat` + `agent_monitoring` + `tools` + `knowledge_base` 路由模块 |
|
||||
| `backend/app/core/database.py` | `init_db` 导入 `agent_llm_log` + `agent_vector_memory` + `knowledge_base` 模型 |
|
||||
| `backend/app/models/__init__.py` | 导出 `AgentLLMLog`、`AgentVectorMemory`、`KnowledgeBase`、`Document`、`DocumentChunk` |
|
||||
| `backend/app/agent_runtime/core.py` | `_LLMClient.chat()` 埋点: timing + token 采集 + `on_completion` 回调;`AgentRuntime` 新增 `on_llm_call` 参数;预算检查移至 LLM 调用前;`on_tool_executed` 重抛 `WorkflowExecutionError` |
|
||||
| `backend/app/agent_runtime/orchestrator.py` | 三种编排模式透传 `on_llm_call` 到子 Agent |
|
||||
| `backend/app/api/agent_chat.py` | 三个端点注入 `on_llm_call` 回调,写入 `AgentLLMLog` 表 |
|
||||
| `backend/app/api/agent_chat.py` | 三个端点注入 `on_llm_call` 回调,写入 `AgentLLMLog` 表;添加 SSE 流式输出端点 |
|
||||
| `backend/app/agent_runtime/memory.py` | `initialize()` 注入向量记忆到 system prompt;`save_context()` 保存对话后生成 embedding 存入 `AgentVectorMemory` |
|
||||
| `backend/app/agent_runtime/schemas.py` | `AgentMemoryConfig` 新增 `vector_memory_enabled` / `vector_memory_top_k` |
|
||||
| `backend/app/agent_runtime/tool_manager.py` | 精简 `execute()` 直接委托 `tool_registry.execute_tool()` |
|
||||
| `frontend/src/router/index.ts` | 添加 `/agent-chat`、`/agent-chat/:id`、`/agents/:id/config`、`/agent-monitoring` 四条路由 |
|
||||
| `frontend/src/components/MainLayout.vue` | 导航栏添加"Agent对话"+"Agent监控"入口 |
|
||||
| `frontend/src/views/Agents.vue` | Agent 列表添加"配置"按钮跳转 AgentConfig |
|
||||
@@ -78,12 +89,20 @@
|
||||
AgentRuntime (新增)
|
||||
│
|
||||
├── ToolManager ──────→ ToolRegistry + builtin_tools (已有)
|
||||
│ ├── HTTP 工具执行 ─→ httpx.AsyncClient
|
||||
│ └── 代码工具执行 ─→ 沙箱 exec (__builtins__ 禁用)
|
||||
│
|
||||
├── Memory ───────────→ persistent_memory_service (已有)
|
||||
│ └── MySQL (已有)
|
||||
│ ├── MySQL (已有) — 键值记忆
|
||||
│ └── AgentVectorMemory (新增) — 向量记忆语义检索
|
||||
│ └── EmbeddingService ─→ SiliconFlow / OpenAI API
|
||||
│
|
||||
├── _LLMClient ───────→ OpenAI SDK (已有)
|
||||
│ └── on_completion → AgentLLMLog (新增) → MySQL
|
||||
│ ├── on_completion → AgentLLMLog (新增) → MySQL
|
||||
│ └── AgentBudgetConfig → WorkflowEngine._llm_invocations (预算联动)
|
||||
│
|
||||
├── SSE Stream ───────→ POST .../bare/stream → text/event-stream
|
||||
│ └── AgentStep → 实时推送到前端
|
||||
│
|
||||
├── Context ──────────→ 纯内存,无外部依赖
|
||||
│
|
||||
@@ -92,6 +111,18 @@ AgentRuntime (新增)
|
||||
│ ├── sequential: Agent A → Agent B → Agent C
|
||||
│ └── debate: Agent 独立回答 → Aggregator 汇总
|
||||
│
|
||||
├── KnowledgeBase RAG (新增)
|
||||
│ ├── DocumentParser → txt/pdf/docx/csv/md 解析
|
||||
│ ├── TextChunker → 段落/句子的语义分块 + overlap
|
||||
│ ├── EmbeddingService → 向量化
|
||||
│ └── 余弦相似度检索 → Top-K 上下文注入 system prompt
|
||||
│
|
||||
├── Tool Marketplace API (新增)
|
||||
│ ├── GET /tools + /categories — 市场浏览
|
||||
│ ├── POST /tools — 创建 HTTP/Code 工具
|
||||
│ ├── POST /tools/test/{http,code} — 沙箱测试
|
||||
│ └── POST /tools/{id}/use — 使用计数
|
||||
│
|
||||
└── AgentMonitoring API (新增)
|
||||
├── /overview → 概览统计
|
||||
├── /llm-calls → LLM 调用记录
|
||||
@@ -103,14 +134,14 @@ AgentRuntime (新增)
|
||||
### 新增代码行数统计
|
||||
|
||||
```
|
||||
agent_runtime/ → 约 1080 行 Python
|
||||
api/ → 约 305 行 Python(agent_chat 250 + agent_monitoring 55)
|
||||
services → 约 140 行 Python(agent_monitoring_service)
|
||||
models → 约 30 行 Python(agent_llm_log)
|
||||
agent_runtime/ → 约 1150 行 Python
|
||||
api/ → 约 660 行 Python(agent_chat 280 + agent_monitoring 55 + tools 325)
|
||||
services → 约 880 行 Python(embedding_service 65 + knowledge_service 180 + document_parser 135 + text_chunker 132 + monitoring_service 140 + tool_registry 361)
|
||||
models → 约 140 行 Python(agent_llm_log 30 + agent_vector_memory 30 + knowledge_base 80)
|
||||
frontend → 约 630 行 Vue/TypeScript(AgentChat 370 + AgentDashboard 260)
|
||||
修改(非新增) → 约 120 行 Python/TS
|
||||
修改(非新增) → 约 200 行 Python/TS
|
||||
─────────────────────────────────────────
|
||||
总计新增 → 约 2300 行
|
||||
总计新增 → 约 3460 行
|
||||
```
|
||||
|
||||
---
|
||||
@@ -122,14 +153,19 @@ frontend → 约 630 行 Vue/TypeScript(AgentChat 370 + AgentDas
|
||||
- 工具调用与记忆管理:✅ 完成
|
||||
- 执行追踪与记忆压缩:✅ 完成
|
||||
- 配置页面与聊天界面:✅ 完成
|
||||
- SSE 流式输出:✅ 完成(Agent 思考过程实时推送到前端)
|
||||
- 多 Agent 编排(路由/顺序/辩论):✅ 完成
|
||||
- 编排前端可视化界面:✅ 完成
|
||||
- LLM 调用埋点与日志:✅ 完成
|
||||
- Agent 监控 Dashboard:✅ 完成
|
||||
- 向量记忆(语义检索):✅ 完成(Embedding + 余弦相似度 Top-K)
|
||||
- 知识库 RAG:✅ 完成(上传 → 解析 5 种格式 → 切片 → 向量化 → 检索)
|
||||
- 工作流预算接入:✅ 完成(Agent LLM/工具调用计入 WorkflowEngine 全局预算)
|
||||
- 工具市场:✅ 完成(HTTP/Code 工具 CRUD + 沙箱测试 + 市场浏览)
|
||||
|
||||
**未完成项**:工作流预算接入、向量记忆、流式输出、知识库 RAG、自主学习。
|
||||
**未完成项**:自主学习。
|
||||
|
||||
整体完成度:**95-97% → 97-98%**(Agent Dashboard 补齐了可观测能力)
|
||||
整体完成度:**95-97% → 99%**(所有核心功能闭环,仅剩自主学习一项长期规划)
|
||||
|
||||
---
|
||||
|
||||
@@ -195,7 +231,7 @@ POST /api/v1/agent-chat/{agent_id}
|
||||
|
||||
## 后续计划
|
||||
|
||||
### 短期(1-2 周)
|
||||
### 短期(1-2 周)— 全部完成
|
||||
|
||||
| 项目 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
@@ -204,17 +240,17 @@ POST /api/v1/agent-chat/{agent_id}
|
||||
| 执行追踪 | ✅ 完成 | 后端返回 steps,前端 AgentChat.vue 可展开显示思考链 |
|
||||
| 多 Agent 编排 | ✅ 完成 | 三种模式:route(路由分发)、sequential(流水线)、debate(独立回答+汇总) |
|
||||
| 编排前端 UI | ✅ 完成 | AgentChat.vue 新增模式切换、Agent 编辑弹窗、步骤展开 |
|
||||
| 预算接入 | ⬜ | Agent 内部 LLM 调用也计入工作流执行预算 |
|
||||
| 预算接入 | ✅ 完成 | Agent 内部 LLM 调用也计入工作流执行预算——通过 `AgentBudgetConfig` 内控 + `on_llm_invocation` 外调双向保障 |
|
||||
| SSE 流式输出 | ✅ 完成 | 新增 `POST /api/v1/agent-chat/bare/stream` 端点,SSE 实时推送思考步骤和最终答案 |
|
||||
|
||||
### 中期(1-2 月)
|
||||
### 中期(1-2 月)— 全部完成
|
||||
|
||||
| 项目 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 向量记忆 | ⬜ | 集成 Embedding API + 向量检索(语义记忆) |
|
||||
| 向量记忆 | ✅ 完成 | SiliconFlow Embedding API(bce-embedding-base_v1, 768维)+ 余弦相似度 Top-K 检索,无 API Key 时自动降级 |
|
||||
| Agent Dashboard | ✅ 完成 | Agent 专属监控面板:LLM 调用追踪、Token 统计、Agent 用量排行、工具调用频次、日趋势图 |
|
||||
| 工具市场 | ⬜ | 用户可上传自定义工具定义 |
|
||||
| 流式输出 | ⬜ | Agent 思考过程实时推送到前端 |
|
||||
| 知识库 | ⬜ | 文件上传 → 切片 → 向量化 → RAG 检索 |
|
||||
| 工具市场 | ✅ 完成 | HTTP/Code 工具 CRUD + 沙箱测试(httpx / 沙箱 exec)+ 市场浏览(分类/搜索/scope 筛选)+ 使用计数 |
|
||||
| 知识库 | ✅ 完成 | 文件上传 → 解析 5 种格式(txt/pdf/docx/csv/md)→ 分块(段落/句子)→ 嵌入 → 语义检索 → RAG 上下文注入 |
|
||||
|
||||
### 长期(3-6 月)
|
||||
|
||||
@@ -254,3 +290,15 @@ POST /api/v1/agent-chat/{agent_id}
|
||||
- [x] Agent 监控 Dashboard 前端路由和导航配置
|
||||
- [x] `on_llm_call` 回调在 /bare /{agent_id} /orchestrate 三个端点均注入
|
||||
- [x] 编排三种模式透传 `on_llm_call` 到子 AgentRuntime
|
||||
- [x] **SSE 流式输出**:`POST /api/v1/agent-chat/bare/stream` 返回 `text/event-stream`
|
||||
- [x] **向量记忆**:`AgentVectorMemory` 表自动创建;Embedding 生成 + 余弦相似度检索可用
|
||||
- [x] **工作流预算接入**:`AgentRuntime` 内 `AgentBudgetConfig` 限制 LLM/工具调用次数;外部 `on_llm_invocation` 回调同步到 `WorkflowEngine._llm_invocations`
|
||||
- [x] **知识库 RAG**:知识库/文档/文档块三表自动创建;支持 txt/pdf/docx/csv 解析;分块 + Embedding + 语义检索可用
|
||||
- [x] **工具市场**:
|
||||
- [x] `GET /api/v1/tools/categories` — 返回分类列表
|
||||
- [x] `GET /api/v1/tools` — scope/mine/public/all 筛选
|
||||
- [x] `POST /api/v1/tools` — 创建 HTTP 工具(ip_info_test)和代码工具(text_stats_test)
|
||||
- [x] `POST /api/v1/tools/test/http` — HTTP 沙箱测试(httpbin 请求 200/4239ms)
|
||||
- [x] `POST /api/v1/tools/test/code` — 代码沙箱测试(文本统计 0ms)
|
||||
- [x] `POST /api/v1/tools/{id}/use` — 使用计数自增
|
||||
- [x] `DELETE /api/v1/tools/{id}` — 删除自定义工具
|
||||
|
||||
Reference in New Issue
Block a user