""" 节点模板API """ from fastapi import APIRouter, Depends, HTTPException, status, Query from sqlalchemy.orm import Session from sqlalchemy import or_, and_ from pydantic import BaseModel, Field from typing import List, Optional, Dict, Any import logging from app.core.database import get_db from app.api.auth import get_current_user from app.models.user import User from app.models.node_template import NodeTemplate from app.core.exceptions import NotFoundError, ConflictError, ValidationError logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/node-templates", tags=["node-templates"]) class NodeTemplateCreate(BaseModel): """创建节点模板请求模型""" name: str = Field(..., min_length=1, max_length=100, description="模板名称") description: Optional[str] = Field(None, description="模板描述") category: Optional[str] = Field(None, description="分类") tags: Optional[List[str]] = Field(None, description="标签列表") prompt: str = Field(..., min_length=1, description="提示词模板") variables: Optional[List[Dict[str, Any]]] = Field(None, description="变量定义列表") provider: Optional[str] = Field("deepseek", description="默认LLM提供商") model: Optional[str] = Field("deepseek-chat", description="默认模型") temperature: Optional[str] = Field("0.7", description="默认温度参数") max_tokens: Optional[int] = Field(1500, description="默认最大token数") is_public: Optional[bool] = Field(False, description="是否公开") class NodeTemplateUpdate(BaseModel): """更新节点模板请求模型""" name: Optional[str] = Field(None, min_length=1, max_length=100, description="模板名称") description: Optional[str] = Field(None, description="模板描述") category: Optional[str] = Field(None, description="分类") tags: Optional[List[str]] = Field(None, description="标签列表") prompt: Optional[str] = Field(None, min_length=1, description="提示词模板") variables: Optional[List[Dict[str, Any]]] = Field(None, description="变量定义列表") provider: Optional[str] = Field(None, description="默认LLM提供商") model: Optional[str] = Field(None, description="默认模型") temperature: Optional[str] = Field(None, description="默认温度参数") max_tokens: Optional[int] = Field(None, description="默认最大token数") is_public: Optional[bool] = Field(None, description="是否公开") class NodeTemplateResponse(BaseModel): """节点模板响应模型""" id: str name: str description: Optional[str] category: Optional[str] tags: Optional[List[str]] prompt: str variables: Optional[List[Dict[str, Any]]] provider: str model: str temperature: str max_tokens: int is_public: bool is_featured: bool use_count: int user_id: str created_at: str updated_at: str class Config: from_attributes = True @router.get("", response_model=List[NodeTemplateResponse]) async def get_node_templates( skip: int = Query(0, ge=0, description="跳过记录数"), limit: int = Query(100, ge=1, le=100, description="每页记录数"), category: Optional[str] = Query(None, description="分类筛选"), tag: Optional[str] = Query(None, description="标签筛选"), search: Optional[str] = Query(None, description="搜索关键词"), is_public: Optional[bool] = Query(None, description="是否公开"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """ 获取节点模板列表 用户可以查看: - 自己创建的所有模板 - 公开的模板 """ query = db.query(NodeTemplate) # 权限过滤:只能看到自己的模板或公开的模板 query = query.filter( or_( NodeTemplate.user_id == current_user.id, NodeTemplate.is_public == True ) ) # 分类筛选 if category: query = query.filter(NodeTemplate.category == category) # 标签筛选 if tag: query = query.filter(NodeTemplate.tags.contains([tag])) # 搜索 if search: query = query.filter( or_( NodeTemplate.name.like(f"%{search}%"), NodeTemplate.description.like(f"%{search}%"), NodeTemplate.prompt.like(f"%{search}%") ) ) # 公开筛选 if is_public is not None: query = query.filter(NodeTemplate.is_public == is_public) # 排序:精选 > 使用次数 > 更新时间 templates = query.order_by( NodeTemplate.is_featured.desc(), NodeTemplate.use_count.desc(), NodeTemplate.updated_at.desc() ).offset(skip).limit(limit).all() return [template.to_dict() for template in templates] @router.get("/{template_id}", response_model=NodeTemplateResponse) async def get_node_template( template_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """获取节点模板详情""" template = db.query(NodeTemplate).filter(NodeTemplate.id == template_id).first() if not template: raise NotFoundError(f"节点模板不存在: {template_id}") # 权限检查:只能查看自己的模板或公开的模板 if template.user_id != current_user.id and not template.is_public: raise HTTPException(status_code=403, detail="无权访问此模板") return template.to_dict() @router.post("", response_model=NodeTemplateResponse, status_code=status.HTTP_201_CREATED) async def create_node_template( template_data: NodeTemplateCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """创建节点模板""" # 检查名称是否重复(同一用户) existing = db.query(NodeTemplate).filter( NodeTemplate.name == template_data.name, NodeTemplate.user_id == current_user.id ).first() if existing: raise ConflictError(f"模板名称 '{template_data.name}' 已存在") # 创建模板 template = NodeTemplate( name=template_data.name, description=template_data.description, category=template_data.category, tags=template_data.tags or [], prompt=template_data.prompt, variables=template_data.variables or [], provider=template_data.provider or "deepseek", model=template_data.model or "deepseek-chat", temperature=template_data.temperature or "0.7", max_tokens=template_data.max_tokens or 1500, is_public=template_data.is_public or False, user_id=current_user.id ) db.add(template) db.commit() db.refresh(template) logger.info(f"用户 {current_user.username} 创建了节点模板: {template.name} ({template.id})") return template.to_dict() @router.put("/{template_id}", response_model=NodeTemplateResponse) async def update_node_template( template_id: str, template_data: NodeTemplateUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """更新节点模板""" template = db.query(NodeTemplate).filter(NodeTemplate.id == template_id).first() if not template: raise NotFoundError(f"节点模板不存在: {template_id}") # 权限检查:只能更新自己的模板 if template.user_id != current_user.id: raise HTTPException(status_code=403, detail="无权更新此模板") # 如果更新名称,检查是否重复 if template_data.name and template_data.name != template.name: existing = db.query(NodeTemplate).filter( NodeTemplate.name == template_data.name, NodeTemplate.user_id == current_user.id, NodeTemplate.id != template_id ).first() if existing: raise ConflictError(f"模板名称 '{template_data.name}' 已存在") # 更新字段 update_data = template_data.dict(exclude_unset=True) for key, value in update_data.items(): setattr(template, key, value) db.commit() db.refresh(template) logger.info(f"用户 {current_user.username} 更新了节点模板: {template.name} ({template.id})") return template.to_dict() @router.delete("/{template_id}", status_code=status.HTTP_200_OK) async def delete_node_template( template_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """删除节点模板""" template = db.query(NodeTemplate).filter(NodeTemplate.id == template_id).first() if not template: raise NotFoundError(f"节点模板不存在: {template_id}") # 权限检查:只能删除自己的模板 if template.user_id != current_user.id: raise HTTPException(status_code=403, detail="无权删除此模板") db.delete(template) db.commit() logger.info(f"用户 {current_user.username} 删除了节点模板: {template.name} ({template.id})") return {"message": "节点模板已删除"} @router.post("/{template_id}/use", status_code=status.HTTP_200_OK) async def use_node_template( template_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """使用节点模板(增加使用次数)""" template = db.query(NodeTemplate).filter(NodeTemplate.id == template_id).first() if not template: raise NotFoundError(f"节点模板不存在: {template_id}") # 权限检查:只能使用自己的模板或公开的模板 if template.user_id != current_user.id and not template.is_public: raise HTTPException(status_code=403, detail="无权使用此模板") # 增加使用次数 template.use_count = (template.use_count or 0) + 1 db.commit() return template.to_dict()