278 lines
9.8 KiB
Python
278 lines
9.8 KiB
Python
"""
|
|
节点模板API
|
|
"""
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import or_, and_
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Optional, Dict, Any
|
|
import logging
|
|
from app.core.database import get_db
|
|
from app.api.auth import get_current_user
|
|
from app.models.user import User
|
|
from app.models.node_template import NodeTemplate
|
|
from app.core.exceptions import NotFoundError, ConflictError, ValidationError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/v1/node-templates", tags=["node-templates"])
|
|
|
|
|
|
class NodeTemplateCreate(BaseModel):
|
|
"""创建节点模板请求模型"""
|
|
name: str = Field(..., min_length=1, max_length=100, description="模板名称")
|
|
description: Optional[str] = Field(None, description="模板描述")
|
|
category: Optional[str] = Field(None, description="分类")
|
|
tags: Optional[List[str]] = Field(None, description="标签列表")
|
|
prompt: str = Field(..., min_length=1, description="提示词模板")
|
|
variables: Optional[List[Dict[str, Any]]] = Field(None, description="变量定义列表")
|
|
provider: Optional[str] = Field("deepseek", description="默认LLM提供商")
|
|
model: Optional[str] = Field("deepseek-chat", description="默认模型")
|
|
temperature: Optional[str] = Field("0.7", description="默认温度参数")
|
|
max_tokens: Optional[int] = Field(1500, description="默认最大token数")
|
|
is_public: Optional[bool] = Field(False, description="是否公开")
|
|
|
|
|
|
class NodeTemplateUpdate(BaseModel):
|
|
"""更新节点模板请求模型"""
|
|
name: Optional[str] = Field(None, min_length=1, max_length=100, description="模板名称")
|
|
description: Optional[str] = Field(None, description="模板描述")
|
|
category: Optional[str] = Field(None, description="分类")
|
|
tags: Optional[List[str]] = Field(None, description="标签列表")
|
|
prompt: Optional[str] = Field(None, min_length=1, description="提示词模板")
|
|
variables: Optional[List[Dict[str, Any]]] = Field(None, description="变量定义列表")
|
|
provider: Optional[str] = Field(None, description="默认LLM提供商")
|
|
model: Optional[str] = Field(None, description="默认模型")
|
|
temperature: Optional[str] = Field(None, description="默认温度参数")
|
|
max_tokens: Optional[int] = Field(None, description="默认最大token数")
|
|
is_public: Optional[bool] = Field(None, description="是否公开")
|
|
|
|
|
|
class NodeTemplateResponse(BaseModel):
|
|
"""节点模板响应模型"""
|
|
id: str
|
|
name: str
|
|
description: Optional[str]
|
|
category: Optional[str]
|
|
tags: Optional[List[str]]
|
|
prompt: str
|
|
variables: Optional[List[Dict[str, Any]]]
|
|
provider: str
|
|
model: str
|
|
temperature: str
|
|
max_tokens: int
|
|
is_public: bool
|
|
is_featured: bool
|
|
use_count: int
|
|
user_id: str
|
|
created_at: str
|
|
updated_at: str
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
@router.get("", response_model=List[NodeTemplateResponse])
|
|
async def get_node_templates(
|
|
skip: int = Query(0, ge=0, description="跳过记录数"),
|
|
limit: int = Query(100, ge=1, le=100, description="每页记录数"),
|
|
category: Optional[str] = Query(None, description="分类筛选"),
|
|
tag: Optional[str] = Query(None, description="标签筛选"),
|
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
|
is_public: Optional[bool] = Query(None, description="是否公开"),
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
获取节点模板列表
|
|
|
|
用户可以查看:
|
|
- 自己创建的所有模板
|
|
- 公开的模板
|
|
"""
|
|
query = db.query(NodeTemplate)
|
|
|
|
# 权限过滤:只能看到自己的模板或公开的模板
|
|
query = query.filter(
|
|
or_(
|
|
NodeTemplate.user_id == current_user.id,
|
|
NodeTemplate.is_public == True
|
|
)
|
|
)
|
|
|
|
# 分类筛选
|
|
if category:
|
|
query = query.filter(NodeTemplate.category == category)
|
|
|
|
# 标签筛选
|
|
if tag:
|
|
query = query.filter(NodeTemplate.tags.contains([tag]))
|
|
|
|
# 搜索
|
|
if search:
|
|
query = query.filter(
|
|
or_(
|
|
NodeTemplate.name.like(f"%{search}%"),
|
|
NodeTemplate.description.like(f"%{search}%"),
|
|
NodeTemplate.prompt.like(f"%{search}%")
|
|
)
|
|
)
|
|
|
|
# 公开筛选
|
|
if is_public is not None:
|
|
query = query.filter(NodeTemplate.is_public == is_public)
|
|
|
|
# 排序:精选 > 使用次数 > 更新时间
|
|
templates = query.order_by(
|
|
NodeTemplate.is_featured.desc(),
|
|
NodeTemplate.use_count.desc(),
|
|
NodeTemplate.updated_at.desc()
|
|
).offset(skip).limit(limit).all()
|
|
|
|
return [template.to_dict() for template in templates]
|
|
|
|
|
|
@router.get("/{template_id}", response_model=NodeTemplateResponse)
|
|
async def get_node_template(
|
|
template_id: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""获取节点模板详情"""
|
|
template = db.query(NodeTemplate).filter(NodeTemplate.id == template_id).first()
|
|
|
|
if not template:
|
|
raise NotFoundError(f"节点模板不存在: {template_id}")
|
|
|
|
# 权限检查:只能查看自己的模板或公开的模板
|
|
if template.user_id != current_user.id and not template.is_public:
|
|
raise HTTPException(status_code=403, detail="无权访问此模板")
|
|
|
|
return template.to_dict()
|
|
|
|
|
|
@router.post("", response_model=NodeTemplateResponse, status_code=status.HTTP_201_CREATED)
|
|
async def create_node_template(
|
|
template_data: NodeTemplateCreate,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""创建节点模板"""
|
|
# 检查名称是否重复(同一用户)
|
|
existing = db.query(NodeTemplate).filter(
|
|
NodeTemplate.name == template_data.name,
|
|
NodeTemplate.user_id == current_user.id
|
|
).first()
|
|
|
|
if existing:
|
|
raise ConflictError(f"模板名称 '{template_data.name}' 已存在")
|
|
|
|
# 创建模板
|
|
template = NodeTemplate(
|
|
name=template_data.name,
|
|
description=template_data.description,
|
|
category=template_data.category,
|
|
tags=template_data.tags or [],
|
|
prompt=template_data.prompt,
|
|
variables=template_data.variables or [],
|
|
provider=template_data.provider or "deepseek",
|
|
model=template_data.model or "deepseek-chat",
|
|
temperature=template_data.temperature or "0.7",
|
|
max_tokens=template_data.max_tokens or 1500,
|
|
is_public=template_data.is_public or False,
|
|
user_id=current_user.id
|
|
)
|
|
|
|
db.add(template)
|
|
db.commit()
|
|
db.refresh(template)
|
|
|
|
logger.info(f"用户 {current_user.username} 创建了节点模板: {template.name} ({template.id})")
|
|
return template.to_dict()
|
|
|
|
|
|
@router.put("/{template_id}", response_model=NodeTemplateResponse)
|
|
async def update_node_template(
|
|
template_id: str,
|
|
template_data: NodeTemplateUpdate,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""更新节点模板"""
|
|
template = db.query(NodeTemplate).filter(NodeTemplate.id == template_id).first()
|
|
|
|
if not template:
|
|
raise NotFoundError(f"节点模板不存在: {template_id}")
|
|
|
|
# 权限检查:只能更新自己的模板
|
|
if template.user_id != current_user.id:
|
|
raise HTTPException(status_code=403, detail="无权更新此模板")
|
|
|
|
# 如果更新名称,检查是否重复
|
|
if template_data.name and template_data.name != template.name:
|
|
existing = db.query(NodeTemplate).filter(
|
|
NodeTemplate.name == template_data.name,
|
|
NodeTemplate.user_id == current_user.id,
|
|
NodeTemplate.id != template_id
|
|
).first()
|
|
|
|
if existing:
|
|
raise ConflictError(f"模板名称 '{template_data.name}' 已存在")
|
|
|
|
# 更新字段
|
|
update_data = template_data.dict(exclude_unset=True)
|
|
for key, value in update_data.items():
|
|
setattr(template, key, value)
|
|
|
|
db.commit()
|
|
db.refresh(template)
|
|
|
|
logger.info(f"用户 {current_user.username} 更新了节点模板: {template.name} ({template.id})")
|
|
return template.to_dict()
|
|
|
|
|
|
@router.delete("/{template_id}", status_code=status.HTTP_200_OK)
|
|
async def delete_node_template(
|
|
template_id: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""删除节点模板"""
|
|
template = db.query(NodeTemplate).filter(NodeTemplate.id == template_id).first()
|
|
|
|
if not template:
|
|
raise NotFoundError(f"节点模板不存在: {template_id}")
|
|
|
|
# 权限检查:只能删除自己的模板
|
|
if template.user_id != current_user.id:
|
|
raise HTTPException(status_code=403, detail="无权删除此模板")
|
|
|
|
db.delete(template)
|
|
db.commit()
|
|
|
|
logger.info(f"用户 {current_user.username} 删除了节点模板: {template.name} ({template.id})")
|
|
return {"message": "节点模板已删除"}
|
|
|
|
|
|
@router.post("/{template_id}/use", status_code=status.HTTP_200_OK)
|
|
async def use_node_template(
|
|
template_id: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""使用节点模板(增加使用次数)"""
|
|
template = db.query(NodeTemplate).filter(NodeTemplate.id == template_id).first()
|
|
|
|
if not template:
|
|
raise NotFoundError(f"节点模板不存在: {template_id}")
|
|
|
|
# 权限检查:只能使用自己的模板或公开的模板
|
|
if template.user_id != current_user.id and not template.is_public:
|
|
raise HTTPException(status_code=403, detail="无权使用此模板")
|
|
|
|
# 增加使用次数
|
|
template.use_count = (template.use_count or 0) + 1
|
|
db.commit()
|
|
|
|
return template.to_dict()
|