""" 场景模板 API(独立路由,避免与 /api/v1/agents/{agent_id} 在部分部署中的匹配顺序问题)。 """ from fastapi import APIRouter, Depends, status, HTTPException from pydantic import BaseModel, Field from sqlalchemy.orm import Session from typing import Any, Dict, List, Optional import logging import uuid import json from app.core.database import get_db from app.api.auth import get_current_user from app.models.user import User from app.models.agent import Agent from app.models.execution import Execution from app.core.exceptions import ValidationError, ConflictError from app.services.workflow_validator import validate_workflow from app.services.scene_templates import build_workflow_for_template, list_scene_template_meta from app.api.agents import AgentResponse, SceneTemplateItem, AgentFromSceneTemplateCreate logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/platform", tags=["platform-templates"]) # ---------- 执行模板请求体 ---------- class ExecuteTemplateRequest(BaseModel): message: str = Field(..., description="用户输入/任务描述") parameters: Optional[Dict[str, Any]] = Field(default_factory=dict, description="模板参数覆盖") class ExecuteTemplateResponse(BaseModel): execution_id: str status: str message: str # ---------- 执行进度响应 ---------- class ExecutionProgressResponse(BaseModel): execution_id: str status: str progress_pct: int = 0 output: Optional[str] = None error: Optional[str] = None execution_time_ms: Optional[int] = None @router.get("/scene-templates", response_model=List[SceneTemplateItem]) async def list_scene_templates_v1(current_user: User = Depends(get_current_user)): _ = current_user return list_scene_template_meta() @router.post("/agents/from-template", response_model=AgentResponse, status_code=status.HTTP_201_CREATED) async def create_agent_from_template_v1( body: AgentFromSceneTemplateCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): try: workflow_config = build_workflow_for_template( body.template_id, body.parameters or {} ) except ValueError as e: raise ValidationError(str(e)) validation_result = validate_workflow( workflow_config.get("nodes", []), workflow_config.get("edges", []) ) if not validation_result["valid"]: raise ValidationError( "工作流配置验证失败: " + ", ".join(validation_result["errors"]) ) existing_agent = db.query(Agent).filter( Agent.name == body.name, Agent.user_id == current_user.id, ).first() if existing_agent: raise ConflictError(f"Agent名称 '{body.name}' 已存在") desc = body.description or f"自场景模板 {body.template_id} 创建" agent = Agent( name=body.name, description=desc, workflow_config=workflow_config, budget_config=body.budget_config, user_id=current_user.id, status="draft", ) db.add(agent) db.commit() db.refresh(agent) logger.info( f"用户 {current_user.username} 从模板 {body.template_id} 创建 Agent: {agent.name} ({agent.id})" ) return agent @router.post("/templates/{template_id}/execute", response_model=ExecuteTemplateResponse) async def execute_template( template_id: str, body: ExecuteTemplateRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """一键执行模板:生成工作流 → 创建 Execution → 触发异步执行 → 返回 execution_id 用于轮询进度。""" # 1. 根据模板 ID 和参数构建工作流 try: workflow_config = build_workflow_for_template( template_id, body.parameters or {} ) except ValueError as e: raise ValidationError(str(e)) # 2. 校验工作流 validation_result = validate_workflow( workflow_config.get("nodes", []), workflow_config.get("edges", []) ) if not validation_result["valid"]: raise ValidationError( "工作流配置验证失败: " + ", ".join(validation_result["errors"]) ) # 3. 创建临时 Agent(内联,不持久化到 agents 表,避免污染列表) # 使用 name = f"平台模板执行-{template_id}" 作为标记 agent_name = f"平台模板执行-{template_id}-{uuid.uuid4().hex[:8]}" temp_agent = Agent( name=agent_name, description=f"来自平台模板 {template_id} 的一次性执行", workflow_config=workflow_config, user_id=current_user.id, status="draft", ) db.add(temp_agent) db.flush() # 获取 agent.id # 4. 创建 Execution 记录 execution = Execution( agent_id=temp_agent.id, input_data={ "USER_INPUT": body.message, "query": body.message, "message": body.message, "template_id": template_id, "parameters": body.parameters or {}, }, status="pending", ) db.add(execution) db.flush() execution_id = str(execution.id) agent_id_str = str(temp_agent.id) # 5. 触发异步执行 try: from app.tasks.agent_tasks import execute_agent_task task = execute_agent_task.delay( agent_id_str, { "USER_INPUT": body.message, "query": body.message, "message": body.message, }, execution_id=execution_id, ) execution.task_id = task.id execution.status = "running" db.commit() logger.info( f"模板执行已触发: template={template_id} execution={execution_id} " f"task={task.id} user={current_user.username}" ) except Exception as e: # Celery 不可用时,回退到同步执行 logger.warning(f"Celery 不可用,回退同步执行: {e}") execution.status = "running" db.commit() _run_template_sync(db, temp_agent, execution, body.message) return ExecuteTemplateResponse( execution_id=execution_id, status=execution.status, message="模板执行已触发,请通过 execution_id 轮询进度", ) def _run_template_sync(db: Session, agent: Agent, execution: Execution, message: str): """同步执行模板(Celery 不可用时的回退方案)。""" import asyncio from app.agent_runtime.core import AgentRuntime from app.agent_runtime.schemas import AgentConfig, AgentLLMConfig, AgentToolConfig, AgentMemoryConfig wf = agent.workflow_config or {} nodes = wf.get("nodes", []) system_prompt = "你是一个有用的AI助手。" model = "deepseek-v4-flash" provider = "deepseek" temperature = 0.7 max_iterations = 10 for n in nodes: cfg = n.get("data", {}) if isinstance(n, dict) else {} if n.get("type") in ("agent", "llm", "template"): system_prompt = cfg.get("system_prompt", "") or cfg.get("prompt", "") or system_prompt model = cfg.get("model", model) provider = cfg.get("provider", provider) temperature = float(cfg.get("temperature", temperature)) max_iterations = int(cfg.get("max_iterations", max_iterations)) break async def _run(): config = AgentConfig( name=agent.name, system_prompt=system_prompt, llm=AgentLLMConfig(model=model, provider=provider, temperature=temperature, max_iterations=max_iterations), tools=AgentToolConfig(), memory=AgentMemoryConfig(), ) runtime = AgentRuntime(config=config) return await runtime.run(message) try: loop = asyncio.get_event_loop() if loop.is_running(): import nest_asyncio nest_asyncio.apply() result = asyncio.run(_run()) except RuntimeError: result = asyncio.run(_run()) import time execution.output_data = {"result": result.content, "iterations": result.iterations_used} execution.status = "completed" if result.success else "failed" execution.execution_time = 0 if not result.success: execution.error_message = result.error db.commit() @router.get("/templates/executions/{execution_id}/progress", response_model=ExecutionProgressResponse) async def get_execution_progress( execution_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """查询模板执行进度。""" execution = db.query(Execution).filter(Execution.id == execution_id).first() if not execution: raise HTTPException(status_code=404, detail="执行记录不存在") progress_pct = 0 if execution.status == "running": progress_pct = 50 elif execution.status == "completed": progress_pct = 100 elif execution.status == "failed": progress_pct = 100 output = None if execution.output_data: if isinstance(execution.output_data, dict): output = execution.output_data.get("result") or execution.output_data.get("output") or json.dumps(execution.output_data, ensure_ascii=False) elif isinstance(execution.output_data, str): output = execution.output_data return ExecutionProgressResponse( execution_id=execution_id, status=execution.status, progress_pct=progress_pct, output=output, error=execution.error_message, execution_time_ms=execution.execution_time, )