116 lines
3.5 KiB
Python
116 lines
3.5 KiB
Python
|
|
"""
|
|||
|
|
Agent Runtime ⇄ WorkflowEngine 桥接。
|
|||
|
|
|
|||
|
|
让 workflow_engine.execute_node() 通过寥寥几行调用 Agent Runtime。
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
from typing import Any, Dict, Optional
|
|||
|
|
|
|||
|
|
from app.agent_runtime.core import AgentRuntime
|
|||
|
|
from app.agent_runtime.schemas import (
|
|||
|
|
AgentConfig,
|
|||
|
|
AgentLLMConfig,
|
|||
|
|
AgentToolConfig,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def run_agent_node(
|
|||
|
|
node_data: Dict[str, Any],
|
|||
|
|
input_data: Dict[str, Any],
|
|||
|
|
execution_logger: Optional[Any] = None,
|
|||
|
|
user_id: Optional[str] = None,
|
|||
|
|
on_tool_executed: Optional[Any] = None,
|
|||
|
|
) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
在工作流中执行 Agent 节点。
|
|||
|
|
|
|||
|
|
node_data 支持的字段:
|
|||
|
|
system_prompt — Agent 人格/指令(支持 {{variable}} 模板)
|
|||
|
|
tools — 可选工具白名单,默认全部
|
|||
|
|
exclude_tools — 可选工具黑名单
|
|||
|
|
model — 模型名称
|
|||
|
|
provider — 提供商(openai/deepseek)
|
|||
|
|
temperature — 温度
|
|||
|
|
max_iterations — ReAct 最大步数
|
|||
|
|
memory — 是否启用长期记忆
|
|||
|
|
|
|||
|
|
input_data 中的 "query" 或 "input" 字段作为用户输入。
|
|||
|
|
"""
|
|||
|
|
# 1. 解析配置
|
|||
|
|
query = (
|
|||
|
|
input_data.get("query")
|
|||
|
|
or input_data.get("input")
|
|||
|
|
or input_data.get("text", "")
|
|||
|
|
)
|
|||
|
|
if not isinstance(query, str):
|
|||
|
|
query = str(query) if query else ""
|
|||
|
|
|
|||
|
|
if not query:
|
|||
|
|
return {"output": "错误:Agent 节点未收到用户输入", "status": "error"}
|
|||
|
|
|
|||
|
|
# 2. 解析 system_prompt(支持模板变量)
|
|||
|
|
raw_prompt = node_data.get("system_prompt", "你是一个有用的AI助手。")
|
|||
|
|
try:
|
|||
|
|
formatted_prompt = raw_prompt.format(**input_data)
|
|||
|
|
except (KeyError, ValueError):
|
|||
|
|
formatted_prompt = raw_prompt
|
|||
|
|
|
|||
|
|
# 3. 构建 Agent 配置
|
|||
|
|
llm_config = AgentLLMConfig(
|
|||
|
|
provider=node_data.get("provider", "openai"),
|
|||
|
|
model=node_data.get("model", "gpt-4o-mini"),
|
|||
|
|
temperature=float(node_data.get("temperature", 0.7)),
|
|||
|
|
max_iterations=int(node_data.get("max_iterations", 10)),
|
|||
|
|
)
|
|||
|
|
# 允许节点内联 api_key/base_url
|
|||
|
|
if node_data.get("api_key"):
|
|||
|
|
llm_config.api_key = node_data["api_key"]
|
|||
|
|
if node_data.get("base_url"):
|
|||
|
|
llm_config.base_url = node_data["base_url"]
|
|||
|
|
|
|||
|
|
agent_config = AgentConfig(
|
|||
|
|
name=node_data.get("label", "agent_node"),
|
|||
|
|
system_prompt=formatted_prompt,
|
|||
|
|
llm=llm_config,
|
|||
|
|
tools=AgentToolConfig(
|
|||
|
|
include_tools=node_data.get("tools", []),
|
|||
|
|
exclude_tools=node_data.get("exclude_tools", []),
|
|||
|
|
),
|
|||
|
|
memory={
|
|||
|
|
"enabled": node_data.get("memory", True),
|
|||
|
|
"persist_to_db": node_data.get("memory", True),
|
|||
|
|
},
|
|||
|
|
user_id=user_id,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 4. 执行 Agent
|
|||
|
|
runtime = AgentRuntime(
|
|||
|
|
config=agent_config,
|
|||
|
|
execution_logger=execution_logger,
|
|||
|
|
on_tool_executed=on_tool_executed,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = await runtime.run(query)
|
|||
|
|
|
|||
|
|
# 5. 返回结果(兼容工作流引擎的输出格式)
|
|||
|
|
if result.success:
|
|||
|
|
return {
|
|||
|
|
"output": result.content,
|
|||
|
|
"status": "success",
|
|||
|
|
"agent_meta": {
|
|||
|
|
"iterations": result.iterations_used,
|
|||
|
|
"tool_calls": result.tool_calls_made,
|
|||
|
|
"truncated": result.truncated,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
return {
|
|||
|
|
"output": result.content,
|
|||
|
|
"status": "error",
|
|||
|
|
"error": result.error,
|
|||
|
|
}
|