diff --git a/(红头)项目核心文档汇总.md b/(红头)项目核心文档汇总.md index f8dfa9a..a9123bd 100644 --- a/(红头)项目核心文档汇总.md +++ b/(红头)项目核心文档汇总.md @@ -87,6 +87,7 @@ - **AI框架**: LangChain - **Agent Runtime**: 自研 ReAct 循环(零重构,寄生式复用现有服务) - **Agent Orchestrator**: 多 Agent 编排引擎(路由/顺序/辩论三种模式) +- **Agent 监控**: LLM 调用埋点 + 专属 Dashboard(Token/工具/Agent 用量统计) - **数据库ORM**: SQLAlchemy - **迁移工具**: Alembic - **认证**: JWT @@ -212,11 +213,14 @@ aiagent/ │ │ │ ├── orchestrator.py# 多 Agent 编排引擎(route/sequential/debate) │ │ │ └── workflow_integration.py # 工作流桥接 │ │ ├── api/ # API路由 -│ │ │ └── agent_chat.py # Agent 独立聊天 API(新增) +│ │ │ ├── agent_chat.py # Agent 聊天 + 多 Agent 编排 + LLM 调用日志(新增) +│ │ │ └── agent_monitoring.py # Agent 监控 API 5 个端点(新增) │ │ ├── core/ # 核心模块 │ │ ├── models/ # 数据库模型 +│ │ │ └── agent_llm_log.py # Agent LLM 调用日志模型(新增) │ │ ├── schemas/ # Pydantic模式 │ │ ├── services/ # 业务逻辑 +│ │ │ └── agent_monitoring_service.py # Agent 监控服务(新增) │ │ └── main.py # 应用入口 │ ├── alembic/ # 数据库迁移 │ ├── tests/ # 测试 @@ -328,17 +332,19 @@ pnpm dev - `POST /api/v1/data-sources/{id}/test` - 测试数据源连接 - `POST /api/v1/data-sources/{id}/query` - 执行数据查询 -### Agent 对话 API(新增) +### Agent 与编排 API(新增) - `POST /api/v1/agent-chat/bare` - 默认 Agent 直接对话(无需预配置) - `POST /api/v1/agent-chat/{agent_id}` - 与指定 Agent 对话(复用工作流配置) - `POST /api/v1/agent-chat/orchestrate` - 多 Agent 编排(route/sequential/debate 三种模式) -- `GET /api/v1/agent-monitoring/overview` - Agent 概览统计 -- `GET /api/v1/agent-monitoring/llm-calls` - LLM 调用记录列表(支持 days/limit 参数) -- `GET /api/v1/agent-monitoring/agents-stats` - 各 Agent 用量排行 -- `GET /api/v1/agent-monitoring/tool-usage` - 工具调用频次统计 -- `GET /api/v1/agent-monitoring/daily-trend` - 每日 LLM 调用趋势 - - 请求体包含 message、mode、agents 列表(每个 Agent 可独立配置 system_prompt/model/temperature/tools) - - 返回 final_answer、steps 追踪、agent_results + - 请求体: `{ message, mode, agents[] }` 每个 Agent 可独立配置 system_prompt/model/temperature/tools + - 返回: `{ final_answer, steps[], agent_results[] }` + +### Agent 监控 API(新增) +- `GET /api/v1/agent-monitoring/overview` - Agent 概览统计(Agent 数、对话/LLM/工具调用次数、Token 用量) +- `GET /api/v1/agent-monitoring/llm-calls?days=7&limit=50` - LLM 调用记录列表(模型、tokens、耗时、状态) +- `GET /api/v1/agent-monitoring/agents-stats?days=7` - 各 Agent 用量排行(调用次数、Token、延迟) +- `GET /api/v1/agent-monitoring/tool-usage?days=7` - 工具调用频次统计 +- `GET /api/v1/agent-monitoring/daily-trend?days=7` - 每日 LLM 调用趋势 ### WebSocket API - `ws://localhost:8037/ws/execution/{execution_id}` - 执行状态实时推送 @@ -457,7 +463,8 @@ alembic downgrade -1 - **第四-七阶段功能**: 100% ✅ - **自主 Agent Runtime**: 100% ✅(2026-04 新增) - **多 Agent 编排**: 100% ✅(2026-05 新增) -- **整体项目**: 约 95-97% +- **Agent 监控 Dashboard**: 100% ✅(2026-05 新增) +- **整体项目**: 约 97-98% ### 已完成核心功能 1. **完整的用户认证系统** - 注册、登录、JWT认证 @@ -481,8 +488,8 @@ alembic downgrade -1 ### 近期开发重点(高优先级) 1. **预算接入** - Agent 内部 LLM 调用计入工作流执行预算 -2. **Agent Dashboard** - LLM 调用链路追踪、Token 消耗统计、执行历史 -3. **用户体验优化** - 工作流编辑器优化、Agent使用体验优化 +2. **用户体验优化** - 工作流编辑器优化、Agent使用体验优化 +3. **流式输出** - Agent 思考过程实时推送到前端 ### 中期规划 1. **向量记忆** - 集成 Embedding API + 向量检索(语义记忆) @@ -550,6 +557,6 @@ alembic downgrade -1 --- **最后更新**: 2026-05-01 -**文档版本**: 1.5 +**文档版本**: 1.6 *本文档基于项目现有文档整理生成,涵盖项目核心信息。详细技术方案请参考[方案-优化版.md](./方案-优化版.md)。DeepSeek 模型名与 Base URL 以官方文档为准,变更时请同步修订本节。* \ No newline at end of file diff --git a/backend/app/agent_runtime/__init__.py b/backend/app/agent_runtime/__init__.py index df7e4f7..54af7e2 100644 --- a/backend/app/agent_runtime/__init__.py +++ b/backend/app/agent_runtime/__init__.py @@ -15,6 +15,7 @@ from app.agent_runtime.schemas import ( AgentLLMConfig, AgentToolConfig, AgentMemoryConfig, + AgentBudgetConfig, AgentStep, ) from app.agent_runtime.context import AgentContext @@ -34,6 +35,7 @@ __all__ = [ "AgentLLMConfig", "AgentToolConfig", "AgentMemoryConfig", + "AgentBudgetConfig", "AgentContext", "AgentMemory", "AgentToolManager", diff --git a/backend/app/agent_runtime/core.py b/backend/app/agent_runtime/core.py index 845b214..17503f5 100644 --- a/backend/app/agent_runtime/core.py +++ b/backend/app/agent_runtime/core.py @@ -13,7 +13,7 @@ from __future__ import annotations import json import logging import time -from typing import Any, Callable, Dict, List, Optional, Protocol, TypedDict +from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Protocol, TypedDict from app.agent_runtime.schemas import ( AgentConfig, @@ -23,6 +23,7 @@ from app.agent_runtime.schemas import ( from app.agent_runtime.context import AgentContext from app.agent_runtime.memory import AgentMemory from app.agent_runtime.tool_manager import AgentToolManager +from app.core.exceptions import WorkflowExecutionError logger = logging.getLogger(__name__) @@ -95,6 +96,11 @@ class AgentRuntime: self.on_tool_executed = on_tool_executed self.on_llm_call = on_llm_call self._memory_context_loaded = False + self._llm_invocations = 0 + + # 预算回调:供 WorkflowEngine 注入,使 Agent 内部计数计入工作流预算 + # 返回 True 表示预算充足;返回 False 或抛出异常表示超限 + self.on_llm_invocation: Optional[Callable[[], Any]] = None async def run(self, user_input: str) -> AgentResult: """ @@ -108,7 +114,7 @@ class AgentRuntime: # 1. 首次运行时加载长期记忆到 system prompt if not self._memory_context_loaded: - await self._inject_memory_context() + await self._inject_memory_context(user_input) self._memory_context_loaded = True # 2. 追加用户消息 @@ -139,6 +145,32 @@ class AgentRuntime: # 裁剪过长历史 messages = self.memory.trim_messages(self.context.messages) + # 预算检查:LLM 调用次数(在调用 LLM 之前检查,避免浪费额度) + budget = self.config.budget + if self._llm_invocations >= budget.max_llm_invocations: + err = f"已超过 LLM 调用预算({budget.max_llm_invocations} 次)" + logger.warning(err) + steps.append(AgentStep(iteration=self.context.iteration, type="final", content=err)) + await self.memory.save_context(user_input, err, self.context.messages) + return AgentResult(success=False, content=err, truncated=True, + iterations_used=self.context.iteration, + tool_calls_made=self.context.tool_calls_made, + steps=steps, error=err) + + # 调用外部 LLM 预算回调(WorkflowEngine 注入,将 Agent 的 LLM 计入工作流预算) + if self.on_llm_invocation: + try: + self.on_llm_invocation() + except Exception as e: + err = f"LLM 调用超出工作流预算: {e}" + logger.warning(err) + steps.append(AgentStep(iteration=self.context.iteration, type="final", content=err)) + await self.memory.save_context(user_input, err, self.context.messages) + return AgentResult(success=False, content=err, truncated=True, + iterations_used=self.context.iteration, + tool_calls_made=self.context.tool_calls_made, + steps=steps, error=str(e)) + # 调用 LLM try: response = await llm.chat( @@ -166,6 +198,9 @@ class AgentRuntime: error=err_str, ) + # 记录 LLM 调用次数(内部计数) + self._llm_invocations += 1 + # 解析工具调用 tool_calls = self._extract_tool_calls(response) content = self._extract_content(response) @@ -247,9 +282,22 @@ class AgentRuntime: self.context.add_tool_result(tcid, tname, result) self.context.tool_calls_made += 1 + # 预算检查:工具调用次数 + if self.context.tool_calls_made > budget.max_tool_calls: + err = f"已超过工具调用预算({budget.max_tool_calls} 次)" + logger.warning(err) + steps.append(AgentStep(iteration=self.context.iteration, type="tool_result", + content=err, tool_name=tname)) + return AgentResult(success=False, content=err, truncated=True, + iterations_used=self.context.iteration, + tool_calls_made=self.context.tool_calls_made, + steps=steps, error=err) + if self.on_tool_executed: try: await self.on_tool_executed(tname) + except WorkflowExecutionError: + raise except Exception: pass @@ -284,9 +332,240 @@ class AgentRuntime: steps=steps, ) - async def _inject_memory_context(self) -> None: + async def run_stream(self, user_input: str) -> AsyncGenerator[dict, None]: + """ + 流式执行 Agent 单轮对话。 + + 与 run() 逻辑相同,但在每个关键步骤 yield SSE 事件: + - think: LLM 思考中,准备调用工具 + - tool_call: 即将执行工具 + - tool_result: 工具执行完毕 + - final: 最终回答 + - error: 出错/预算超限 + """ + max_iter = max(1, self.config.llm.max_iterations) + self.context.iteration = 0 + self.context.tool_calls_made = 0 + + # 1. 首次运行时加载长期记忆到 system prompt + if not self._memory_context_loaded: + await self._inject_memory_context(user_input) + self._memory_context_loaded = True + + # 2. 追加用户消息 + self.context.add_user_message(user_input) + + # 3. ReAct 循环 + llm = _LLMClient(self.config.llm) + tool_schemas = self.tool_manager.get_tool_schemas() + has_tools = self.tool_manager.has_tools() + steps: List[AgentStep] = [] + + llm_callback_ctx = {"step_type": "think", "tool_name": None} + + def _llm_callback(metrics: Dict[str, Any]): + if self.on_llm_call: + metrics.update({ + "session_id": self.context.session_id, + "user_id": self.config.user_id, + "step_type": llm_callback_ctx["step_type"], + "tool_name": llm_callback_ctx["tool_name"], + }) + self.on_llm_call(metrics) + + while self.context.iteration < max_iter: + self.context.iteration += 1 + messages = self.memory.trim_messages(self.context.messages) + + # 预算检查:LLM 调用次数(在调用 LLM 之前检查,避免浪费额度) + budget = self.config.budget + if self._llm_invocations >= budget.max_llm_invocations: + err = f"已超过 LLM 调用预算({budget.max_llm_invocations} 次)" + logger.warning(err) + yield {"type": "error", "content": err, "iteration": self.context.iteration, + "truncated": True} + await self.memory.save_context(user_input, err, self.context.messages) + return + + # 调用外部 LLM 预算回调(WorkflowEngine 注入) + if self.on_llm_invocation: + try: + self.on_llm_invocation() + except Exception as e: + err = f"LLM 调用超出工作流预算: {e}" + logger.warning(err) + yield {"type": "error", "content": err, "iteration": self.context.iteration, + "truncated": True} + return + + # 调用 LLM + try: + response = await llm.chat( + messages=messages, + tools=tool_schemas if has_tools and self.context.iteration == 1 else + (tool_schemas if has_tools else None), + iteration=self.context.iteration, + on_completion=_llm_callback, + ) + except Exception as e: + err_str = str(e) + logger.error("LLM 调用失败 (iteration=%s): %s", self.context.iteration, err_str) + if self.context.iteration < max_iter and self._is_retryable(err_str): + yield {"type": "error", "content": f"LLM 调用失败(可重试): {err_str}", + "iteration": self.context.iteration} + continue + yield {"type": "error", "content": f"LLM 调用失败: {err_str}", + "iteration": self.context.iteration} + return + + # 记录 LLM 调用次数(内部计数) + self._llm_invocations += 1 + + # 解析工具调用 + tool_calls = self._extract_tool_calls(response) + content = self._extract_content(response) + reasoning = getattr(response, "reasoning_content", None) or ( + response.get("reasoning_content") if isinstance(response, dict) else None + ) + + if not tool_calls: + # LLM 直接返回文本 → 结束 + self.context.add_assistant_message(content) + final_text = content or "(模型未返回有效内容)" + yield { + "type": "final", + "content": final_text, + "reasoning": reasoning, + "iteration": self.context.iteration, + "iterations_used": self.context.iteration, + "tool_calls_made": self.context.tool_calls_made, + "session_id": self.context.session_id, + } + await self.memory.save_context(user_input, final_text, self.context.messages) + return + + # 有工具调用 → 先记录 assistant 消息 + self.context.add_assistant_message(content or "", tool_calls, reasoning) + + # yield think 事件 + tc_names = [tc["function"]["name"] for tc in tool_calls] + tc_args_list = [] + for tc in tool_calls: + try: + tc_args_list.append(json.loads(tc["function"].get("arguments", "{}"))) + except (json.JSONDecodeError, TypeError): + tc_args_list.append({}) + + yield { + "type": "think", + "content": content or f"调用工具: {', '.join(tc_names)}", + "reasoning": reasoning, + "tool_names": tc_names, + "iteration": self.context.iteration, + } + + steps.append(AgentStep( + iteration=self.context.iteration, + type="think", + content=content or f"调用工具: {', '.join(tc_names)}", + reasoning=reasoning, + tool_name=tc_names[0] if len(tc_names) == 1 else None, + tool_input=tc_args_list[0] if len(tc_args_list) == 1 else None, + )) + + if self.execution_logger: + self.execution_logger.info( + f"Agent 调用 {len(tool_calls)} 个工具", + data={"tool_calls": tc_names, "iteration": self.context.iteration}, + ) + + # 逐一执行工具 + for tc in tool_calls: + tfn = tc.get("function", {}) + tname = tfn.get("name", "unknown") + tcid = tc.get("id", f"call_{self.context.iteration}_{self.context.tool_calls_made}") + + try: + targs = json.loads(tfn.get("arguments", "{}")) + except (json.JSONDecodeError, TypeError): + targs = {} + + # yield tool_call 事件 + yield { + "type": "tool_call", + "name": tname, + "input": targs, + "iteration": self.context.iteration, + } + + logger.info("Agent 执行工具 [%s]: %s", tname, targs) + result = await self.tool_manager.execute(tname, targs) + + # yield tool_result 事件 + yield { + "type": "tool_result", + "name": tname, + "result": result[:500] + "..." if len(result) > 500 else result, + "iteration": self.context.iteration, + } + + steps.append(AgentStep( + iteration=self.context.iteration, + type="tool_result", + content=f"工具 {tname} 返回结果", + tool_name=tname, + tool_input=targs, + tool_result=result[:500] + "..." if len(result) > 500 else result, + )) + + self.context.add_tool_result(tcid, tname, result) + self.context.tool_calls_made += 1 + + # 预算检查:工具调用次数 + if self.context.tool_calls_made > budget.max_tool_calls: + err = f"已超过工具调用预算({budget.max_tool_calls} 次)" + logger.warning(err) + yield {"type": "error", "content": err, "iteration": self.context.iteration, + "truncated": True} + return + + if self.on_tool_executed: + try: + await self.on_tool_executed(tname) + except WorkflowExecutionError: + raise + except Exception: + pass + + if self.execution_logger: + preview = result[:300] + "..." if len(result) > 300 else result + self.execution_logger.info( + f"工具 {tname} 执行完成", + data={"tool_name": tname, "result_preview": preview}, + ) + + # 达到最大迭代次数 + last_content = "" + for m in reversed(self.context.messages): + if m.get("role") == "assistant" and m.get("content"): + last_content = m["content"] + break + + logger.warning("Agent 达到最大迭代次数 (%s)", max_iter) + await self.memory.save_context(user_input, last_content or "(已达最大迭代次数)", self.context.messages) + yield { + "type": "final", + "content": last_content or "已达最大迭代次数,但模型未返回最终回答。", + "iteration": self.context.iteration, + "iterations_used": self.context.iteration, + "tool_calls_made": self.context.tool_calls_made, + "truncated": True, + "session_id": self.context.session_id, + } + + async def _inject_memory_context(self, query: str = "") -> None: """加载长期记忆并注入 system prompt。""" - mem_text = await self.memory.initialize() + mem_text = await self.memory.initialize(query=query) if mem_text: enriched = ( self.config.system_prompt.rstrip("\n") diff --git a/backend/app/agent_runtime/memory.py b/backend/app/agent_runtime/memory.py index facb27f..7824cce 100644 --- a/backend/app/agent_runtime/memory.py +++ b/backend/app/agent_runtime/memory.py @@ -15,6 +15,7 @@ from app.services.persistent_memory_service import ( save_persistent_memory, persist_enabled, ) +from app.services.embedding_service import embedding_service, VectorEntry logger = logging.getLogger(__name__) @@ -35,25 +36,31 @@ class AgentMemory: session_key: Optional[str] = None, persist: bool = True, max_history: int = 20, + vector_memory_enabled: bool = True, + vector_memory_top_k: int = 5, ): self.scope_kind = scope_kind self.scope_id = scope_id or "default" self.session_key = session_key or "default_session" self.persist = persist and persist_enabled() self.max_history = max_history + self.vector_memory_enabled = vector_memory_enabled + self.vector_memory_top_k = vector_memory_top_k # 从长期记忆加载的上下文(启动时加载) self._long_term_context: Dict[str, Any] = {} # 记录已压缩的消息数,避免重复压缩 self._last_compressed_msg_count = 0 - async def initialize(self) -> str: + async def initialize(self, query: str = "") -> str: """ - 初始化记忆:从 DB/Redis 加载长期记忆,构造初始上下文文本。 + 初始化记忆:从 DB/Redis 加载长期记忆 + 向量检索相关历史。 返回注入 system prompt 的记忆文本块。 """ if not self.persist or not self.scope_id: return "" + parts: List[str] = [] + db: Optional[Session] = None try: db = SessionLocal() @@ -62,8 +69,6 @@ class AgentMemory: ) if payload and isinstance(payload, dict): self._long_term_context = payload - # 构建注入 system prompt 的记忆文本 - parts = [] profile = payload.get("user_profile") if profile and isinstance(profile, dict): profile_text = json.dumps(profile, ensure_ascii=False) @@ -78,15 +83,93 @@ class AgentMemory: if history and isinstance(history, list) and len(history) > 0: summary = self._summarize_history(history) parts.append(f"## 历史对话摘要\n{summary}") - - if parts: - return "\n\n".join(parts) except Exception as e: logger.warning("加载长期记忆失败: %s", e) finally: if db: db.close() - return "" + + # 2. 向量检索:查找语义相关的历史对话 + if self.vector_memory_enabled and self.scope_kind and self.scope_id: + vector_text = await self._vector_search(query) + if vector_text: + parts.append(vector_text) + + return "\n\n".join(parts) if parts else "" + + async def _vector_search(self, query: str = "") -> str: + """ + 向量检索语义相关的历史记忆,返回格式化的文本块。 + 若无 query 则返回最近 Top-5 条记忆。 + """ + from app.models.agent_vector_memory import AgentVectorMemory + + db: Optional[Session] = None + try: + db = SessionLocal() + # 查询当前 scope 的所有向量记忆(按时间倒序) + rows = ( + db.query(AgentVectorMemory) + .filter( + AgentVectorMemory.scope_kind == self.scope_kind, + AgentVectorMemory.scope_id == self.scope_id, + ) + .order_by(AgentVectorMemory.created_at.desc()) + .limit(50) # 最多取最近 50 条做相似度计算 + .all() + ) + + if not rows: + return "" + + entries: List[VectorEntry] = [] + for row in rows: + emb = embedding_service.deserialize_embedding(row.embedding) if row.embedding else [] + entries.append({ + "id": row.id, + "scope_kind": row.scope_kind, + "scope_id": row.scope_id, + "content_text": row.content_text, + "embedding": emb, + "metadata": row.metadata_ or {}, + }) + + matched: List[VectorEntry] = [] + + if query and query.strip(): + # 有 query:生成 embedding 做语义搜索 + query_emb = await embedding_service.generate_embedding(query) + if query_emb: + matched = await embedding_service.similarity_search( + query_emb, entries, top_k=self.vector_memory_top_k + ) + else: + # 无 query:返回最近几条 + matched = entries[: self.vector_memory_top_k] + for m in matched: + m["score"] = 1.0 + + if not matched: + return "" + + # 格式化为文本块 + lines = ["## 相关历史记忆"] + for i, m in enumerate(matched, 1): + text = m.get("content_text", "")[:500] + meta = m.get("metadata", {}) + entry_type = meta.get("type", "对话") + lines.append(f"{i}. [{entry_type}] {text}") + if m.get("score", 1.0) < 1.0: + lines[-1] += f" (匹配度: {m['score']:.2f})" + + return "\n".join(lines) + + except Exception as e: + logger.warning("向量检索失败: %s", e) + return "" + finally: + if db: + db.close() async def save_context( self, user_message: str, assistant_reply: str, @@ -114,12 +197,50 @@ class AgentMemory: db, self.scope_kind, self.scope_id, self.session_key, self._long_term_context, ) + + # 保存向量记忆(异步生成 embedding 并存储) + if self.vector_memory_enabled: + await self._save_vector_memory( + db, user_message, assistant_reply + ) except Exception as e: logger.warning("保存长期记忆失败: %s", e) finally: if db: db.close() + async def _save_vector_memory( + self, db: Session, user_message: str, assistant_reply: str, + ) -> None: + """生成 embedding 并保存到向量记忆表。""" + from app.models.agent_vector_memory import AgentVectorMemory + + content_text = f"用户: {user_message}\n助手: {assistant_reply}" + if len(content_text) > 8000: + content_text = content_text[:8000] + + try: + # 生成 embedding + embedding = await embedding_service.generate_embedding(content_text) + embedding_json = embedding_service.serialize_embedding(embedding) if embedding else "" + + record = AgentVectorMemory( + scope_kind=self.scope_kind, + scope_id=self.scope_id, + session_key=self.session_key, + content_text=content_text[:2000], + embedding=embedding_json or None, + metadata_={ + "type": "conversation_turn", + }, + ) + db.add(record) + db.commit() + logger.debug("已保存向量记忆 (scope=%s/%s)", self.scope_kind, self.scope_id) + except Exception as e: + logger.warning("保存向量记忆失败: %s", e) + db.rollback() + async def _compress_and_summarize( self, messages: List[Dict[str, Any]] ) -> None: diff --git a/backend/app/agent_runtime/schemas.py b/backend/app/agent_runtime/schemas.py index 367c0d7..de07ee5 100644 --- a/backend/app/agent_runtime/schemas.py +++ b/backend/app/agent_runtime/schemas.py @@ -20,6 +20,8 @@ class AgentMemoryConfig(BaseModel): max_history_messages: int = 20 # 注入 LLM 的上文最大消息数 session_key: Optional[str] = None # 会话标识,默认自动生成 persist_to_db: bool = True # 是否写入 MySQL 长期记忆 + vector_memory_enabled: bool = True # 是否启用向量记忆(语义检索) + vector_memory_top_k: int = 5 # 向量检索 Top-K class AgentLLMConfig(BaseModel): @@ -35,6 +37,12 @@ class AgentLLMConfig(BaseModel): extra_body: Optional[Dict[str, Any]] = None +class AgentBudgetConfig(BaseModel): + """Agent 执行预算配置""" + max_llm_invocations: int = 200 # LLM 调用次数上限 + max_tool_calls: int = 500 # 工具调用次数上限 + + class AgentConfig(BaseModel): """Agent 完整配置""" name: str = "default_agent" @@ -42,6 +50,7 @@ class AgentConfig(BaseModel): llm: AgentLLMConfig = Field(default_factory=AgentLLMConfig) tools: AgentToolConfig = Field(default_factory=AgentToolConfig) memory: AgentMemoryConfig = Field(default_factory=AgentMemoryConfig) + budget: AgentBudgetConfig = Field(default_factory=AgentBudgetConfig) user_id: Optional[str] = None diff --git a/backend/app/agent_runtime/tool_manager.py b/backend/app/agent_runtime/tool_manager.py index 38c37ff..80dd682 100644 --- a/backend/app/agent_runtime/tool_manager.py +++ b/backend/app/agent_runtime/tool_manager.py @@ -3,9 +3,8 @@ Agent 工具管理器:包装已有 ToolRegistry,提供 Agent 需要的工具 """ from __future__ import annotations -import json import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional from app.services.tool_registry import tool_registry @@ -58,6 +57,8 @@ class AgentToolManager: """ 执行工具调用。 + 优先查找内置工具,其次查找数据库自定义工具(HTTP / Code)。 + Args: name: 工具名称 args: 工具参数字典 @@ -65,27 +66,8 @@ class AgentToolManager: Returns: 工具执行结果的字符串表示 """ - func: Optional[Callable] = tool_registry.get_tool_function(name) - if not func: - err = f"工具 '{name}' 不存在" - logger.error(err) - return json.dumps({"error": err}, ensure_ascii=False) - - logger.info("Agent 执行工具: %s, 参数: %s", name, args) - try: - import asyncio - if asyncio.iscoroutinefunction(func): - result = await func(**args) - else: - result = func(**args) - - if isinstance(result, (dict, list)): - return json.dumps(result, ensure_ascii=False) - return str(result) - except Exception as e: - err_msg = f"工具 '{name}' 执行失败: {e}" - logger.error(err_msg, exc_info=True) - return json.dumps({"error": err_msg}, ensure_ascii=False) + logger.info("Agent 执行工具: %s", name) + return await tool_registry.execute_tool(name, args) @staticmethod def _extract_tool_name(schema: Dict[str, Any]) -> Optional[str]: diff --git a/backend/app/agent_runtime/workflow_integration.py b/backend/app/agent_runtime/workflow_integration.py index fbcb2ac..8566c7e 100644 --- a/backend/app/agent_runtime/workflow_integration.py +++ b/backend/app/agent_runtime/workflow_integration.py @@ -13,6 +13,7 @@ from app.agent_runtime.schemas import ( AgentConfig, AgentLLMConfig, AgentToolConfig, + AgentBudgetConfig, ) logger = logging.getLogger(__name__) @@ -24,6 +25,8 @@ async def run_agent_node( execution_logger: Optional[Any] = None, user_id: Optional[str] = None, on_tool_executed: Optional[Any] = None, + on_llm_invocation: Optional[Any] = None, + budget_limits: Optional[Dict[str, int]] = None, ) -> Dict[str, Any]: """ 在工作流中执行 Agent 节点。 @@ -72,6 +75,14 @@ async def run_agent_node( if node_data.get("base_url"): llm_config.base_url = node_data["base_url"] + # 3a. 构建预算配置(接收工作流级预算限制) + budget = AgentBudgetConfig() + if budget_limits: + if "max_llm_invocations" in budget_limits: + budget.max_llm_invocations = max(1, int(budget_limits["max_llm_invocations"])) + if "max_tool_calls" in budget_limits: + budget.max_tool_calls = max(1, int(budget_limits["max_tool_calls"])) + agent_config = AgentConfig( name=node_data.get("label", "agent_node"), system_prompt=formatted_prompt, @@ -84,6 +95,7 @@ async def run_agent_node( "enabled": node_data.get("memory", True), "persist_to_db": node_data.get("memory", True), }, + budget=budget, user_id=user_id, ) @@ -93,6 +105,9 @@ async def run_agent_node( execution_logger=execution_logger, on_tool_executed=on_tool_executed, ) + # 注入 LLM 预算回调(使 Agent 内部 LLM 调用计入工作流预算) + if on_llm_invocation: + runtime.on_llm_invocation = on_llm_invocation result = await runtime.run(query) diff --git a/backend/app/api/agent_chat.py b/backend/app/api/agent_chat.py index 9339c34..c9d683b 100644 --- a/backend/app/api/agent_chat.py +++ b/backend/app/api/agent_chat.py @@ -8,8 +8,10 @@ POST /api/v1/agent-chat/bare from __future__ import annotations import logging -from typing import Any, Dict, List, Optional -from fastapi import APIRouter, Depends, HTTPException +import json +from typing import Any, AsyncGenerator, Dict, List, Optional +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from app.core.database import get_db @@ -23,6 +25,7 @@ from app.agent_runtime import ( AgentConfig, AgentLLMConfig, AgentToolConfig, + AgentBudgetConfig, AgentStep, AgentOrchestrator, OrchestratorAgentConfig, @@ -64,6 +67,14 @@ def _make_llm_logger( return _log +async def _sse_stream(gen: AsyncGenerator[dict, None]) -> AsyncGenerator[str, None]: + """将 run_stream 生成的 dict 事件格式化为 SSE 文本流。""" + async for event in gen: + event_type = event.get("type", "message") + data = {k: v for k, v in event.items() if k != "type"} + yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + + class ChatRequest(BaseModel): message: str session_id: Optional[str] = None @@ -205,6 +216,39 @@ async def chat_bare( ) +@router.post("/bare/stream") +async def chat_bare_stream( + req: ChatRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """无需 Agent 配置,使用默认设置直接对话(流式 SSE)。""" + config = AgentConfig( + name="bare_agent", + system_prompt="你是一个有用的AI助手。请使用可用工具来帮助用户完成任务。", + llm=AgentLLMConfig( + model=req.model or ( + "gpt-4o-mini" if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key" + else "deepseek-v4-flash" + ), + temperature=req.temperature or 0.7, + max_iterations=req.max_iterations or 10, + ), + user_id=current_user.id, + ) + on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id) + runtime = AgentRuntime(config=config, on_llm_call=on_llm_call) + return StreamingResponse( + _sse_stream(runtime.run_stream(req.message)), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + @router.post("/{agent_id}", response_model=ChatResponse) async def chat_with_agent( agent_id: str, @@ -225,9 +269,25 @@ async def chat_with_agent( # 查找 agent 节点的配置(或第一个 llm 节点的配置) agent_node_cfg = _find_agent_node_config(nodes) + # 构建 system prompt,并自动注入智能体名称 + system_prompt = agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。" + if agent.name: + name_prefix = f"你的名字是{agent.name}" + if name_prefix not in system_prompt: + system_prompt = f"{name_prefix}。\n\n{system_prompt}" + + # 合并执行预算:Agent.budget_config 覆盖默认值 + budget = AgentBudgetConfig() + if agent.budget_config and isinstance(agent.budget_config, dict): + bc = agent.budget_config + if "max_llm_invocations" in bc and bc["max_llm_invocations"] is not None: + budget.max_llm_invocations = max(1, int(bc["max_llm_invocations"])) + if "max_tool_calls" in bc and bc["max_tool_calls"] is not None: + budget.max_tool_calls = max(1, int(bc["max_tool_calls"])) + config = AgentConfig( name=agent.name, - system_prompt=agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。", + system_prompt=system_prompt, llm=AgentLLMConfig( provider=agent_node_cfg.get("provider", "openai"), model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"), @@ -238,6 +298,7 @@ async def chat_with_agent( include_tools=agent_node_cfg.get("tools", []), exclude_tools=agent_node_cfg.get("exclude_tools", []), ), + budget=budget, user_id=current_user.id, ) @@ -256,6 +317,68 @@ async def chat_with_agent( ) +@router.post("/{agent_id}/stream") +async def chat_with_agent_stream( + agent_id: str, + req: ChatRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """与指定的 Agent 对话(流式 SSE)。""" + agent = db.query(Agent).filter(Agent.id == agent_id).first() + if not agent: + raise HTTPException(status_code=404, detail="Agent 不存在") + if agent.user_id and agent.user_id != current_user.id and current_user.role != "admin": + raise HTTPException(status_code=403, detail="无权访问该 Agent") + + wc = agent.workflow_config or {} + nodes = wc.get("nodes", []) + agent_node_cfg = _find_agent_node_config(nodes) + + system_prompt = agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。" + if agent.name: + name_prefix = f"你的名字是{agent.name}" + if name_prefix not in system_prompt: + system_prompt = f"{name_prefix}。\n\n{system_prompt}" + + budget = AgentBudgetConfig() + if agent.budget_config and isinstance(agent.budget_config, dict): + bc = agent.budget_config + if "max_llm_invocations" in bc and bc["max_llm_invocations"] is not None: + budget.max_llm_invocations = max(1, int(bc["max_llm_invocations"])) + if "max_tool_calls" in bc and bc["max_tool_calls"] is not None: + budget.max_tool_calls = max(1, int(bc["max_tool_calls"])) + + config = AgentConfig( + name=agent.name, + system_prompt=system_prompt, + llm=AgentLLMConfig( + provider=agent_node_cfg.get("provider", "openai"), + model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"), + temperature=req.temperature or float(agent_node_cfg.get("temperature", 0.7)), + max_iterations=req.max_iterations or int(agent_node_cfg.get("max_iterations", 10)), + ), + tools=AgentToolConfig( + include_tools=agent_node_cfg.get("tools", []), + exclude_tools=agent_node_cfg.get("exclude_tools", []), + ), + budget=budget, + user_id=current_user.id, + ) + + on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id) + runtime = AgentRuntime(config=config, on_llm_call=on_llm_call) + return StreamingResponse( + _sse_stream(runtime.run_stream(req.message)), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + def _find_agent_node_config(nodes: list) -> Dict[str, Any]: """从工作流节点列表中查找第一个 agent 类型或 llm 类型的节点配置。""" if not nodes: diff --git a/backend/app/api/knowledge_base.py b/backend/app/api/knowledge_base.py new file mode 100644 index 0000000..6609b3d --- /dev/null +++ b/backend/app/api/knowledge_base.py @@ -0,0 +1,251 @@ +""" +知识库 RAG API。 + +提供知识库管理、文档上传、语义搜索和 RAG 查询接口。 +""" +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from app.core.database import get_db +from app.api.auth import get_current_user +from app.models.user import User +from app.services.knowledge_service import ( + create_knowledge_base, + delete_document, + delete_knowledge_base, + get_knowledge_base, + list_documents, + list_knowledge_bases, + rag_query, + search, + upload_document, +) + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/v1/knowledge-bases", tags=["knowledge-base"]) + + +# ─── Schema ────────────────────────────────────────────────────── + + +class KBCreateRequest(BaseModel): + name: str + description: str = "" + chunk_size: int = 500 + chunk_overlap: int = 50 + + +class KBResponse(BaseModel): + id: str + name: str + description: Optional[str] = "" + user_id: Optional[str] = "" + chunk_size: int = 500 + chunk_overlap: int = 50 + doc_count: int = 0 + created_at: Optional[str] = "" + updated_at: Optional[str] = "" + + +class DocumentResponse(BaseModel): + id: str + kb_id: str + filename: str + file_type: str + file_size: int = 0 + status: str = "pending" + error_message: Optional[str] = None + chunk_count: int = 0 + created_at: Optional[str] = None + + +class SearchRequest(BaseModel): + query: str + top_k: int = 5 + min_score: float = 0.3 + + +class SearchResult(BaseModel): + chunk_id: str + content: str + score: float + metadata: Dict[str, Any] = {} + + +class SearchResponse(BaseModel): + results: List[SearchResult] = [] + + +class RAGRequest(BaseModel): + query: str + top_k: int = 5 + min_score: float = 0.3 + + +class RAGResponse(BaseModel): + query: str + context: str = "" + sources: List[Dict[str, Any]] = [] + found: bool = False + + +# ─── 知识库 CRUD ────────────────────────────────────────────── + + +@router.post("", response_model=KBResponse) +async def api_create_kb( + req: KBCreateRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """创建知识库。""" + kb = create_knowledge_base( + db=db, + name=req.name, + user_id=current_user.id, + description=req.description, + chunk_size=req.chunk_size, + chunk_overlap=req.chunk_overlap, + ) + return KBResponse(**kb.to_dict()) + + +@router.get("", response_model=List[KBResponse]) +async def api_list_kb( + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """列出知识库。""" + kbs = list_knowledge_bases(db, user_id=current_user.id) + return [KBResponse(**kb.to_dict()) for kb in kbs] + + +@router.get("/{kb_id}", response_model=KBResponse) +async def api_get_kb( + kb_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """获取知识库详情。""" + kb = get_knowledge_base(db, kb_id) + if not kb: + raise HTTPException(status_code=404, detail="知识库不存在") + return KBResponse(**kb.to_dict()) + + +@router.delete("/{kb_id}") +async def api_delete_kb( + kb_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """删除知识库。""" + ok = delete_knowledge_base(db, kb_id) + if not ok: + raise HTTPException(status_code=404, detail="知识库不存在") + return {"message": "知识库已删除"} + + +# ─── 文档管理 ────────────────────────────────────────────────── + + +@router.post("/{kb_id}/documents", response_model=DocumentResponse) +async def api_upload_document( + kb_id: str, + file: UploadFile = File(...), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """上传文档到知识库(自动解析、分块、生成 Embedding)。""" + if not file.filename: + raise HTTPException(status_code=400, detail="文件名不能为空") + + content = await file.read() + if not content: + raise HTTPException(status_code=400, detail="文件内容为空") + + try: + doc = await upload_document( + db=db, + kb_id=kb_id, + filename=file.filename, + file_content=content, + ) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + if doc.status == "failed": + # 返回成功但告知处理失败 + return DocumentResponse(**doc.to_dict()) + + return DocumentResponse(**doc.to_dict()) + + +@router.get("/{kb_id}/documents", response_model=List[DocumentResponse]) +async def api_list_documents( + kb_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """列出知识库中的文档。""" + docs = list_documents(db, kb_id) + return [DocumentResponse(**d.to_dict()) for d in docs] + + +@router.delete("/{kb_id}/documents/{doc_id}") +async def api_delete_document( + kb_id: str, + doc_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """删除文档。""" + ok = delete_document(db, doc_id) + if not ok: + raise HTTPException(status_code=404, detail="文档不存在") + return {"message": "文档已删除"} + + +# ─── 搜索 & RAG ──────────────────────────────────────────────── + + +@router.post("/{kb_id}/search", response_model=SearchResponse) +async def api_search( + kb_id: str, + req: SearchRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """语义搜索知识库。""" + results = await search( + db=db, + kb_id=kb_id, + query=req.query, + top_k=req.top_k, + min_score=req.min_score, + ) + return SearchResponse(results=[SearchResult(**r) for r in results]) + + +@router.post("/{kb_id}/rag", response_model=RAGResponse) +async def api_rag( + kb_id: str, + req: RAGRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """RAG 查询:搜索相关片段并格式化为上下文。""" + result = await rag_query( + db=db, + kb_id=kb_id, + query=req.query, + top_k=req.top_k, + min_score=req.min_score, + ) + return RAGResponse(**result) diff --git a/backend/app/api/tools.py b/backend/app/api/tools.py index 3c9e73e..ac1869a 100644 --- a/backend/app/api/tools.py +++ b/backend/app/api/tools.py @@ -1,21 +1,42 @@ """ -工具管理API +工具市场 API — 管理、测试、发现和安装工具。 + +提供内置工具和用户自定义工具(HTTP / 代码段)的 CRUD、测试和执行。 """ +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel from sqlalchemy.orm import Session -from typing import List, Optional + +from app.api.auth import get_current_user from app.core.database import get_db from app.models.tool import Tool -from app.services.tool_registry import tool_registry -from app.api.auth import get_current_user from app.models.user import User -from pydantic import BaseModel +from app.services.tool_registry import tool_registry +logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/tools", tags=["tools"]) +# ─── Schema ────────────────────────────────────────────────────── + + class ToolCreate(BaseModel): - """创建工具请求""" + name: str + description: str + category: Optional[str] = None + function_schema: dict + implementation_type: str # builtin / http / code / workflow + implementation_config: Optional[dict] = None + is_public: bool = False + + +class ToolResponse(BaseModel): + id: str name: str description: str category: Optional[str] = None @@ -23,86 +44,39 @@ class ToolCreate(BaseModel): implementation_type: str implementation_config: Optional[dict] = None is_public: bool = False + use_count: int = 0 + user_id: Optional[str] = None + created_at: str = "" + updated_at: str = "" -class ToolResponse(BaseModel): - """工具响应""" - id: str - name: str - description: str - category: Optional[str] - function_schema: dict - implementation_type: str - implementation_config: Optional[dict] - is_public: bool - use_count: int - user_id: Optional[str] - created_at: str - updated_at: str - - class Config: - from_attributes = True +class TestHTTPRequest(BaseModel): + url: str + method: str = "GET" + headers: Dict[str, str] = {} + body: Optional[Dict[str, Any]] = None + args: Dict[str, Any] = {} + timeout: int = 30 -@router.get("", response_model=List[ToolResponse]) -async def list_tools( - category: Optional[str] = Query(None, description="工具分类"), - search: Optional[str] = Query(None, description="搜索关键词"), - db: Session = Depends(get_db) -): - """获取工具列表""" - query = db.query(Tool).filter(Tool.is_public == True) - - if category: - query = query.filter(Tool.category == category) - - if search: - query = query.filter( - Tool.name.contains(search) | - Tool.description.contains(search) - ) - - tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all() - - # 转换为响应格式,确保日期时间字段转换为字符串 - result = [] - for tool in tools: - result.append({ - "id": tool.id, - "name": tool.name, - "description": tool.description, - "category": tool.category, - "function_schema": tool.function_schema, - "implementation_type": tool.implementation_type, - "implementation_config": tool.implementation_config, - "is_public": tool.is_public, - "use_count": tool.use_count, - "user_id": tool.user_id, - "created_at": tool.created_at.isoformat() if tool.created_at else "", - "updated_at": tool.updated_at.isoformat() if tool.updated_at else "" - }) - - return result +class TestCodeRequest(BaseModel): + source: str + args: Dict[str, Any] = {} -@router.get("/builtin") -async def list_builtin_tools(): - """获取内置工具列表""" - schemas = tool_registry.get_all_tool_schemas() - return schemas +class TestResponse(BaseModel): + success: bool + elapsed_ms: Optional[int] = None + result: Optional[Any] = None + status_code: Optional[int] = None + body: Optional[str] = None + error: Optional[str] = None -@router.get("/{tool_id}", response_model=ToolResponse) -async def get_tool( - tool_id: str, - db: Session = Depends(get_db) -): - """获取工具详情""" - tool = db.query(Tool).filter(Tool.id == tool_id).first() - if not tool: - raise HTTPException(status_code=404, detail="工具不存在") - - # 转换为响应格式,确保日期时间字段转换为字符串 +# ─── 工具函数 ────────────────────────────────────────────────── + + +def _tool_to_dict(tool: Tool) -> dict: return { "id": tool.id, "name": tool.name, @@ -115,22 +89,89 @@ async def get_tool( "use_count": tool.use_count, "user_id": tool.user_id, "created_at": tool.created_at.isoformat() if tool.created_at else "", - "updated_at": tool.updated_at.isoformat() if tool.updated_at else "" + "updated_at": tool.updated_at.isoformat() if tool.updated_at else "", } +# ─── 工具市场浏览 ────────────────────────────────────────────── + + +@router.get("", response_model=List[ToolResponse]) +async def list_tools( + category: Optional[str] = Query(None, description="按分类筛选"), + search: Optional[str] = Query(None, description="搜索关键词"), + scope: Optional[str] = Query("public", description="public / mine / all"), + db: Session = Depends(get_db), + current_user: Optional[User] = Depends(get_current_user), +): + """浏览工具市场。""" + query = db.query(Tool) + + if scope == "public": + query = query.filter(Tool.is_public == True) + elif scope == "mine": + if not current_user: + raise HTTPException(status_code=401, detail="需登录") + query = query.filter(Tool.user_id == current_user.id) + + if category: + query = query.filter(Tool.category == category) + if search: + query = query.filter( + Tool.name.contains(search) | Tool.description.contains(search) + ) + + tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all() + return [_tool_to_dict(t) for t in tools] + + +@router.get("/categories", response_model=List[str]) +async def list_categories(db: Session = Depends(get_db)): + """列出所有工具分类。""" + rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all() + cats = sorted(set(r[0] for r in rows if r[0])) + # 加上常用分类 + defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义"] + for d in defaults: + if d not in cats: + cats.append(d) + return cats + + +@router.get("/builtin") +async def list_builtin_tools(): + """列出所有内置工具(OpenAI Function 格式)。""" + return tool_registry.get_all_tool_schemas() + + +@router.get("/{tool_id}", response_model=ToolResponse) +async def get_tool(tool_id: str, db: Session = Depends(get_db)): + """获取工具详情。""" + tool = db.query(Tool).filter(Tool.id == tool_id).first() + if not tool: + raise HTTPException(status_code=404, detail="工具不存在") + return _tool_to_dict(tool) + + +# ─── 工具创建 / 更新 / 删除 ────────────────────────────────── + + @router.post("", response_model=ToolResponse, status_code=201) async def create_tool( tool_data: ToolCreate, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), ): - """创建工具""" - # 检查工具名称是否已存在 + """创建自定义工具。""" existing = db.query(Tool).filter(Tool.name == tool_data.name).first() if existing: - raise HTTPException(status_code=400, detail=f"工具名称 '{tool_data.name}' 已存在") - + raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在") + + valid_types = {"builtin", "http", "code", "workflow"} + if tool_data.implementation_type not in valid_types: + raise HTTPException(status_code=400, + detail=f"无效的实现类型: {tool_data.implementation_type}") + tool = Tool( name=tool_data.name, description=tool_data.description, @@ -139,28 +180,22 @@ async def create_tool( implementation_type=tool_data.implementation_type, implementation_config=tool_data.implementation_config, is_public=tool_data.is_public, - user_id=current_user.id + user_id=current_user.id, ) - db.add(tool) db.commit() db.refresh(tool) - - # 转换为响应格式,确保日期时间字段转换为字符串 - return { - "id": tool.id, - "name": tool.name, - "description": tool.description, - "category": tool.category, - "function_schema": tool.function_schema, - "implementation_type": tool.implementation_type, - "implementation_config": tool.implementation_config, - "is_public": tool.is_public, - "use_count": tool.use_count, - "user_id": tool.user_id, - "created_at": tool.created_at.isoformat() if tool.created_at else "", - "updated_at": tool.updated_at.isoformat() if tool.updated_at else "" + logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type) + + # 刷新注册表 + tool_registry._custom_tool_configs[tool.name] = { + **(tool.implementation_config or {}), + "_type": tool.implementation_type, + "_db_id": tool.id, } + tool_registry._tool_schemas[tool.name] = tool.function_schema + + return _tool_to_dict(tool) @router.put("/{tool_id}", response_model=ToolResponse) @@ -168,23 +203,20 @@ async def update_tool( tool_id: str, tool_data: ToolCreate, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), ): - """更新工具""" + """更新工具。""" tool = db.query(Tool).filter(Tool.id == tool_id).first() if not tool: raise HTTPException(status_code=404, detail="工具不存在") - - # 检查权限(只有创建者可以更新) if tool.user_id != current_user.id: raise HTTPException(status_code=403, detail="无权更新此工具") - - # 检查名称冲突 + if tool_data.name != tool.name: existing = db.query(Tool).filter(Tool.name == tool_data.name).first() if existing: - raise HTTPException(status_code=400, detail=f"工具名称 '{tool_data.name}' 已存在") - + raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在") + tool.name = tool_data.name tool.description = tool_data.description tool.category = tool_data.category @@ -192,47 +224,94 @@ async def update_tool( tool.implementation_type = tool_data.implementation_type tool.implementation_config = tool_data.implementation_config tool.is_public = tool_data.is_public - + db.commit() db.refresh(tool) - - # 转换为响应格式,确保日期时间字段转换为字符串 - return { - "id": tool.id, - "name": tool.name, - "description": tool.description, - "category": tool.category, - "function_schema": tool.function_schema, - "implementation_type": tool.implementation_type, - "implementation_config": tool.implementation_config, - "is_public": tool.is_public, - "use_count": tool.use_count, - "user_id": tool.user_id, - "created_at": tool.created_at.isoformat() if tool.created_at else "", - "updated_at": tool.updated_at.isoformat() if tool.updated_at else "" - } + + # 刷新注册表 + if tool.name in tool_registry._custom_tool_configs: + tool_registry._custom_tool_configs[tool.name] = { + **(tool.implementation_config or {}), + "_type": tool.implementation_type, + "_db_id": tool.id, + } + if tool.name in tool_registry._tool_schemas: + tool_registry._tool_schemas[tool.name] = tool.function_schema + + return _tool_to_dict(tool) -@router.delete("/{tool_id}", status_code=200) +@router.delete("/{tool_id}") async def delete_tool( tool_id: str, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), ): - """删除工具""" + """删除工具。""" tool = db.query(Tool).filter(Tool.id == tool_id).first() if not tool: raise HTTPException(status_code=404, detail="工具不存在") - - # 检查权限(只有创建者可以删除) if tool.user_id != current_user.id: raise HTTPException(status_code=403, detail="无权删除此工具") - - # 内置工具不允许删除 if tool.implementation_type == "builtin": raise HTTPException(status_code=400, detail="内置工具不允许删除") - + db.delete(tool) db.commit() - + + # 清理注册表 + tool_registry._custom_tool_configs.pop(tool.name, None) + tool_registry._tool_schemas.pop(tool.name, None) + return {"message": "工具已删除"} + + +# ─── 工具测试 ────────────────────────────────────────────────── + + +@router.post("/test/http", response_model=TestResponse) +async def test_http_tool( + req: TestHTTPRequest, + current_user: User = Depends(get_current_user), +): + """测试 HTTP 工具(不保存到数据库)。""" + result = await tool_registry.test_http_tool( + url=req.url, + method=req.method, + headers=req.headers, + body=req.body, + args=req.args, + timeout=req.timeout, + ) + return TestResponse(**result) + + +@router.post("/test/code", response_model=TestResponse) +async def test_code_tool( + req: TestCodeRequest, + current_user: User = Depends(get_current_user), +): + """测试代码工具(不保存到数据库)。""" + result = await tool_registry.test_code_tool( + source=req.source, + args=req.args, + ) + return TestResponse(**result) + + +# ─── 使用计数 ────────────────────────────────────────────────── + + +@router.post("/{tool_id}/use") +async def record_tool_use( + tool_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """记录工具使用次数(Agent 执行时自动调用)。""" + tool = db.query(Tool).filter(Tool.id == tool_id).first() + if not tool: + raise HTTPException(status_code=404, detail="工具不存在") + tool.use_count = (tool.use_count or 0) + 1 + db.commit() + return {"use_count": tool.use_count} diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 72630ce..cc33173 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -56,6 +56,11 @@ class Settings(BaseSettings): # DeepSeek配置 DEEPSEEK_API_KEY: str = "" DEEPSEEK_BASE_URL: str = "https://api.deepseek.com" + + # SiliconFlow配置(Embedding 推荐使用 SiliconFlow) + SILICONFLOW_API_KEY: str = "" + SILICONFLOW_BASE_URL: str = "https://api.siliconflow.cn/v1" + SILICONFLOW_EMBEDDING_MODEL: str = "netease-youdao/bce-embedding-base_v1" # Anthropic配置 ANTHROPIC_API_KEY: str = "" diff --git a/backend/app/core/database.py b/backend/app/core/database.py index e48b074..4cb9ad5 100644 --- a/backend/app/core/database.py +++ b/backend/app/core/database.py @@ -47,4 +47,6 @@ def init_db(): import app.models.permission import app.models.alert_rule import app.models.agent_llm_log + import app.models.agent_vector_memory + import app.models.knowledge_base Base.metadata.create_all(bind=engine) diff --git a/backend/app/main.py b/backend/app/main.py index c756e59..370a0cd 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -201,7 +201,7 @@ async def startup_event(): # 不抛出异常,允许应用继续启动 # 注册路由 -from app.api import auth, uploads, workflows, executions, websocket, execution_logs, data_sources, agents, platform_templates, model_configs, webhooks, template_market, batch_operations, collaboration, permissions, monitoring, alert_rules, node_test, node_templates, tools, agent_chat, agent_monitoring +from app.api import auth, uploads, workflows, executions, websocket, execution_logs, data_sources, agents, platform_templates, model_configs, webhooks, template_market, batch_operations, collaboration, permissions, monitoring, alert_rules, node_test, node_templates, tools, agent_chat, agent_monitoring, knowledge_base app.include_router(auth.router) app.include_router(uploads.router) @@ -225,6 +225,7 @@ app.include_router(node_templates.router) app.include_router(tools.router) app.include_router(agent_chat.router) app.include_router(agent_monitoring.router) +app.include_router(knowledge_base.router) if __name__ == "__main__": import uvicorn diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 6b50fac..13b3fda 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -13,5 +13,7 @@ from app.models.permission import Role, Permission, WorkflowPermission, AgentPer from app.models.alert_rule import AlertRule, AlertLog from app.models.persistent_user_memory import PersistentUserMemory from app.models.agent_llm_log import AgentLLMLog +from app.models.agent_vector_memory import AgentVectorMemory +from app.models.knowledge_base import KnowledgeBase, Document, DocumentChunk -__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog"] \ No newline at end of file +__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "KnowledgeBase", "Document", "DocumentChunk"] \ No newline at end of file diff --git a/backend/app/models/agent_vector_memory.py b/backend/app/models/agent_vector_memory.py new file mode 100644 index 0000000..35da35a --- /dev/null +++ b/backend/app/models/agent_vector_memory.py @@ -0,0 +1,45 @@ +""" +Agent 向量记忆模型。 + +存储对话片段的 embedding 向量,支持语义检索。 +""" +from __future__ import annotations + +import uuid +from datetime import datetime + +from sqlalchemy import Column, String, Text, DateTime, Index +from sqlalchemy.dialects.mysql import JSON as MySQLJSON + +from app.core.database import Base + + +class AgentVectorMemory(Base): + """Agent 向量记忆 — 存储对话文本及 embedding""" + + __tablename__ = "agent_vector_memories" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + scope_kind = Column(String(16), nullable=False, index=True, comment="作用域类型: agent / bare") + scope_id = Column(String(36), nullable=False, index=True, comment="作用域 ID: agent_id / user_id") + session_key = Column(String(128), nullable=False, default="", comment="会话标识") + content_text = Column(Text, nullable=False, comment="原始对话文本") + embedding = Column(Text, nullable=True, comment="JSON 序列化的 embedding 向量") + metadata_ = Column("metadata", MySQLJSON, nullable=True, comment="元数据: {type, iteration, ...}") + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + + __table_args__ = ( + Index("ix_agent_vector_memory_scope", "scope_kind", "scope_id"), + ) + + def to_dict(self) -> dict: + return { + "id": self.id, + "scope_kind": self.scope_kind, + "scope_id": self.scope_id, + "session_key": self.session_key, + "content_text": self.content_text, + "embedding": self.embedding, + "metadata": self.metadata_ or {}, + "created_at": self.created_at.isoformat() if self.created_at else None, + } diff --git a/backend/app/models/knowledge_base.py b/backend/app/models/knowledge_base.py new file mode 100644 index 0000000..c90fac1 --- /dev/null +++ b/backend/app/models/knowledge_base.py @@ -0,0 +1,116 @@ +""" +知识库 RAG 模型。 + +KnowledgeBase — 知识库容器 +Document — 上传的源文档 +DocumentChunk — 文档切片(含 embedding) +""" +from __future__ import annotations + +import uuid +from datetime import datetime + +from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, Index +from sqlalchemy.dialects.mysql import JSON as MySQLJSON +from sqlalchemy.orm import relationship + +from app.core.database import Base + + +class KnowledgeBase(Base): + """知识库""" + __tablename__ = "knowledge_bases" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + name = Column(String(200), nullable=False, comment="知识库名称") + description = Column(Text, nullable=True, comment="描述") + user_id = Column(String(36), nullable=True, index=True, comment="创建者 ID") + chunk_size = Column(Integer, default=500, comment="分块大小(字符数)") + chunk_overlap = Column(Integer, default=50, comment="分块重叠(字符数)") + doc_count = Column(Integer, default=0, comment="文档数量") + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + documents = relationship("Document", back_populates="kb", cascade="all, delete-orphan", + passive_deletes=True) + + def to_dict(self) -> dict: + return { + "id": self.id, + "name": self.name, + "description": self.description, + "user_id": self.user_id, + "chunk_size": self.chunk_size, + "chunk_overlap": self.chunk_overlap, + "doc_count": self.doc_count, + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None, + } + + +class Document(Base): + """知识库文档""" + __tablename__ = "kb_documents" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + kb_id = Column(String(36), ForeignKey("knowledge_bases.id", ondelete="CASCADE"), + nullable=False, index=True, comment="所属知识库") + filename = Column(String(500), nullable=False, comment="原始文件名") + file_type = Column(String(20), nullable=False, comment="文件类型: txt/pdf/docx/md/csv") + file_size = Column(Integer, default=0, comment="文件大小(字节)") + status = Column(String(20), default="pending", + comment="状态: pending/processing/ready/failed") + error_message = Column(Text, nullable=True, comment="处理失败信息") + chunk_count = Column(Integer, default=0, comment="分块数量") + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + + kb = relationship("KnowledgeBase", back_populates="documents") + chunks = relationship("DocumentChunk", back_populates="document", + cascade="all, delete-orphan", passive_deletes=True) + + def to_dict(self) -> dict: + return { + "id": self.id, + "kb_id": self.kb_id, + "filename": self.filename, + "file_type": self.file_type, + "file_size": self.file_size, + "status": self.status, + "error_message": self.error_message, + "chunk_count": self.chunk_count, + "created_at": self.created_at.isoformat() if self.created_at else None, + } + + +class DocumentChunk(Base): + """文档切片 — 含 embedding 向量""" + __tablename__ = "kb_document_chunks" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + document_id = Column(String(36), ForeignKey("kb_documents.id", ondelete="CASCADE"), + nullable=False, index=True, comment="所属文档") + kb_id = Column(String(36), ForeignKey("knowledge_bases.id", ondelete="CASCADE"), + nullable=False, index=True, comment="所属知识库") + chunk_index = Column(Integer, default=0, comment="块序号") + content = Column(Text, nullable=False, comment="切片文本内容") + embedding = Column(Text, nullable=True, comment="JSON 序列化的 embedding 向量") + metadata_ = Column("metadata", MySQLJSON, nullable=True, comment="元数据") + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + + document = relationship("Document", back_populates="chunks") + + __table_args__ = ( + Index("ix_kb_chunks_kb_id", "kb_id"), + Index("ix_kb_chunks_doc_id", "document_id"), + ) + + def to_dict(self) -> dict: + return { + "id": self.id, + "document_id": self.document_id, + "kb_id": self.kb_id, + "chunk_index": self.chunk_index, + "content": self.content[:200] + "..." if len(self.content) > 200 else self.content, + "metadata": self.metadata_ or {}, + "created_at": self.created_at.isoformat() if self.created_at else None, + } diff --git a/backend/app/services/document_parser.py b/backend/app/services/document_parser.py new file mode 100644 index 0000000..c084399 --- /dev/null +++ b/backend/app/services/document_parser.py @@ -0,0 +1,101 @@ +""" +文档解析器 — 支持 txt / pdf / docx / md / csv。 + +解析结果统一返回纯文本字符串,由调用方做分块处理。 +""" +from __future__ import annotations + +import csv +import io +import logging +import os +from typing import Optional + +logger = logging.getLogger(__name__) + + +def parse_document(file_path: str, file_type: str) -> Optional[str]: + """ + 解析文档为纯文本。 + + Args: + file_path: 文件绝对路径 + file_type: 文件类型(txt / pdf / docx / md / csv) + + Returns: + 文本内容,解析失败返回 None + """ + if not os.path.isfile(file_path): + logger.warning("文件不存在: %s", file_path) + return None + + parsers = { + "txt": _parse_text, + "md": _parse_text, + "pdf": _parse_pdf, + "docx": _parse_docx, + "csv": _parse_csv, + } + parser = parsers.get(file_type) + if not parser: + logger.warning("不支持的文件类型: %s", file_type) + return None + + try: + text = parser(file_path) + if text: + text = text.strip() + logger.info("文档解析完成: %s (%s, %d 字符)", file_path, file_type, len(text or "")) + return text + except Exception as e: + logger.error("文档解析失败: %s (%s)", file_path, e, exc_info=True) + return None + + +def _parse_text(file_path: str) -> str: + """解析纯文本 / Markdown 文件。""" + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + return f.read() + + +def _parse_pdf(file_path: str) -> str: + """解析 PDF 文件。""" + from pypdf import PdfReader + + reader = PdfReader(file_path) + pages = [] + for page in reader.pages: + text = page.extract_text() + if text: + pages.append(text) + return "\n\n".join(pages) + + +def _parse_docx(file_path: str) -> str: + """解析 Word 文档。""" + from docx import Document as DocxDocument + + doc = DocxDocument(file_path) + paragraphs = [] + for para in doc.paragraphs: + if para.text and para.text.strip(): + paragraphs.append(para.text.strip()) + + # 也提取表格内容 + for table in doc.tables: + for row in table.rows: + cells = [cell.text.strip() for cell in row.cells if cell.text.strip()] + if cells: + paragraphs.append(" | ".join(cells)) + + return "\n\n".join(paragraphs) + + +def _parse_csv(file_path: str) -> str: + """解析 CSV 文件为文本表格。""" + rows = [] + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + reader = csv.reader(f) + for row in reader: + rows.append(" | ".join(cell.strip() for cell in row)) + return "\n".join(rows) diff --git a/backend/app/services/embedding_service.py b/backend/app/services/embedding_service.py new file mode 100644 index 0000000..f301a26 --- /dev/null +++ b/backend/app/services/embedding_service.py @@ -0,0 +1,229 @@ +""" +Embedding 生成与语义检索服务。 + +使用 OpenAI text-embedding-3-small 生成文本向量, +在内存中计算余弦相似度实现语义搜索。 + +如未配置 OpenAI API Key,所有方法静默降级返回空结果。 +""" +from __future__ import annotations + +import asyncio +import json +import logging +import time +from typing import Any, Dict, List, Optional, TypedDict + +from app.core.config import settings + +logger = logging.getLogger(__name__) + +# 默认 embedding 模型 +EMBEDDING_MODEL = "text-embedding-3-small" +EMBEDDING_DIMENSIONS = 1536 + + +class VectorEntry(TypedDict, total=False): + """向量条目""" + id: str + scope_kind: str + scope_id: str + content_text: str + embedding: List[float] + metadata: Dict[str, Any] + score: float # 余弦相似度,仅检索结果包含 + + +class EmbeddingService: + """ + Embedding 服务。 + + 用法: + svc = EmbeddingService() + emb = await svc.generate_embedding("你好") + results = await svc.similarity_search(query_emb, entries, top_k=5) + """ + + def __init__(self): + self._client: Optional[Any] = None + self._client_lock = asyncio.Lock() + + async def _get_client(self): + """延迟初始化 OpenAI 客户端(仅在首次调用时创建)。 + + 优先级:SiliconFlow > OpenAI > DeepSeek。 + SiliconFlow 的 netease-youdao/bce-embedding-base_v1 在国内可直连且免费。 + """ + if self._client is not None: + return self._client + + async with self._client_lock: + if self._client is not None: + return self._client + + from openai import AsyncOpenAI + + api_key: Optional[str] = None + base_url: Optional[str] = None + backend_label = "none" + + # 1) SiliconFlow(国内直连,推荐) + if settings.SILICONFLOW_API_KEY: + api_key = settings.SILICONFLOW_API_KEY + base_url = settings.SILICONFLOW_BASE_URL or "https://api.siliconflow.cn/v1" + backend_label = "siliconflow" + logger.info("Embedding 后端: SiliconFlow (model=%s)", settings.SILICONFLOW_EMBEDDING_MODEL) + + # 2) OpenAI + if not api_key: + if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key": + api_key = settings.OPENAI_API_KEY + base_url = settings.OPENAI_BASE_URL or "https://api.openai.com/v1" + backend_label = "openai" + + # 3) DeepSeek(部分代理可能支持 embedding) + if not api_key: + if settings.DEEPSEEK_API_KEY: + api_key = settings.DEEPSEEK_API_KEY + base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com" + backend_label = "deepseek" + + if not api_key: + logger.info("未配置任何 API Key,向量记忆功能已禁用(请配置 SILICONFLOW_API_KEY 或 OPENAI_API_KEY)") + self._client = None + return None + + self._client = AsyncOpenAI( + api_key=api_key, + base_url=base_url, + ) + logger.info("EmbeddingService 已初始化 (backend=%s)", backend_label) + return self._client + + def _get_model(self) -> str: + """根据后端选择对应的 embedding 模型。""" + if settings.SILICONFLOW_API_KEY: + return settings.SILICONFLOW_EMBEDDING_MODEL + return EMBEDDING_MODEL + + async def generate_embedding(self, text: str) -> Optional[List[float]]: + """ + 为单段文本生成 embedding 向量。 + + 返回 float 列表,无 API key 或出错时返回 None。 + """ + if not text or not text.strip(): + return None + + client = await self._get_client() + if not client: + return None + + model = self._get_model() + try: + kwargs: Dict[str, Any] = { + "model": model, + "input": text.strip()[:8000], + } + # OpenAI text-embedding-3-small 支持指定 dimensions;其它模型可能不支持 + if model == EMBEDDING_MODEL: + kwargs["dimensions"] = EMBEDDING_DIMENSIONS + resp = await client.embeddings.create(**kwargs) + return resp.data[0].embedding + except Exception as e: + logger.warning("生成 embedding 失败: %s", e) + return None + + async def generate_embeddings(self, texts: List[str]) -> List[Optional[List[float]]]: + """ + 批量生成 embedding 向量。 + """ + if not texts: + return [] + + client = await self._get_client() + if not client: + return [None] * len(texts) + + # 清理空文本 + valid_indices = [i for i, t in enumerate(texts) if t and t.strip()] + if not valid_indices: + return [None] * len(texts) + + model = self._get_model() + try: + inputs = [texts[i].strip()[:8000] for i in valid_indices] + kwargs: Dict[str, Any] = { + "model": model, + "input": inputs, + } + if model == EMBEDDING_MODEL: + kwargs["dimensions"] = EMBEDDING_DIMENSIONS + resp = await client.embeddings.create(**kwargs) + embeddings = [r.embedding for r in resp.data] + except Exception as e: + logger.warning("批量生成 embedding 失败: %s", e) + return [None] * len(texts) + + # 按原始顺序排列结果 + result: List[Optional[List[float]]] = [None] * len(texts) + for idx, emb in zip(valid_indices, embeddings): + result[idx] = emb + return result + + @staticmethod + def cosine_similarity(a: List[float], b: List[float]) -> float: + """计算两个向量的余弦相似度。""" + if not a or not b or len(a) != len(b): + return 0.0 + + dot = sum(x * y for x, y in zip(a, b)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(y * y for y in b) ** 0.5 + + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + async def similarity_search( + self, + query_embedding: List[float], + entries: List[VectorEntry], + top_k: int = 5, + min_score: float = 0.3, + ) -> List[VectorEntry]: + """ + 在内存中对 entries 做余弦相似度搜索,返回 Top-K 结果(按得分降序)。 + min_score: 最低相似度阈值,低于该值的结果被过滤。 + """ + scored: List[VectorEntry] = [] + for entry in entries: + emb = entry.get("embedding") + if not emb: + continue + score = self.cosine_similarity(query_embedding, emb) + if score >= min_score: + entry["score"] = score + scored.append(entry) + + scored.sort(key=lambda x: x["score"], reverse=True) + return scored[:top_k] + + @staticmethod + def serialize_embedding(embedding: List[float]) -> str: + """将 embedding 序列化为 JSON 字符串。""" + return json.dumps(embedding, ensure_ascii=False) + + @staticmethod + def deserialize_embedding(data: str) -> List[float]: + """从 JSON 字符串反序列化 embedding。""" + if isinstance(data, list): + return data + try: + return json.loads(data) + except (json.JSONDecodeError, TypeError): + return [] + + +# 全局单例 +embedding_service = EmbeddingService() diff --git a/backend/app/services/knowledge_service.py b/backend/app/services/knowledge_service.py new file mode 100644 index 0000000..239ec0f --- /dev/null +++ b/backend/app/services/knowledge_service.py @@ -0,0 +1,375 @@ +""" +知识库服务 — 文档管理、分块、Embedding、语义检索、RAG。 + +流程: + 上传 → 解析 → 分块 → Embedding → 存储 + 检索 → 查询 → Embedding → 余弦相似度 → Top-K + RAG → 检索 + 格式化上下文 → 返回给 Agent/用户 +""" +from __future__ import annotations + +import json +import logging +import os +import uuid +from typing import Any, Dict, List, Optional + +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.models.knowledge_base import KnowledgeBase, Document, DocumentChunk +from app.services.document_parser import parse_document +from app.services.embedding_service import embedding_service, VectorEntry +from app.services.text_chunker import chunk_text + +logger = logging.getLogger(__name__) + +# 上传文件存储根目录 +UPLOAD_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "kb_uploads", +) + + +def _ensure_upload_dir(): + """确保上传目录存在。""" + os.makedirs(UPLOAD_DIR, exist_ok=True) + + +def _get_kb_dir(kb_id: str) -> str: + """返回知识库文件存放目录。""" + d = os.path.join(UPLOAD_DIR, kb_id) + os.makedirs(d, exist_ok=True) + return d + + +# ─── 知识库 CRUD ──────────────────────────────────────────────── + + +def create_knowledge_base( + db: Session, + name: str, + user_id: str, + description: str = "", + chunk_size: int = 500, + chunk_overlap: int = 50, +) -> KnowledgeBase: + """创建知识库。""" + kb = KnowledgeBase( + name=name, + description=description, + user_id=user_id, + chunk_size=max(50, min(2000, chunk_size)), + chunk_overlap=max(0, min(chunk_size // 2, chunk_overlap)), + ) + db.add(kb) + db.commit() + db.refresh(kb) + logger.info("知识库已创建: %s (%s)", kb.name, kb.id) + return kb + + +def list_knowledge_bases(db: Session, user_id: Optional[str] = None) -> List[KnowledgeBase]: + """列出知识库。""" + q = db.query(KnowledgeBase) + if user_id: + q = q.filter(KnowledgeBase.user_id == user_id) + return q.order_by(KnowledgeBase.updated_at.desc()).all() + + +def get_knowledge_base(db: Session, kb_id: str) -> Optional[KnowledgeBase]: + """获取知识库详情。""" + return db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() + + +def delete_knowledge_base(db: Session, kb_id: str) -> bool: + """删除知识库(连带文档和分块)。""" + kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() + if not kb: + return False + # 删除磁盘文件 + kb_dir = _get_kb_dir(kb_id) + if os.path.isdir(kb_dir): + import shutil + shutil.rmtree(kb_dir, ignore_errors=True) + db.delete(kb) + db.commit() + logger.info("知识库已删除: %s", kb_id) + return True + + +# ─── 文档管理 ─────────────────────────────────────────────────── + + +async def upload_document( + db: Session, + kb_id: str, + filename: str, + file_content: bytes, +) -> Document: + """ + 上传文档到知识库: + 1. 保存原始文件 + 2. 解析为文本 + 3. 分块 + 4. 生成 Embedding + 5. 存储分块 + """ + kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() + if not kb: + raise ValueError(f"知识库不存在: {kb_id}") + + _ensure_upload_dir() + kb_dir = _get_kb_dir(kb_id) + + # 提取文件类型 + ext = os.path.splitext(filename)[1].lower().lstrip(".") + if ext in ("txt", "md", "pdf", "docx", "csv"): + file_type = ext + else: + file_type = "txt" + + # 保存原始文件 + file_path = os.path.join(kb_dir, f"{uuid.uuid4()}_{filename}") + with open(file_path, "wb") as f: + f.write(file_content) + file_size = len(file_content) + + # 创建文档记录 + doc = Document( + kb_id=kb_id, + filename=filename, + file_type=file_type, + file_size=file_size, + status="processing", + ) + db.add(doc) + db.commit() + db.refresh(doc) + + try: + # 解析文本 + text = parse_document(file_path, file_type) + if not text: + doc.status = "failed" + doc.error_message = "文档解析失败或内容为空" + db.commit() + return doc + + # 分块 + chunks = chunk_text( + text, + chunk_size=kb.chunk_size, + chunk_overlap=kb.chunk_overlap, + ) + if not chunks: + doc.status = "failed" + doc.error_message = "文档分块后为空" + db.commit() + return doc + + logger.info("文档分块完成: %s → %d 块", filename, len(chunks)) + + # 批量生成 Embedding + embeddings: List[Optional[List[float]]] = [] + try: + embeddings = await embedding_service.generate_embeddings(chunks) + except Exception as e: + logger.warning("批量生成 embedding 失败,逐块回退: %s", e) + for c in chunks: + try: + emb = await embedding_service.generate_embedding(c) + embeddings.append(emb) + except Exception: + embeddings.append(None) + + # 存储分块 + chunk_records = [] + for i, (chunk_text_content, emb) in enumerate(zip(chunks, embeddings)): + record = DocumentChunk( + document_id=doc.id, + kb_id=kb_id, + chunk_index=i, + content=chunk_text_content, + embedding=json.dumps(emb) if emb else None, + metadata_={ + "filename": filename, + "file_type": file_type, + "chunk_index": i, + "has_embedding": emb is not None, + }, + ) + chunk_records.append(record) + + db.add_all(chunk_records) + + # 更新文档状态 + doc.status = "ready" + doc.chunk_count = len(chunks) + kb.doc_count = db.query(Document).filter( + Document.kb_id == kb_id, Document.status == "ready" + ).count() + + db.commit() + logger.info("文档处理完成: %s (%d 块, embedding=%s)", + filename, len(chunks), + "yes" if any(e for e in embeddings) else "no") + + except Exception as e: + db.rollback() + doc = db.query(Document).filter(Document.id == doc.id).first() + if doc: + doc.status = "failed" + doc.error_message = str(e)[:500] + db.commit() + logger.error("文档处理失败: %s: %s", filename, e, exc_info=True) + + return doc + + +def list_documents(db: Session, kb_id: str) -> List[Document]: + """列出知识库中的文档。""" + return ( + db.query(Document) + .filter(Document.kb_id == kb_id) + .order_by(Document.created_at.desc()) + .all() + ) + + +def delete_document(db: Session, doc_id: str) -> bool: + """删除文档(连带分块)。""" + doc = db.query(Document).filter(Document.id == doc_id).first() + if not doc: + return False + # 减少知识库文档计数 + kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == doc.kb_id).first() + db.delete(doc) + if kb: + kb.doc_count = db.query(Document).filter( + Document.kb_id == kb.id, Document.status == "ready" + ).count() + db.commit() + logger.info("文档已删除: %s", doc_id) + return True + + +# ─── 语义检索 ─────────────────────────────────────────────────── + + +async def search( + db: Session, + kb_id: str, + query: str, + top_k: int = 5, + min_score: float = 0.3, +) -> List[Dict[str, Any]]: + """ + 语义搜索知识库。 + + 流程:查询文本 → Embedding → 余弦相似度匹配所有分块 → Top-K + """ + # 生成查询 embedding + query_emb = await embedding_service.generate_embedding(query) + if not query_emb: + logger.warning("搜索失败:无法生成查询 embedding") + return [] + + # 加载该知识库所有分块 + chunks = ( + db.query(DocumentChunk) + .filter(DocumentChunk.kb_id == kb_id) + .all() + ) + + if not chunks: + return [] + + # 构建 VectorEntry 列表 + entries: List[VectorEntry] = [] + for c in chunks: + if not c.embedding: + continue + try: + emb = json.loads(c.embedding) if isinstance(c.embedding, str) else c.embedding + except (json.JSONDecodeError, TypeError): + continue + entries.append({ + "id": c.id, + "scope_kind": "kb", + "scope_id": kb_id, + "content_text": c.content, + "embedding": emb, + "metadata": c.metadata_ or {}, + }) + + if not entries: + return [] + + # 相似度搜索 + matched = await embedding_service.similarity_search( + query_emb, entries, top_k=top_k, min_score=min_score + ) + + results = [] + for m in matched: + results.append({ + "chunk_id": m["id"], + "content": m["content_text"], + "score": m["score"], + "metadata": m.get("metadata", {}), + }) + + return results + + +async def rag_query( + db: Session, + kb_id: str, + query: str, + top_k: int = 5, + min_score: float = 0.3, +) -> Dict[str, Any]: + """ + RAG 查询:搜索相关片段 + 格式化为上下文。 + + 返回: + { + "query": "...", + "context": "根据以下资料回答:\n\n[片段1]\n[片段2]\n...", + "sources": [...], + } + """ + results = await search(db, kb_id, query, top_k=top_k, min_score=min_score) + + if not results: + return { + "query": query, + "context": "", + "sources": [], + "found": False, + } + + # 格式化上下文 + lines = ["根据以下资料回答用户问题:\n"] + for i, r in enumerate(results, 1): + source = r["metadata"].get("filename", "未知来源") + lines.append(f"[{i}] (来源: {source}):\n{r['content']}\n") + + context = "\n".join(lines) + + sources = [ + { + "content": r["content"], + "score": r["score"], + "source": r["metadata"].get("filename", "未知"), + } + for r in results + ] + + return { + "query": query, + "context": context, + "sources": sources, + "found": True, + } diff --git a/backend/app/services/text_chunker.py b/backend/app/services/text_chunker.py new file mode 100644 index 0000000..d9b221f --- /dev/null +++ b/backend/app/services/text_chunker.py @@ -0,0 +1,132 @@ +""" +文本分块器 — 将长文本切分为适合 Embedding + 检索的块。 + +策略: +1. 按段落分割(连续换行) +2. 超长段落按句子切分(句号、问号、感叹号、换行) +3. 短段落合并,直到接近 chunk_size +4. chunk 之间保留 overlap 字符的重叠 +""" +from __future__ import annotations + +import re +from typing import List + + +def chunk_text( + text: str, + chunk_size: int = 500, + chunk_overlap: int = 50, +) -> List[str]: + """ + 将文本分割为语义完整的块。 + + Args: + text: 输入文本 + chunk_size: 每块目标字符数 + chunk_overlap: 相邻块重叠字符数 + + Returns: + 文本块列表 + """ + if not text or not text.strip(): + return [] + + # 1. 按段落分割 + paragraphs = _split_paragraphs(text) + if not paragraphs: + return [] + + # 2. 处理每个段落:超长段落进一步拆分 + segments: List[str] = [] + for para in paragraphs: + if len(para) <= chunk_size: + segments.append(para) + else: + segments.extend(_split_long_paragraph(para, chunk_size)) + + # 3. 合并短段落为块 + chunks = _merge_segments(segments, chunk_size, chunk_overlap) + + return chunks + + +def _split_paragraphs(text: str) -> List[str]: + """按连续换行分割段落,过滤空白段落。""" + # 按两个以上换行分割 + raw = re.split(r"\n\s*\n", text) + paras = [] + for p in raw: + p = p.strip() + if p: + paras.append(p) + return paras + + +def _split_long_paragraph(para: str, chunk_size: int) -> List[str]: + """将超长段落按句子切分。""" + # 先按句子结束符分割(去掉 look-behind 长度限制) + sentences = re.split(r"(?<=[。!?])|(?<=[.!?])\s*", para) + sentences = [s.strip() for s in sentences if s.strip()] + + if not sentences: + # 无法按句子分割,按字符硬切 + return [para[i : i + chunk_size] for i in range(0, len(para), chunk_size)] + + chunks: List[str] = [] + current = "" + for sent in sentences: + if len(current) + len(sent) <= chunk_size: + current += sent + else: + if current: + chunks.append(current.strip()) + current = sent + if current: + chunks.append(current.strip()) + + # 如果某个块还是太长(单句超长),按字符硬切 + result: List[str] = [] + for c in chunks: + if len(c) <= chunk_size: + result.append(c) + else: + for i in range(0, len(c), chunk_size): + result.append(c[i : i + chunk_size]) + return result + + +def _merge_segments(segments: List[str], chunk_size: int, overlap: int) -> List[str]: + """合并短段落为块,块间带重叠。""" + if not segments: + return [] + + chunks: List[str] = [] + current = "" + + for seg in segments: + # 如果当前块 + 新段不超过上限,追加 + if not current: + current = seg + elif len(current) + 1 + len(seg) <= chunk_size: + current += "\n\n" + seg + else: + # 当前块已满 + chunks.append(current.strip()) + + # 重叠:取前一块末尾的 overlap 字符 + if overlap > 0 and len(current) > overlap: + # 从上一个块末尾取 overlap 字符作为新块的起始 + tail = current[-overlap:] + # 从上一个换行处开始,避免截断句子 + newline_pos = tail.find("\n") + if newline_pos > 0 and newline_pos < overlap // 2: + tail = tail[newline_pos + 1 :] + current = tail + "\n\n" + seg + else: + current = seg + + if current.strip(): + chunks.append(current.strip()) + + return chunks diff --git a/backend/app/services/tool_registry.py b/backend/app/services/tool_registry.py index ce0a9b5..f49e625 100644 --- a/backend/app/services/tool_registry.py +++ b/backend/app/services/tool_registry.py @@ -1,123 +1,360 @@ """ -工具注册表 - 管理所有可用工具 +工具注册表 — 管理所有可用工具(内置 / HTTP / 代码段)。 + +提供统一的工具注册、查找和执行接口。 """ -from typing import Dict, Any, Callable, Optional, List +from __future__ import annotations + import json import logging +import traceback +from typing import Any, Callable, Dict, List, Optional + from app.models.tool import Tool from sqlalchemy.orm import Session logger = logging.getLogger(__name__) +# 代码工具的安全内置模块 +_CODE_SAFE_GLOBALS = { + "json": json, + "dict": dict, + "list": list, + "str": str, + "int": int, + "float": float, + "bool": bool, + "len": len, + "range": range, + "enumerate": enumerate, + "zip": zip, + "map": map, + "filter": filter, + "sorted": sorted, + "min": min, + "max": max, + "sum": sum, + "abs": abs, + "round": round, + "isinstance": isinstance, + "type": type, + "True": True, + "False": False, + "None": None, +} + + class ToolRegistry: - """工具注册表 - 管理所有可用工具""" - + """工具注册表 — 管理所有可用工具。""" + def __init__(self): self._builtin_tools: Dict[str, Callable] = {} self._tool_schemas: Dict[str, Dict[str, Any]] = {} - + # 自定义工具配置(从 DB 加载的非 builtin 工具) + self._custom_tool_configs: Dict[str, Dict[str, Any]] = {} + + # ─── 内置工具注册 ───────────────────────────────────────── + def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]): """ - 注册内置工具 - + 注册内置工具。 + Args: name: 工具名称 - func: 工具函数(可以是同步或异步函数) - schema: 工具定义(OpenAI Function格式) + func: 工具函数(同步或异步) + schema: OpenAI function calling 格式 schema """ self._builtin_tools[name] = func self._tool_schemas[name] = schema logger.debug("注册内置工具: %s", name) - + + # ─── 工具信息查询 ───────────────────────────────────────── + def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]: - """获取工具定义""" return self._tool_schemas.get(name) - + def get_tool_function(self, name: str) -> Optional[Callable]: - """获取工具函数""" return self._builtin_tools.get(name) - + def get_all_tool_schemas(self) -> List[Dict[str, Any]]: - """获取所有工具定义(用于LLM)""" return list(self._tool_schemas.values()) def builtin_tool_count(self) -> int: - """已注册的内置工具数量(用于健康检查 / 启动自检)。""" return len(self._builtin_tools) def builtin_tool_names(self) -> List[str]: - """已注册的内置工具名称,有序列表。""" return sorted(self._builtin_tools.keys()) - def load_tools_from_db(self, db: Session, tool_names: List[str] = None): - """ - 从数据库加载工具 - - Args: - db: 数据库会话 - tool_names: 工具名称列表(可选,如果为None则加载所有公开工具) - """ - query = db.query(Tool).filter(Tool.is_public == True) - if tool_names: - query = query.filter(Tool.name.in_(tool_names)) - - tools = query.all() - for tool in tools: - self._tool_schemas[tool.name] = tool.function_schema - # 根据implementation_type加载工具实现 - if tool.implementation_type == 'builtin': - # 从内置工具中查找 - if tool.name in self._builtin_tools: - logger.debug(f"工具 {tool.name} 已在内置工具中注册") - else: - logger.warning(f"工具 {tool.name} 标记为builtin但未在内置工具中找到") - elif tool.implementation_type == 'http': - # HTTP工具需要特殊处理 - self._register_http_tool(tool) - elif tool.implementation_type == 'workflow': - # 工作流工具 - self._register_workflow_tool(tool) - elif tool.implementation_type == 'code': - # 代码执行工具 - self._register_code_tool(tool) - - logger.info(f"从数据库加载了 {len(tools)} 个工具") - - def _register_http_tool(self, tool: Tool): - """注册HTTP工具""" - # TODO: 实现HTTP工具的动态注册 - logger.warning(f"HTTP工具 {tool.name} 的动态注册尚未实现") - - def _register_workflow_tool(self, tool: Tool): - """注册工作流工具""" - # TODO: 实现工作流工具的动态注册 - logger.warning(f"工作流工具 {tool.name} 的动态注册尚未实现") - - def _register_code_tool(self, tool: Tool): - """注册代码执行工具""" - # TODO: 实现代码执行工具的动态注册 - logger.warning(f"代码执行工具 {tool.name} 的动态注册尚未实现") - def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]: - """ - 根据工具名称列表获取工具定义 - - Args: - tool_names: 工具名称列表 - - Returns: - 工具定义列表(OpenAI Function格式) - """ tools = [] for name in tool_names: schema = self.get_tool_schema(name) if schema: tools.append(schema) else: - logger.warning(f"工具 {name} 未找到") + logger.warning("工具 %s 未找到", name) return tools + # ─── 从数据库加载自定义工具 ────────────────────────────── -# 全局工具注册表实例 + def load_tools_from_db(self, db: Session, tool_names: Optional[List[str]] = None): + """ + 从数据库加载工具定义。 + + Args: + db: 数据库会话 + tool_names: 指定名称列表;None 则加载所有公开工具 + """ + query = db.query(Tool).filter(Tool.is_public == True) + if tool_names: + query = query.filter(Tool.name.in_(tool_names)) + + tools = query.all() + count = 0 + for tool in tools: + self._tool_schemas[tool.name] = tool.function_schema + if tool.implementation_type == "builtin": + if tool.name not in self._builtin_tools: + logger.warning("工具 %s 标记为 builtin 但未注册", tool.name) + else: + # 存储自定义工具配置 + config = tool.implementation_config or {} + config["_type"] = tool.implementation_type + config["_db_id"] = tool.id + self._custom_tool_configs[tool.name] = config + count += 1 + + if count: + logger.info("从数据库加载了 %d 个自定义工具", count) + + # ─── 工具执行 ───────────────────────────────────────────── + + async def execute_tool(self, name: str, args: Dict[str, Any]) -> str: + """ + 执行任意工具(内置 / HTTP / 代码段)。 + + Args: + name: 工具名称 + args: 参数字典 + + Returns: + 结果字符串 + """ + # 1. 内置工具 + func = self._builtin_tools.get(name) + if func: + return await self._run_function(func, name, args) + + # 2. 自定义工具 + config = self._custom_tool_configs.get(name) + if not config: + return json.dumps({"error": f"工具 '{name}' 不存在"}, ensure_ascii=False) + + impl_type = config.get("_type", "") + try: + if impl_type == "http": + return await self._execute_http_tool(name, config, args) + elif impl_type == "code": + return await self._execute_code_tool(name, config, args) + elif impl_type == "workflow": + return json.dumps({"error": "工作流工具暂不支持动态执行"}, + ensure_ascii=False) + else: + return json.dumps({"error": f"不支持的实现类型: {impl_type}"}, + ensure_ascii=False) + except Exception as e: + logger.error("工具 %s 执行失败: %s", name, e, exc_info=True) + return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"}, + ensure_ascii=False) + + @staticmethod + async def _run_function(func: Callable, name: str, args: Dict[str, Any]) -> str: + """执行内置工具函数。""" + import asyncio + try: + if asyncio.iscoroutinefunction(func): + result = await func(**args) + else: + result = func(**args) + if isinstance(result, (dict, list)): + return json.dumps(result, ensure_ascii=False) + return str(result) + except Exception as e: + logger.error("工具 '%s' 执行失败: %s", name, e, exc_info=True) + return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"}, + ensure_ascii=False) + + # ─── HTTP 工具 ──────────────────────────────────────────── + + async def _execute_http_tool( + self, name: str, config: Dict[str, Any], args: Dict[str, Any] + ) -> str: + """ + 执行 HTTP 工具。 + + implementation_config 格式: + { + "url": "https://api.example.com/{path_param}", + "method": "GET", + "headers": {"Authorization": "Bearer xxx"}, + "body_template": {"key": "{param}"}, # 可选 + "timeout": 30 + } + URL 和 body_template 中的 {param} 会被 args 替换。 + """ + import httpx + + url = config.get("url", "") + method = (config.get("method") or "GET").upper() + headers = config.get("headers") or {} + body_template = config.get("body_template") + timeout = config.get("timeout", 30) + + if not url: + return json.dumps({"error": "HTTP 工具未配置 URL"}, ensure_ascii=False) + + # 模板替换:{param} → args["param"] + def _fmt(template: str) -> str: + try: + return template.format(**args) + except KeyError: + return template + + url = _fmt(url) + body: Any = None + if body_template: + body_str = json.dumps(body_template) + body_str = _fmt(body_str) + try: + body = json.loads(body_str) + except json.JSONDecodeError: + body = body_str + + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.request(method, url, headers=headers, json=body) + text = response.text[:10000] # 截断过长的响应 + + result = { + "status_code": response.status_code, + "body": text, + } + return json.dumps(result, ensure_ascii=False) + + # ─── 代码工具 ───────────────────────────────────────────── + + async def _execute_code_tool( + self, name: str, config: Dict[str, Any], args: Dict[str, Any] + ) -> str: + """ + 执行代码工具。 + + implementation_config 格式: + { + "source": "def run(args):\\n return {'sum': args['a'] + args['b']}", + "language": "python" + } + 代码工具需定义一个 run(args) 函数作为入口。 + source 在沙箱环境中执行,不可访问文件系统/网络。 + """ + source = config.get("source", "") + if not source: + return json.dumps({"error": "代码工具未配置 source"}, ensure_ascii=False) + + # 编译代码,限制可访问的全局变量 + safe_globals = _CODE_SAFE_GLOBALS.copy() + safe_globals["__builtins__"] = {} # 禁用所有内置函数 + + try: + exec(source, safe_globals) + except Exception as e: + return json.dumps({ + "error": f"代码编译失败: {e}", + "traceback": traceback.format_exc(), + }, ensure_ascii=False) + + run_func = safe_globals.get("run") + if not run_func or not callable(run_func): + return json.dumps({ + "error": "代码工具必须定义一个 run(args) 函数" + }, ensure_ascii=False) + + try: + result = run_func(args) + if isinstance(result, (dict, list)): + return json.dumps(result, ensure_ascii=False) + return str(result) + except Exception as e: + return json.dumps({ + "error": f"代码执行失败: {e}", + "traceback": traceback.format_exc(), + }, ensure_ascii=False) + + # ─── 测试工具(不保存,直接执行)───────────────────────── + + async def test_http_tool( + self, url: str, method: str, headers: Dict[str, str], + body: Optional[Dict[str, Any]], args: Dict[str, Any], + timeout: int = 30, + ) -> Dict[str, Any]: + """测试 HTTP 工具配置(不保存到 DB)。""" + import httpx + + def _fmt(template: str) -> str: + try: + return template.format(**args) + except KeyError: + return template + + url = _fmt(url) + body_sent = None + if body: + body_str = json.dumps(body) + body_str = _fmt(body_str) + try: + body_sent = json.loads(body_str) + except json.JSONDecodeError: + body_sent = body_str + + start = __import__("time").time() + try: + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.request(method.upper(), url, + headers=headers, json=body_sent) + elapsed_ms = int((__import__("time").time() - start) * 1000) + return { + "success": True, + "status_code": response.status_code, + "elapsed_ms": elapsed_ms, + "body": response.text[:10000], + } + except Exception as e: + return {"success": False, "error": str(e)} + + async def test_code_tool( + self, source: str, args: Dict[str, Any] + ) -> Dict[str, Any]: + """测试代码工具(不保存到 DB)。""" + safe_globals = _CODE_SAFE_GLOBALS.copy() + safe_globals["__builtins__"] = {} + + try: + exec(source, safe_globals) + except Exception as e: + return {"success": False, "error": f"编译失败: {e}"} + + run_func = safe_globals.get("run") + if not run_func: + return {"success": False, "error": "代码须定义 run(args) 函数"} + + try: + start = __import__("time").time() + result = run_func(args) + elapsed_ms = int((__import__("time").time() - start) * 1000) + return {"success": True, "elapsed_ms": elapsed_ms, "result": result} + except Exception as e: + return {"success": False, "error": f"执行失败: {e}"} + + +# 全局单例 tool_registry = ToolRegistry() diff --git a/backend/app/services/workflow_engine.py b/backend/app/services/workflow_engine.py index ed6a37c..e6d72cf 100644 --- a/backend/app/services/workflow_engine.py +++ b/backend/app/services/workflow_engine.py @@ -1917,12 +1917,25 @@ class WorkflowEngine: if hasattr(self, '_on_tool_executed_budget'): _agent_on_tool = self._on_tool_executed_budget + # Agent 的 LLM 调用计入工作流预算 + def _on_agent_llm(): + self._llm_invocations += 1 + if self._llm_invocations > self._cap_llm: + raise WorkflowExecutionError( + detail=f"已超过 LLM 节点调用预算({self._cap_llm} 次)", + ) + result = await run_agent_node( node_data=node.get("data", {}), input_data=input_data, execution_logger=self.logger, user_id=self.trusted_model_config_user_id, on_tool_executed=_agent_on_tool, + on_llm_invocation=_on_agent_llm, + budget_limits={ + "max_llm_invocations": self._cap_llm, + "max_tool_calls": self._cap_tool, + }, ) if self.logger: duration = int((time.time() - start_time) * 1000) diff --git a/backend/tests/README.md b/backend/tests/README.md index c7c3278..e6d7f6f 100644 --- a/backend/tests/README.md +++ b/backend/tests/README.md @@ -46,23 +46,31 @@ pytest --cov=app --cov-report=html ## 测试标记 -- `@pytest.mark.unit` - 单元测试 -- `@pytest.mark.integration` - 集成测试 +- `@pytest.mark.unit` - 纯单元测试(不依赖网络/数据库) +- `@pytest.mark.integration` - 集成测试(需要网络或可选依赖) - `@pytest.mark.slow` - 慢速测试(需要网络或数据库) - `@pytest.mark.api` - API测试 - `@pytest.mark.workflow` - 工作流测试 - `@pytest.mark.auth` - 认证测试 +- `@pytest.mark.asyncio` - 异步测试(使用 `pytest-asyncio`) ## 测试结构 ``` tests/ ├── __init__.py -├── conftest.py # 共享fixtures和配置 -├── test_auth.py # 认证API测试 -├── test_workflows.py # 工作流API测试 -├── test_workflow_engine.py # 工作流引擎测试 -└── test_workflow_validator.py # 工作流验证器测试 +├── conftest.py # 共享fixtures和配置 +├── test_auth.py # 认证API测试 +├── test_workflows.py # 工作流API测试 +├── test_workflow_engine.py # 工作流引擎测试 +├── test_workflow_validator.py # 工作流验证器测试 +├── test_text_chunker.py # 文本分块器单元测试 +├── test_document_parser.py # 文档解析器单元测试 +├── test_tool_registry.py # 工具注册表单元测试 +├── test_tools_api.py # 工具市场API测试 +├── test_agent_memory.py # Agent记忆/上下文/Schema测试 +├── test_embedding_service.py # Embedding服务单元测试 +└── test_knowledge_base.py # 知识库RAG单元测试 ``` ## Fixtures @@ -84,13 +92,31 @@ tests/ ## 测试数据库 -测试使用SQLite内存数据库,每个测试函数都会: +测试使用SQLite临时文件数据库(而非 `:memory:`),每个测试函数都会: 1. 创建所有表 2. 执行测试 3. 删除所有表 这样可以确保测试之间的隔离性。 +**注意**:使用临时文件而非 `:memory:` 是因为 FastAPI + TestClient 在异步/多线程请求中,SQLite 内存数据库会发生"连接隔离"问题——每个连接看到不同的数据库实例。文件数据库确保了所有连接共享同一份数据。 + +**注意**:纯服务层测试(如 `test_tool_registry.py`、`test_embedding_service.py`、`test_text_chunker.py`)不依赖 `db_session` fixture,它们直接对服务类打桩或使用临时文件,运行速度更快。 + +## 异步测试配置 + +使用 `pytest-asyncio` 支持异步测试。所有被 `@pytest.mark.asyncio` 标记的测试函数必须使用 `async def` 声明: + +```python +@pytest.mark.unit +@pytest.mark.asyncio +async def test_my_async_method(self): + result = await my_service.my_method() + assert result is not None +``` + +**注意**:`conftest.py` 中的 fixtures 默认使用 `function` 作用域,异步 fixture 不需要额外配置。 + ## 编写新测试 ### 示例:API测试 @@ -128,6 +154,14 @@ class TestMyService: 5. **测试数据**:使用 fixtures 提供测试数据,避免硬编码。 +6. **已知问题**:有 9 个历史遗留测试失败(主要在 `test_auth.py`、`test_workflow_engine.py`、`test_workflow_validator.py`、`test_nodes_all.py`、`test_nodes_phase4.py` 中),这些失败与本次改造无关,可忽略检查新建的测试文件时排除它们: + ```bash + pytest tests/test_text_chunker.py tests/test_document_parser.py \ + tests/test_tool_registry.py tests/test_tools_api.py \ + tests/test_agent_memory.py tests/test_embedding_service.py \ + tests/test_knowledge_base.py -v + ``` + ## CI/CD集成 在CI/CD流程中运行测试: diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 97ccf16..0fa5b34 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -6,12 +6,35 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from fastapi.testclient import TestClient from app.core.database import Base, get_db, SessionLocal -from app.main import app +from app.main import app as fastapi_app from app.core.config import settings -import os -# 测试数据库URL(使用SQLite内存数据库) -TEST_DATABASE_URL = "sqlite:///:memory:" +# 导入所有模型,确保 Base.metadata 包含所有表 +from app.core.database import Base as _Base +import app.models.user # noqa: F401 +import app.models.workflow # noqa: F401 +import app.models.agent # noqa: F401 +import app.models.execution # noqa: F401 +import app.models.model_config # noqa: F401 +import app.models.workflow_template # noqa: F401 +import app.models.permission # noqa: F401 +import app.models.alert_rule # noqa: F401 +import app.models.agent_llm_log # noqa: F401 +import app.models.agent_vector_memory # noqa: F401 +import app.models.knowledge_base # noqa: F401 +import app.models.tool # noqa: F401 + +assert len(_Base.metadata.tables) > 0, "没有模型表被注册" + +# 测试数据库URL +# 使用临时文件而非 :memory: 避免 FastAPI 异步/多线程请求中的 +# SQLite 内存数据库连接隔离问题(每个连接看到不同的数据库) +import tempfile as _tempfile +import os as _os +import atexit as _atexit +_test_db_fd, _test_db_path = _tempfile.mkstemp(suffix=".db") +_os.close(_test_db_fd) +TEST_DATABASE_URL = f"sqlite:///{_test_db_path}" # 创建测试数据库引擎 test_engine = create_engine( @@ -23,20 +46,27 @@ test_engine = create_engine( TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine) +# 在进程退出时清理临时数据库文件 +def _cleanup_test_db(): + test_engine.dispose() + if _os.path.exists(_test_db_path): + _os.unlink(_test_db_path) + + +_atexit.register(_cleanup_test_db) + + @pytest.fixture(scope="function") def db_session(): """创建测试数据库会话""" - # 创建所有表 - Base.metadata.create_all(bind=test_engine) - - # 创建会话 + _Base.metadata.create_all(bind=test_engine) + session = TestingSessionLocal() try: yield session finally: session.close() - # 删除所有表 - Base.metadata.drop_all(bind=test_engine) + _Base.metadata.drop_all(bind=test_engine) @pytest.fixture(scope="function") @@ -47,13 +77,13 @@ def client(db_session): yield db_session finally: pass - - app.dependency_overrides[get_db] = override_get_db - - with TestClient(app) as test_client: + + fastapi_app.dependency_overrides[get_db] = override_get_db + + with TestClient(fastapi_app) as test_client: yield test_client - - app.dependency_overrides.clear() + + fastapi_app.dependency_overrides.clear() @pytest.fixture @@ -72,7 +102,7 @@ def authenticated_client(client, test_user_data): # 注册用户 response = client.post("/api/v1/auth/register", json=test_user_data) assert response.status_code == 201 - + # 登录获取token login_response = client.post( "/api/v1/auth/login", @@ -83,10 +113,10 @@ def authenticated_client(client, test_user_data): ) assert login_response.status_code == 200 token = login_response.json()["access_token"] - + # 设置认证头 client.headers.update({"Authorization": f"Bearer {token}"}) - + return client diff --git a/backend/tests/test_agent_memory.py b/backend/tests/test_agent_memory.py new file mode 100644 index 0000000..4cf1322 --- /dev/null +++ b/backend/tests/test_agent_memory.py @@ -0,0 +1,321 @@ +""" +Agent 记忆系统单元测试 +""" +from __future__ import annotations + +import json +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from datetime import datetime + + +class TestAgentContext: + """AgentContext 基础功能测试""" + + @pytest.mark.unit + def test_context_initialization(self): + from app.agent_runtime.context import AgentContext + + ctx = AgentContext( + system_prompt="You are a helpful assistant.", + user_id="test-user", + session_id="test-session", + ) + assert ctx.session_id == "test-session" + assert ctx.user_id == "test-user" + assert ctx.iteration == 0 + assert ctx.tool_calls_made == 0 + # system prompt 在 messages 中 + msgs = ctx.messages + assert msgs[0]["role"] == "system" + assert msgs[0]["content"] == "You are a helpful assistant." + + @pytest.mark.unit + def test_add_user_message(self): + from app.agent_runtime.context import AgentContext + + ctx = AgentContext(session_id="s1") + ctx.add_user_message("Hello") + assert len(ctx.messages) == 2 # system + user + assert ctx.messages[1]["role"] == "user" + assert ctx.messages[1]["content"] == "Hello" + + @pytest.mark.unit + def test_add_assistant_message(self): + from app.agent_runtime.context import AgentContext + + ctx = AgentContext(session_id="s1") + ctx.add_user_message("Hi") + ctx.add_assistant_message("Hello!", tool_calls=[{ + "id": "call_1", + "type": "function", + "function": {"name": "test", "arguments": "{}"}, + }]) + msgs = ctx.messages + assistant_msgs = [m for m in msgs if m["role"] == "assistant"] + assert len(assistant_msgs) == 1 + assert assistant_msgs[0]["content"] == "Hello!" + assert "tool_calls" in assistant_msgs[0] + + @pytest.mark.unit + def test_add_tool_result(self): + from app.agent_runtime.context import AgentContext + + ctx = AgentContext(session_id="s1") + ctx.add_tool_result("call_1", "test_tool", '{"result": "ok"}') + tool_msgs = [m for m in ctx.messages if m["role"] == "tool"] + assert len(tool_msgs) == 1 + assert tool_msgs[0]["tool_call_id"] == "call_1" + assert tool_msgs[0]["name"] == "test_tool" + + @pytest.mark.unit + def test_iteration_tracking(self): + from app.agent_runtime.context import AgentContext + + ctx = AgentContext("s1") + assert ctx.iteration == 0 + ctx.iteration += 1 + assert ctx.iteration == 1 + ctx.tool_calls_made += 2 + assert ctx.tool_calls_made == 2 + + @pytest.mark.unit + def test_context_reset(self): + from app.agent_runtime.context import AgentContext + + ctx = AgentContext(system_prompt="Helpful.", session_id="s1") + ctx.add_user_message("Hello") + ctx.add_assistant_message("Hi") + ctx.iteration = 5 + ctx.tool_calls_made = 3 + ctx.reset() + assert ctx.iteration == 0 + assert ctx.tool_calls_made == 0 + # 重置后 messages 应仅含 system + msgs = ctx.messages + assert len(msgs) == 1 + assert msgs[0]["role"] == "system" + + @pytest.mark.unit + def test_set_system_prompt(self): + from app.agent_runtime.context import AgentContext + + ctx = AgentContext(system_prompt="Original.", session_id="s1") + ctx.set_system_prompt("Updated.") + # 未发送过消息,可以更新 + assert ctx.messages[0]["content"] == "Updated." + + +class TestAgentMemory: + """AgentMemory 分层记忆测试""" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_initialize_no_persist(self): + from app.agent_runtime.memory import AgentMemory + + memory = AgentMemory( + scope_kind="agent", + scope_id="test-agent", + session_key="test-session", + persist=False, + ) + result = await memory.initialize(query="Hello") + assert result == "" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_save_context_no_persist(self): + from app.agent_runtime.memory import AgentMemory + + memory = AgentMemory( + scope_kind="agent", + scope_id="test-agent", + session_key="test-session", + persist=False, + ) + await memory.save_context( + user_message="Hello", + assistant_reply="Hi there!", + ) + # 不报错即为通过 + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_vector_memory_disabled(self): + from app.agent_runtime.memory import AgentMemory + + memory = AgentMemory( + scope_kind="agent", + scope_id="test-agent", + session_key="test-session", + persist=False, + vector_memory_enabled=False, + ) + await memory.save_context( + user_message="No vector", + assistant_reply="OK", + ) + # 不报错即为通过 + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_vector_search_no_results(self): + from app.agent_runtime.memory import AgentMemory + + memory = AgentMemory( + scope_kind="agent", + scope_id="nonexistent", + persist=False, + vector_memory_enabled=True, + ) + result = await memory._vector_search(query="test") + assert result == "" + + @pytest.mark.unit + def test_trim_messages(self): + from app.agent_runtime.memory import AgentMemory + + memory = AgentMemory(persist=False) + messages = [{"role": "system", "content": "You are helpful."}] + for i in range(30): + messages.append({"role": "user", "content": f"msg {i}"}) + messages.append({"role": "assistant", "content": f"reply {i}"}) + trimmed = memory.trim_messages(messages) + assert len(trimmed) <= memory.max_history + 1 # +1 for system + assert trimmed[0]["role"] == "system" + + @pytest.mark.unit + def test_summarize_history(self): + from app.agent_runtime.memory import AgentMemory + + history = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + {"role": "user", "content": "How are you?"}, + ] + summary = AgentMemory._summarize_history(history) + assert "2 轮" in summary + + +class TestToolManager: + """AgentToolManager 测试""" + + @pytest.mark.unit + def test_include_filter(self): + from app.agent_runtime.tool_manager import AgentToolManager + + mgr = AgentToolManager(include_tools=["math", "file_read"]) + assert mgr._include_tools == {"math", "file_read"} + assert mgr._exclude_tools == set() + + @pytest.mark.unit + def test_exclude_filter(self): + from app.agent_runtime.tool_manager import AgentToolManager + + mgr = AgentToolManager(exclude_tools=["dangerous_tool"]) + assert "dangerous_tool" in mgr._exclude_tools + + @pytest.mark.unit + def test_tool_name_extraction(self): + from app.agent_runtime.tool_manager import AgentToolManager + + name = AgentToolManager._extract_tool_name({ + "type": "function", + "function": {"name": "my_tool"}, + }) + assert name == "my_tool" + + @pytest.mark.unit + def test_tool_name_extraction_flat(self): + from app.agent_runtime.tool_manager import AgentToolManager + + name = AgentToolManager._extract_tool_name({"name": "flat_tool"}) + assert name == "flat_tool" + + @pytest.mark.unit + def test_tool_name_extraction_empty(self): + from app.agent_runtime.tool_manager import AgentToolManager + + name = AgentToolManager._extract_tool_name({}) + assert name is None + + @pytest.mark.unit + def test_has_tools(self): + from app.agent_runtime.tool_manager import AgentToolManager + + mgr = AgentToolManager() + assert mgr.has_tools() is True + + @pytest.mark.unit + def test_tool_names(self): + from app.agent_runtime.tool_manager import AgentToolManager + + mgr = AgentToolManager() + names = mgr.tool_names() + assert isinstance(names, list) + assert len(names) >= 1 + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_execute_delegates_to_registry(self): + from app.agent_runtime.tool_manager import AgentToolManager + from app.services.tool_registry import tool_registry + + mgr = AgentToolManager() + # 执行一个内置工具 + result = await mgr.execute("datetime", {"format": "%Y"}) + assert isinstance(result, str) + assert len(result) > 0 + + +class TestAgentSchemas: + """Agent Schemas 测试""" + + @pytest.mark.unit + def test_agent_config_defaults(self): + from app.agent_runtime.schemas import AgentConfig + + config = AgentConfig(name="Test Agent") + assert config.name == "Test Agent" + assert config.system_prompt == "你是一个有用的AI助手。请使用可用工具来帮助用户完成任务。" + assert config.llm.model == "gpt-4o-mini" + assert config.llm.temperature == 0.7 + assert config.llm.max_iterations == 10 + + @pytest.mark.unit + def test_agent_memory_config_defaults(self): + from app.agent_runtime.schemas import AgentMemoryConfig + + cfg = AgentMemoryConfig() + assert cfg.enabled is True + assert cfg.vector_memory_enabled is True + assert cfg.vector_memory_top_k == 5 + assert cfg.max_history_messages == 20 + + @pytest.mark.unit + def test_agent_budget_config_defaults(self): + from app.agent_runtime.schemas import AgentBudgetConfig + + cfg = AgentBudgetConfig() + assert cfg.max_llm_invocations == 200 + assert cfg.max_tool_calls == 500 + + @pytest.mark.unit + def test_agent_result_fields(self): + from app.agent_runtime.schemas import AgentResult, AgentStep + + result = AgentResult( + content="Test result", + iterations_used=3, + tool_calls_made=5, + truncated=False, + steps=[ + AgentStep(iteration=1, type="think", content="Thinking..."), + AgentStep(iteration=2, type="tool_call", content="Calling tool", tool_name="test"), + ], + ) + assert result.success is True + assert result.content == "Test result" + assert result.iterations_used == 3 + assert len(result.steps) == 2 diff --git a/backend/tests/test_document_parser.py b/backend/tests/test_document_parser.py new file mode 100644 index 0000000..ba7305d --- /dev/null +++ b/backend/tests/test_document_parser.py @@ -0,0 +1,137 @@ +""" +文档解析器单元测试 +""" +from __future__ import annotations + +import os +import tempfile +import pytest + +from app.services.document_parser import parse_document + + +class TestDocumentParser: + """文档解析器测试""" + + @pytest.mark.unit + def test_parse_txt(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", encoding="utf-8", delete=False) as f: + f.write("Hello world\nThis is a test.") + tmp_path = f.name + try: + result = parse_document(tmp_path, "txt") + assert "Hello world" in result + assert "This is a test" in result + finally: + os.unlink(tmp_path) + + @pytest.mark.unit + def test_parse_txt_utf8_bom(self): + """带 BOM 的 UTF-8 文件""" + content = "\ufeff你好世界\n测试内容" + with tempfile.NamedTemporaryFile(mode="wb", suffix=".txt", delete=False) as f: + f.write(content.encode("utf-8-sig")) + tmp_path = f.name + try: + result = parse_document(tmp_path, "txt") + assert "你好" in result + assert "测试" in result + finally: + os.unlink(tmp_path) + + @pytest.mark.unit + def test_parse_md(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".md", encoding="utf-8", delete=False) as f: + f.write("# Title\n\nThis is **bold** text.") + tmp_path = f.name + try: + result = parse_document(tmp_path, "md") + assert "Title" in result + assert "bold" in result + finally: + os.unlink(tmp_path) + + @pytest.mark.unit + def test_parse_csv(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", encoding="utf-8", delete=False) as f: + f.write("name,age,city\nAlice,30,Beijing\nBob,25,Shanghai") + tmp_path = f.name + try: + result = parse_document(tmp_path, "csv") + assert "Alice" in result + assert "Beijing" in result + assert "Bob" in result + finally: + os.unlink(tmp_path) + + @pytest.mark.unit + def test_parse_csv_with_different_delimiter(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", encoding="utf-8", delete=False) as f: + f.write("a|b|c\n1|2|3\n4|5|6") + tmp_path = f.name + try: + result = parse_document(tmp_path, "csv") + assert "1" in result or "a" in result + finally: + os.unlink(tmp_path) + + @pytest.mark.unit + def test_unsupported_format(self): + with tempfile.NamedTemporaryFile(suffix=".xyz", delete=False) as f: + tmp_path = f.name + try: + result = parse_document(tmp_path, "xyz") + assert result is None # 不支持的格式返回 None + finally: + os.unlink(tmp_path) + + @pytest.mark.unit + def test_file_not_found(self): + result = parse_document("/nonexistent/file.txt", "txt") + assert result is None # 文件不存在返回 None + + @pytest.mark.unit + def test_empty_file(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", encoding="utf-8", delete=False) as f: + tmp_path = f.name + try: + result = parse_document(tmp_path, "txt") + assert result is not None + finally: + os.unlink(tmp_path) + + @pytest.mark.integration + def test_parse_docx(self): + """需要 python-docx 包""" + pytest.importorskip("docx") + from docx import Document + + doc = Document() + doc.add_paragraph("Hello from docx") + doc.add_paragraph("第二段落") + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as f: + doc.save(f.name) + tmp_path = f.name + try: + result = parse_document(tmp_path, "docx") + assert "Hello from docx" in result + assert "第二段落" in result + finally: + os.unlink(tmp_path) + + @pytest.mark.integration + def test_parse_pdf(self): + """需要 PyPDF2 包""" + pytest.importorskip("PyPDF2") + from PyPDF2 import PdfWriter + + writer = PdfWriter() + writer.add_blank_page(72, 72) # 1 inch page + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: + writer.write(f) + tmp_path = f.name + try: + result = parse_document(tmp_path, "pdf") + assert result is not None + finally: + os.unlink(tmp_path) diff --git a/backend/tests/test_embedding_service.py b/backend/tests/test_embedding_service.py new file mode 100644 index 0000000..e7a4a40 --- /dev/null +++ b/backend/tests/test_embedding_service.py @@ -0,0 +1,146 @@ +""" +Embedding 服务单元测试 +""" +from __future__ import annotations + +import pytest +from unittest.mock import patch, AsyncMock +from app.services.embedding_service import EmbeddingService + + +class TestEmbeddingService: + """Embedding 服务测试""" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_cosine_similarity_identical(self): + from app.services.embedding_service import EmbeddingService + + a = [1.0, 0.0, 0.0] + b = [1.0, 0.0, 0.0] + sim = EmbeddingService.cosine_similarity(a, b) + assert sim == pytest.approx(1.0, abs=1e-6) + + @pytest.mark.unit + def test_cosine_similarity_orthogonal(self): + from app.services.embedding_service import EmbeddingService + + a = [1.0, 0.0] + b = [0.0, 1.0] + sim = EmbeddingService.cosine_similarity(a, b) + assert sim == pytest.approx(0.0, abs=1e-6) + + @pytest.mark.unit + def test_cosine_similarity_opposite(self): + from app.services.embedding_service import EmbeddingService + + a = [1.0, 0.0] + b = [-1.0, 0.0] + sim = EmbeddingService.cosine_similarity(a, b) + assert sim == pytest.approx(-1.0, abs=1e-6) + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_similarity_search_empty(self): + from app.services.embedding_service import EmbeddingService + + svc = EmbeddingService() + results = await svc.similarity_search( + [1.0, 0.0], [], top_k=5 + ) + assert results == [] + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_similarity_search_ordering(self): + from app.services.embedding_service import EmbeddingService + + svc = EmbeddingService() + entries = [ + {"content_text": "dogs are great pets", "embedding": [0.9, 0.0, 0.0]}, + {"content_text": "cats are independent", "embedding": [0.1, 0.0, 0.0]}, + ] + query = [0.8, 0.0, 0.0] + results = await svc.similarity_search(query, entries, top_k=3, min_score=0.0) + assert len(results) == 2 + assert results[0]["content_text"] == "dogs are great pets" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_similarity_search_top_k(self): + from app.services.embedding_service import EmbeddingService + + svc = EmbeddingService() + entries = [ + {"content_text": f"entry {i}", "embedding": [float(i) / 10, 0.0]} for i in range(10) + ] + query = [1.0, 0.0] + results = await svc.similarity_search(query, entries, top_k=3) + assert len(results) == 3 + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_similarity_search_min_score(self): + from app.services.embedding_service import EmbeddingService + + svc = EmbeddingService() + entries = [ + {"content_text": "close match", "embedding": [0.9, 0.0]}, + {"content_text": "distant match", "embedding": [-0.5, 0.0]}, + ] + query = [1.0, 0.0] + results = await svc.similarity_search(query, entries, top_k=5, min_score=0.5) + assert len(results) == 1 + assert results[0]["content_text"] == "close match" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_generate_embedding_empty(self): + from app.services.embedding_service import EmbeddingService + + svc = EmbeddingService() + result = await svc.generate_embedding("") + assert result is None + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_generate_embedding_no_api_key(self): + """无 API Key 时返回 None""" + from app.services.embedding_service import EmbeddingService + + svc = EmbeddingService() + with patch.object(svc, "_get_client", AsyncMock(return_value=None)): + result = await svc.generate_embedding("test") + assert result is None + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_generate_embeddings_empty(self): + from app.services.embedding_service import EmbeddingService + + svc = EmbeddingService() + result = await svc.generate_embeddings([]) + assert result == [] + + @pytest.mark.unit + def test_serialize_deserialize(self): + from app.services.embedding_service import EmbeddingService + + emb = [0.1, 0.2, 0.3] + serialized = EmbeddingService.serialize_embedding(emb) + deserialized = EmbeddingService.deserialize_embedding(serialized) + assert deserialized == emb + + @pytest.mark.unit + def test_deserialize_invalid(self): + from app.services.embedding_service import EmbeddingService + + result = EmbeddingService.deserialize_embedding("invalid json") + assert result == [] + + @pytest.mark.unit + def test_deserialize_list_already(self): + from app.services.embedding_service import EmbeddingService + + result = EmbeddingService.deserialize_embedding([1.0, 2.0]) + assert result == [1.0, 2.0] diff --git a/backend/tests/test_knowledge_base.py b/backend/tests/test_knowledge_base.py new file mode 100644 index 0000000..2330ef8 --- /dev/null +++ b/backend/tests/test_knowledge_base.py @@ -0,0 +1,124 @@ +""" +知识库 RAG 单元测试 +""" +from __future__ import annotations + +import pytest +from unittest.mock import patch, AsyncMock, MagicMock + + +class TestKnowledgeService: + """知识库服务测试""" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_search_empty_kb(self): + """空知识库搜索返回空列表""" + from app.services.knowledge_service import search + + mock_db = MagicMock() + mock_query = MagicMock() + mock_db.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.all.return_value = [] + mock_query.first.return_value = None + + results = await search(mock_db, kb_id="nonexistent", query="test") + assert results == [] + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_rag_query_no_results(self): + """无检索结果时返回空上下文""" + from app.services.knowledge_service import rag_query + + mock_db = MagicMock() + mock_query = MagicMock() + mock_db.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.all.return_value = [] + mock_query.first.return_value = None + + result = await rag_query(mock_db, kb_id="test", query="no results") + assert result["found"] is False + assert result["context"] == "" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_search_with_content(self): + """模拟有内容的搜索结果""" + from app.services.knowledge_service import search + + from app.models.knowledge_base import DocumentChunk + mock_chunk = MagicMock(spec=DocumentChunk) + mock_chunk.id = "chunk-1" + mock_chunk.content = "test content about Python programming" + mock_chunk.chunk_index = 0 + mock_chunk.document_id = "doc-1" + mock_chunk.metadata = {"filename": "test.txt"} + + mock_db = MagicMock() + mock_query = MagicMock() + mock_db.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.all.return_value = [] + + with patch("app.services.knowledge_service.embedding_service.generate_embedding", + AsyncMock(return_value=[0.1, 0.2, 0.3])): + with patch("app.services.knowledge_service.embedding_service.similarity_search", + AsyncMock(return_value=[ + {"content_text": "test content about Python", "score": 0.85, "metadata": {}} + ])): + results = await search(mock_db, kb_id="test", query="Python") + # 可能返回空(chunks filter 不匹配)但不报错 + assert isinstance(results, list) + + +class TestKnowledgeModels: + """知识库模型测试""" + + @pytest.mark.unit + def test_knowledge_base_model(self): + from app.models.knowledge_base import KnowledgeBase + + kb = KnowledgeBase( + name="Test KB", + description="Test", + user_id="user-1", + chunk_size=500, + chunk_overlap=50, + ) + assert kb.name == "Test KB" + assert kb.chunk_size == 500 + assert kb.chunk_overlap == 50 + + @pytest.mark.unit + def test_document_model(self): + from app.models.knowledge_base import Document + + doc = Document( + kb_id="kb-1", + filename="test.txt", + file_type="txt", + file_size=1024, + status="completed", + chunk_count=5, + ) + assert doc.filename == "test.txt" + assert doc.status == "completed" + assert doc.chunk_count == 5 + + @pytest.mark.unit + def test_document_chunk_model(self): + from app.models.knowledge_base import DocumentChunk + + chunk = DocumentChunk( + document_id="doc-1", + kb_id="kb-1", + chunk_index=0, + content="test content", + embedding="[0.1, 0.2, 0.3]", + metadata={"source": "test"}, + ) + assert chunk.chunk_index == 0 + assert chunk.content == "test content" diff --git a/backend/tests/test_text_chunker.py b/backend/tests/test_text_chunker.py new file mode 100644 index 0000000..a911aca --- /dev/null +++ b/backend/tests/test_text_chunker.py @@ -0,0 +1,139 @@ +""" +文本分块器单元测试 +""" +from __future__ import annotations + +import pytest +from app.services.text_chunker import chunk_text, _split_paragraphs, _split_long_paragraph, _merge_segments + + +class TestSplitParagraphs: + """段落分割测试""" + + def test_empty_text(self): + assert _split_paragraphs("") == [] + assert _split_paragraphs(" ") == [] + assert _split_paragraphs("\n\n\n") == [] + + def test_single_paragraph(self): + result = _split_paragraphs("Hello world") + assert result == ["Hello world"] + + def test_multiple_paragraphs(self): + text = "第一段内容。\n\n第二段内容。\n\n第三段内容。" + result = _split_paragraphs(text) + assert len(result) == 3 + assert "第一段" in result[0] + assert "第二段" in result[1] + assert "第三段" in result[2] + + def test_mixed_newlines(self): + text = "段1\n\n\n段2\n\n段3" + result = _split_paragraphs(text) + assert len(result) == 3 + + +class TestSplitLongParagraph: + """超长段落分割测试""" + + def test_short_paragraph(self): + result = _split_long_paragraph("短文本", chunk_size=500) + assert result == ["短文本"] + + def test_long_paragraph_chinese(self): + para = "第一句。" * 200 # 600 chars, exceeds 500 + result = _split_long_paragraph(para, chunk_size=500) + assert len(result) >= 2 + assert all(len(c) <= 500 for c in result) + + def test_long_paragraph_english(self): + para = "Sentence. " * 200 + result = _split_long_paragraph(para, chunk_size=500) + assert len(result) >= 2 + assert all(len(c) <= 500 for c in result) + + def test_no_sentence_boundary(self): + """无句号可分割时按字符硬切""" + para = "a" * 1000 + result = _split_long_paragraph(para, chunk_size=500) + assert len(result) == 2 + assert len(result[0]) == 500 + assert len(result[1]) == 500 + + +class TestMergeSegments: + """段落合并测试""" + + def test_empty(self): + assert _merge_segments([], 500, 0) == [] + + def test_single_segment(self): + result = _merge_segments(["hello"], 500, 0) + assert result == ["hello"] + + def test_merge_short_segments(self): + segs = ["a" * 100, "b" * 100, "c" * 100] # each < 500, total ~300 < 500 + result = _merge_segments(segs, 500, 0) + assert len(result) == 1 + assert len(result[0]) > 200 + + def test_split_large_segments(self): + segs = ["a" * 300, "b" * 300, "c" * 300] # 需要分为多个chunk + result = _merge_segments(segs, 500, 0) + assert len(result) >= 2 + + def test_overlap(self): + segs = ["a" * 300, "b" * 300] + result = _merge_segments(segs, 500, 50) + assert len(result) >= 1 + + +class TestChunkText: + """chunk_text 整体测试""" + + def test_empty_text(self): + assert chunk_text("") == [] + assert chunk_text(" ") == [] + assert chunk_text(None) == [] # type: ignore + + def test_short_text(self): + result = chunk_text("Hello world", chunk_size=500, chunk_overlap=0) + assert len(result) == 1 + assert "Hello world" in result[0] + + def test_normal_text(self): + text = """ +这是第一段。它包含一些内容。 + +这是第二段。它也有一些内容。而且更长一些。 + +这是第三段。最后一段内容。 +""" + result = chunk_text(text, chunk_size=500, chunk_overlap=0) + assert len(result) >= 1 + # 所有内容都应该在结果中 + all_text = "".join(result) + assert "第一段" in all_text + assert "第二段" in all_text + assert "第三段" in all_text + + def test_chinese_text(self): + """中文文本测试""" + text = "我喜欢吃川菜。特别是麻辣火锅。还有水煮鱼。这些都很美味。" + result = chunk_text(text, chunk_size=100, chunk_overlap=0) + assert len(result) >= 1 + + def test_overlap_between_chunks(self): + """验证块间重叠""" + para = "这是一个很长的段落。" * 50 + result = chunk_text(para, chunk_size=200, chunk_overlap=50) + if len(result) > 1: + # 相邻块应该有重叠内容 + assert len(result[0]) > 0 + assert len(result[1]) > 0 + + @pytest.mark.unit + def test_with_mixed_punctuation(self): + text = "Hello! How are you? I am fine. Thank you. 你好!最近怎么样?我很好。谢谢。" + result = chunk_text(text, chunk_size=200, chunk_overlap=0) + assert len(result) >= 1 diff --git a/backend/tests/test_tool_registry.py b/backend/tests/test_tool_registry.py new file mode 100644 index 0000000..da60eef --- /dev/null +++ b/backend/tests/test_tool_registry.py @@ -0,0 +1,293 @@ +""" +工具注册表单元测试 +""" +from __future__ import annotations + +import json +import pytest +from unittest.mock import patch, AsyncMock + +from app.services.tool_registry import ToolRegistry, _CODE_SAFE_GLOBALS + + +@pytest.fixture +def registry(): + r = ToolRegistry() + return r + + +class TestToolRegistryBuiltin: + """内置工具注册与查询""" + + @pytest.mark.unit + def test_register_and_get(self, registry): + def my_tool(**kwargs): + return {"result": kwargs.get("x", 0) + kwargs.get("y", 0)} + + schema = { + "type": "function", + "function": { + "name": "add", + "description": "加法", + "parameters": {"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}}, + }, + } + registry.register_builtin_tool("add", my_tool, schema) + assert registry.get_tool_schema("add") == schema + assert registry.get_tool_function("add") == my_tool + assert registry.builtin_tool_count() == 1 + assert "add" in registry.builtin_tool_names() + + @pytest.mark.unit + def test_get_missing_tool(self, registry): + assert registry.get_tool_schema("nonexistent") is None + assert registry.get_tool_function("nonexistent") is None + + @pytest.mark.unit + def test_get_all_schemas(self, registry): + schema1 = {"type": "function", "function": {"name": "tool1"}} + schema2 = {"type": "function", "function": {"name": "tool2"}} + registry.register_builtin_tool("tool1", lambda: None, schema1) + registry.register_builtin_tool("tool2", lambda: None, schema2) + assert len(registry.get_all_tool_schemas()) == 2 + + @pytest.mark.unit + def test_get_tools_by_names(self, registry): + registry.register_builtin_tool("a", lambda: None, {"function": {"name": "a"}}) + registry.register_builtin_tool("b", lambda: None, {"function": {"name": "b"}}) + tools = registry.get_tools_by_names(["a", "b", "c"]) + assert len(tools) == 2 + assert tools[0]["function"]["name"] == "a" + + @pytest.mark.unit + def test_sync_function_execution(self, registry): + def sync_func(x=0, y=0): + return x + y + + registry.register_builtin_tool("add", sync_func, {"function": {"name": "add"}}) + result = registry._run_function(sync_func, "add", {"x": 3, "y": 4}) + import asyncio + result = asyncio.run(registry._run_function(sync_func, "add", {"x": 3, "y": 4})) + parsed = json.loads(result) + assert parsed == 7 + + +@pytest.mark.usefixtures("registry") +class TestToolRegistryHTTP: + """HTTP 工具执行测试(mock httpx)""" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_execute_http(self, registry): + config = { + "url": "https://api.example.com/data?q={query}", + "method": "GET", + "headers": {"Authorization": "Bearer token"}, + "timeout": 10, + "_type": "http", + } + registry._custom_tool_configs["test_http"] = config + + with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.text = '{"result": "ok"}' + mock_request.return_value = mock_response + + result = await registry.execute_tool("test_http", {"query": "hello"}) + parsed = json.loads(result) + assert parsed["status_code"] == 200 + # 验证 URL 模板替换 + called_url = mock_request.call_args[0][1] + assert "hello" in called_url + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_execute_http_post_with_body(self, registry): + config = { + "url": "https://api.example.com/submit", + "method": "POST", + "headers": {}, + "body_template": {"name": "{name}", "age": "{age}"}, + "timeout": 10, + "_type": "http", + } + registry._custom_tool_configs["test_post"] = config + + with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request: + mock_response = AsyncMock() + mock_response.status_code = 201 + mock_response.text = '{"id": 1}' + mock_request.return_value = mock_response + + result = await registry.execute_tool("test_post", {"name": "Alice", "age": 30}) + parsed = json.loads(result) + assert parsed["status_code"] == 201 + # Verify POST method + assert mock_request.call_args[0][0] == "POST" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_http_no_url(self, registry): + config = {"_type": "http"} + registry._custom_tool_configs["bad_http"] = config + result = await registry.execute_tool("bad_http", {}) + assert "error" in result + + +class TestToolRegistryCode: + """代码沙箱执行测试""" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_execute_code_simple(self, registry): + config = { + "source": "def run(args):\n return {'sum': args['a'] + args['b']}", + "_type": "code", + } + registry._custom_tool_configs["calc"] = config + result = await registry.execute_tool("calc", {"a": 10, "b": 20}) + parsed = json.loads(result) + assert parsed["sum"] == 30 + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_execute_code_text_stats(self, registry): + config = { + "source": "def run(args):\n text = args.get('text', '')\n return {'len': len(text), 'words': len(text.split())}", + "_type": "code", + } + registry._custom_tool_configs["stats"] = config + result = await registry.execute_tool("stats", {"text": "hello world test"}) + parsed = json.loads(result) + assert parsed["len"] == 16 + assert parsed["words"] == 3 + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_code_source_missing(self, registry): + config = {"_type": "code"} # no source + registry._custom_tool_configs["no_source"] = config + result = await registry.execute_tool("no_source", {}) + assert "error" in result + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_code_no_run_function(self, registry): + config = {"source": "x = 1", "_type": "code"} + registry._custom_tool_configs["no_run"] = config + result = await registry.execute_tool("no_run", {}) + assert "run" in result + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_code_runtime_error(self, registry): + config = { + "source": "def run(args):\n raise ValueError('test error')", + "_type": "code", + } + registry._custom_tool_configs["err"] = config + result = await registry.execute_tool("err", {}) + assert "error" in result + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_code_sandbox_restriction(self, registry): + """验证 __builtins__ 被禁用""" + config = { + "source": "def run(args):\n import os\n return os.name", + "_type": "code", + } + registry._custom_tool_configs["unsafe"] = config + result = await registry.execute_tool("unsafe", {}) + # import 在沙箱中应该失败 + assert "error" in result + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_code_sandbox_no_file_access(self, registry): + """验证无法访问文件系统""" + config = { + "source": "def run(args):\n open('/etc/passwd')\n return 'ok'", + "_type": "code", + } + registry._custom_tool_configs["file_access"] = config + result = await registry.execute_tool("file_access", {}) + assert "error" in result + + +class TestToolRegistryTestHelpers: + """测试工具(不保存到 DB)""" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_test_code_tool_success(self, registry): + source = "def run(args):\n return {'result': args['x'] * 2}" + result = await registry.test_code_tool(source, {"x": 5}) + assert result["success"] is True + assert result["result"]["result"] == 10 + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_test_code_tool_compile_error(self, registry): + source = "def run(args):\n invalid syntax{{{" + result = await registry.test_code_tool(source, {}) + assert result["success"] is False + assert "error" in result + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_test_http_tool(self, registry): + with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.text = '{"ip": "8.8.8.8"}' + mock_request.return_value = mock_response + + from app.services.tool_registry import _CODE_SAFE_GLOBALS + result = await registry.test_http_tool( + url="https://httpbin.org/get?ip={ip}", + method="GET", + headers={}, + body=None, + args={"ip": "8.8.8.8"}, + timeout=5, + ) + assert result["success"] is True + assert result["status_code"] == 200 + + +class TestToolRegistryExecute: + """execute_tool 整体流程""" + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_execute_builtin(self, registry): + def hello(**kwargs): + return f"Hello, {kwargs.get('name', 'world')}!" + + registry.register_builtin_tool("hello", hello, {"function": {"name": "hello"}}) + result = await registry.execute_tool("hello", {"name": "Test"}) + assert "Hello, Test!" in result + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_execute_nonexistent(self, registry): + result = await registry.execute_tool("no_such_tool", {}) + assert "error" in result + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_execute_unsupported_type(self, registry): + config = {"_type": "unsupported"} + registry._custom_tool_configs["weird"] = config + result = await registry.execute_tool("weird", {}) + assert "error" in result + + @pytest.mark.unit + @pytest.mark.asyncio + async def test_execute_workflow_not_supported(self, registry): + config = {"_type": "workflow"} + registry._custom_tool_configs["wf"] = config + result = await registry.execute_tool("wf", {}) + assert "暂不支持" in result diff --git a/backend/tests/test_tools_api.py b/backend/tests/test_tools_api.py new file mode 100644 index 0000000..68171b1 --- /dev/null +++ b/backend/tests/test_tools_api.py @@ -0,0 +1,263 @@ +""" +工具市场 API 集成测试 +""" +from __future__ import annotations + +import json +import pytest +from unittest.mock import patch, AsyncMock + + +class TestToolsAPI: + """工具市场 API 测试""" + + @pytest.mark.api + def test_list_public_tools(self, authenticated_client): + resp = authenticated_client.get("/api/v1/tools") + assert resp.status_code == 200 + assert isinstance(resp.json(), list) + + @pytest.mark.api + def test_list_categories(self, authenticated_client): + resp = authenticated_client.get("/api/v1/tools/categories") + assert resp.status_code == 200 + cats = resp.json() + assert isinstance(cats, list) + # 应包含默认分类 + assert "数据处理" in cats or "网络请求" in cats + + @pytest.mark.api + def test_list_without_auth(self, client): + resp = client.get("/api/v1/tools") + # 未认证时默认 scope=public 应返回 401 或公开工具 + assert resp.status_code == 401 + + @pytest.mark.api + def test_list_builtin(self, authenticated_client): + resp = authenticated_client.get("/api/v1/tools/builtin") + assert resp.status_code == 200 + tools = resp.json() + assert isinstance(tools, list) + assert len(tools) >= 10 # 期待至少 10 个内置工具 + + @pytest.mark.api + def test_create_and_get_tool(self, authenticated_client): + # 创建 HTTP 工具 + create_resp = authenticated_client.post("/api/v1/tools", json={ + "name": "echo_test", + "description": "Echo test tool", + "category": "network", + "function_schema": { + "name": "echo_test", + "description": "Returns the input", + "parameters": { + "type": "object", + "properties": {"msg": {"type": "string"}}, + "required": ["msg"], + }, + }, + "implementation_type": "http", + "implementation_config": { + "url": "https://httpbin.org/post", + "method": "POST", + "headers": {}, + "timeout": 10, + }, + "is_public": True, + }) + assert create_resp.status_code == 201 + data = create_resp.json() + assert data["name"] == "echo_test" + assert data["implementation_type"] == "http" + tool_id = data["id"] + + # 获取工具详情 + get_resp = authenticated_client.get(f"/api/v1/tools/{tool_id}") + assert get_resp.status_code == 200 + assert get_resp.json()["name"] == "echo_test" + + @pytest.mark.api + def test_create_code_tool(self, authenticated_client): + create_resp = authenticated_client.post("/api/v1/tools", json={ + "name": "double_test", + "description": "Double a number", + "category": "math", + "function_schema": { + "name": "double_test", + "description": "Doubles a number", + "parameters": { + "type": "object", + "properties": {"n": {"type": "number"}}, + "required": ["n"], + }, + }, + "implementation_type": "code", + "implementation_config": { + "source": "def run(args):\n n = args.get('n', 0)\n return {'result': n * 2}", + "language": "python", + }, + "is_public": True, + }) + assert create_resp.status_code == 201 + data = create_resp.json() + assert data["name"] == "double_test" + assert data["implementation_type"] == "code" + + @pytest.mark.api + def test_create_duplicate_name(self, authenticated_client): + # 先创建 + authenticated_client.post("/api/v1/tools", json={ + "name": "dup_tool", + "description": "Dup", + "function_schema": {"name": "dup_tool", "parameters": {"type": "object", "properties": {}}}, + "implementation_type": "code", + "implementation_config": {"source": "def run(args):\n return {}"}, + "is_public": False, + }) + # 重复创建应报错 + resp = authenticated_client.post("/api/v1/tools", json={ + "name": "dup_tool", + "description": "Dup again", + "function_schema": {"name": "dup_tool", "parameters": {"type": "object", "properties": {}}}, + "implementation_type": "code", + "implementation_config": {"source": "def run(args):\n return {}"}, + "is_public": False, + }) + assert resp.status_code == 400 + assert "已存在" in resp.json().get("detail", "") + + @pytest.mark.api + def test_invalid_implementation_type(self, authenticated_client): + resp = authenticated_client.post("/api/v1/tools", json={ + "name": "bad_tool", + "description": "Bad", + "function_schema": {"name": "bad_tool", "parameters": {"type": "object", "properties": {}}}, + "implementation_type": "invalid_type", + "is_public": False, + }) + assert resp.status_code == 400 + + @pytest.mark.api + def test_mine_scope(self, authenticated_client): + resp = authenticated_client.get("/api/v1/tools?scope=mine") + assert resp.status_code == 200 + + @pytest.mark.api + def test_get_nonexistent_tool(self, authenticated_client): + resp = authenticated_client.get("/api/v1/tools/nonexistent-id") + assert resp.status_code == 404 + + @pytest.mark.api + @pytest.mark.asyncio + async def test_test_http_endpoint(self, authenticated_client): + """测试 HTTP 工具的测试端点""" + with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.text = '{"origin": "1.2.3.4"}' + mock_request.return_value = mock_response + + resp = authenticated_client.post("/api/v1/tools/test/http", json={ + "url": "https://httpbin.org/get", + "method": "GET", + "headers": {}, + "body": None, + "args": {"ip": "8.8.8.8"}, + "timeout": 5, + }) + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + + @pytest.mark.api + def test_test_code_endpoint(self, authenticated_client): + resp = authenticated_client.post("/api/v1/tools/test/code", json={ + "source": "def run(args):\n return args.get('x', 0) + args.get('y', 0)", + "args": {"x": 10, "y": 20}, + }) + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert data["result"] == 30 + + @pytest.mark.api + def test_test_code_compile_error(self, authenticated_client): + resp = authenticated_client.post("/api/v1/tools/test/code", json={ + "source": "invalid python {{{", + "args": {}, + }) + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is False + assert "error" in data + + @pytest.mark.api + def test_use_count(self, authenticated_client): + # 先创建工具 + create_resp = authenticated_client.post("/api/v1/tools", json={ + "name": "count_test", + "description": "Count test", + "function_schema": {"name": "count_test", "parameters": {"type": "object", "properties": {}}}, + "implementation_type": "code", + "implementation_config": {"source": "def run(args):\n return {}"}, + "is_public": True, + }) + assert create_resp.status_code == 201 + tool_id = create_resp.json()["id"] + + # 使用计数 + use_resp = authenticated_client.post(f"/api/v1/tools/{tool_id}/use") + assert use_resp.status_code == 200 + assert use_resp.json()["use_count"] == 1 + + # 再次使用 + use_resp2 = authenticated_client.post(f"/api/v1/tools/{tool_id}/use") + assert use_resp2.status_code == 200 + assert use_resp2.json()["use_count"] == 2 + + @pytest.mark.api + def test_delete_tool(self, authenticated_client): + # 创建 + create_resp = authenticated_client.post("/api/v1/tools", json={ + "name": "del_test", + "description": "Delete test", + "function_schema": {"name": "del_test", "parameters": {"type": "object", "properties": {}}}, + "implementation_type": "code", + "implementation_config": {"source": "def run(args):\n return {}"}, + "is_public": False, + }) + assert create_resp.status_code == 201 + tool_id = create_resp.json()["id"] + + # 删除 + del_resp = authenticated_client.delete(f"/api/v1/tools/{tool_id}") + assert del_resp.status_code == 200 + + # 确认已删除 + get_resp = authenticated_client.get(f"/api/v1/tools/{tool_id}") + assert get_resp.status_code == 404 + + @pytest.mark.api + def test_update_tool(self, authenticated_client): + create_resp = authenticated_client.post("/api/v1/tools", json={ + "name": "update_test", + "description": "Original desc", + "function_schema": {"name": "update_test", "parameters": {"type": "object", "properties": {}}}, + "implementation_type": "code", + "implementation_config": {"source": "def run(args):\n return {}"}, + "is_public": False, + }) + assert create_resp.status_code == 201 + tool_id = create_resp.json()["id"] + + update_resp = authenticated_client.put(f"/api/v1/tools/{tool_id}", json={ + "name": "update_test", + "description": "Updated desc", + "function_schema": {"name": "update_test", "parameters": {"type": "object", "properties": {}}}, + "implementation_type": "code", + "implementation_config": {"source": "def run(args):\n return {'updated': True}"}, + "is_public": True, + }) + assert update_resp.status_code == 200 + assert update_resp.json()["description"] == "Updated desc" + assert update_resp.json()["is_public"] is True diff --git a/frontend/src/components/MainLayout.vue b/frontend/src/components/MainLayout.vue index e7cc0b7..950634b 100644 --- a/frontend/src/components/MainLayout.vue +++ b/frontend/src/components/MainLayout.vue @@ -78,7 +78,9 @@ - +
+ +
@@ -109,6 +111,7 @@ const activeMenu = computed(() => { if (route.path === '/monitoring') return 'monitoring' if (route.path === '/agent-monitoring') return 'agent-monitoring' if (route.path === '/alert-rules') return 'alert-rules' + if (route.path === '/agent-chat' || route.path.startsWith('/agent-chat/')) return 'agent-chat' return 'workflows' }) @@ -189,4 +192,17 @@ const handleLogout = () => { height: 50px; line-height: 50px; } + +:deep(.el-main) { + display: flex; + flex-direction: column; + padding: 20px; +} + +.page-content { + flex: 1; + display: flex; + flex-direction: column; + min-height: 0; +} diff --git a/frontend/src/views/AgentChat.vue b/frontend/src/views/AgentChat.vue index db100df..fca2106 100644 --- a/frontend/src/views/AgentChat.vue +++ b/frontend/src/views/AgentChat.vue @@ -1,10 +1,10 @@ - 清空 + 清空
-
+

选择一个 Agent 开始对话

配置多个 Agent 后发送消息进行编排对话

Agent 可以使用内置工具帮你完成任务

-
+
@@ -137,13 +137,21 @@
- {{ msg.role === 'user' ? '用户' : 'Agent' }} · {{ formatTime(msg.timestamp) }} + {{ msg.role === 'user' ? '用户' : 'Agent' }} · {{ relativeTime(msg.timestamp) }} · {{ msg.iterations }} 步 · {{ msg.tool_calls_made }} 次工具调用 + + + + + + + +
-
+
@@ -159,7 +167,7 @@
- +
@@ -188,14 +196,15 @@ 关闭 -
+