Files
aiagent/backend/app/api/template_market.py

573 lines
19 KiB
Python
Raw Normal View History

2026-01-19 00:09:36 +08:00
"""
工作流模板市场API
支持用户分享搜索评分收藏模板
"""
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from sqlalchemy import func, or_, and_
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from datetime import datetime
import logging
from app.core.database import get_db
from app.models.workflow_template import WorkflowTemplate, TemplateRating, TemplateFavorite
from app.models.workflow import Workflow
from app.api.auth import get_current_user
from app.models.user import User
from app.core.exceptions import NotFoundError, ValidationError
from app.services.workflow_validator import validate_workflow
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/template-market", tags=["template-market"])
class TemplateCreate(BaseModel):
"""模板创建模型"""
name: str
description: Optional[str] = None
category: Optional[str] = None
tags: Optional[List[str]] = None
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
thumbnail: Optional[str] = None
is_public: bool = True
class TemplateUpdate(BaseModel):
"""模板更新模型"""
name: Optional[str] = None
description: Optional[str] = None
category: Optional[str] = None
tags: Optional[List[str]] = None
thumbnail: Optional[str] = None
is_public: Optional[bool] = None
class TemplateResponse(BaseModel):
"""模板响应模型"""
id: str
name: str
description: Optional[str]
category: Optional[str]
tags: Optional[List[str]]
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
thumbnail: Optional[str]
is_public: bool
is_featured: bool
view_count: int
use_count: int
rating_count: int
rating_avg: float
user_id: str
creator_username: Optional[str] = None
is_favorited: Optional[bool] = None # 当前用户是否收藏
user_rating: Optional[int] = None # 当前用户的评分
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class RatingCreate(BaseModel):
"""评分创建模型"""
rating: int # 1-5
comment: Optional[str] = None
@router.get("", response_model=List[TemplateResponse])
async def get_templates(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
search: Optional[str] = None,
category: Optional[str] = None,
tags: Optional[str] = None, # 逗号分隔的标签
sort_by: Optional[str] = Query("created_at", regex="^(created_at|rating_avg|use_count|view_count)$"),
sort_order: Optional[str] = Query("desc", regex="^(asc|desc)$"),
featured_only: bool = Query(False),
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user)
):
"""获取模板列表"""
query = db.query(WorkflowTemplate).filter(WorkflowTemplate.is_public == True)
# 搜索
if search:
query = query.filter(
or_(
WorkflowTemplate.name.like(f"%{search}%"),
WorkflowTemplate.description.like(f"%{search}%")
)
)
# 分类筛选
if category:
query = query.filter(WorkflowTemplate.category == category)
# 标签筛选
if tags:
tag_list = [tag.strip() for tag in tags.split(",")]
# MySQL JSON查询简化版实际可能需要更复杂的查询
for tag in tag_list:
query = query.filter(WorkflowTemplate.tags.contains([tag]))
# 精选筛选
if featured_only:
query = query.filter(WorkflowTemplate.is_featured == True)
# 排序
if sort_by == "rating_avg":
order_by = WorkflowTemplate.rating_avg.desc() if sort_order == "desc" else WorkflowTemplate.rating_avg.asc()
elif sort_by == "use_count":
order_by = WorkflowTemplate.use_count.desc() if sort_order == "desc" else WorkflowTemplate.use_count.asc()
elif sort_by == "view_count":
order_by = WorkflowTemplate.view_count.desc() if sort_order == "desc" else WorkflowTemplate.view_count.asc()
else:
order_by = WorkflowTemplate.created_at.desc() if sort_order == "desc" else WorkflowTemplate.created_at.asc()
query = query.order_by(order_by)
# 分页
templates = query.offset(skip).limit(limit).all()
# 构建响应
result = []
for template in templates:
template_dict = {
"id": template.id,
"name": template.name,
"description": template.description,
"category": template.category,
"tags": template.tags,
"nodes": template.nodes,
"edges": template.edges,
"thumbnail": template.thumbnail,
"is_public": template.is_public,
"is_featured": template.is_featured,
"view_count": template.view_count,
"use_count": template.use_count,
"rating_count": template.rating_count,
"rating_avg": template.rating_avg,
"user_id": template.user_id,
"creator_username": template.user.username if template.user else None,
"created_at": template.created_at,
"updated_at": template.updated_at
}
# 如果用户已登录,检查是否收藏和评分
if current_user:
favorite = db.query(TemplateFavorite).filter(
TemplateFavorite.template_id == template.id,
TemplateFavorite.user_id == current_user.id
).first()
template_dict["is_favorited"] = favorite is not None
rating = db.query(TemplateRating).filter(
TemplateRating.template_id == template.id,
TemplateRating.user_id == current_user.id
).first()
template_dict["user_rating"] = rating.rating if rating else None
else:
template_dict["is_favorited"] = None
template_dict["user_rating"] = None
result.append(TemplateResponse(**template_dict))
return result
@router.get("/{template_id}", response_model=TemplateResponse)
async def get_template(
template_id: str,
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user)
):
"""获取模板详情"""
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查权限
if not template.is_public and (not current_user or template.user_id != current_user.id):
raise HTTPException(status_code=403, detail="无权访问此模板")
# 增加查看次数
template.view_count += 1
db.commit()
# 构建响应
template_dict = {
"id": template.id,
"name": template.name,
"description": template.description,
"category": template.category,
"tags": template.tags,
"nodes": template.nodes,
"edges": template.edges,
"thumbnail": template.thumbnail,
"is_public": template.is_public,
"is_featured": template.is_featured,
"view_count": template.view_count,
"use_count": template.use_count,
"rating_count": template.rating_count,
"rating_avg": template.rating_avg,
"user_id": template.user_id,
"creator_username": template.user.username if template.user else None,
"created_at": template.created_at,
"updated_at": template.updated_at
}
# 如果用户已登录,检查是否收藏和评分
if current_user:
favorite = db.query(TemplateFavorite).filter(
TemplateFavorite.template_id == template.id,
TemplateFavorite.user_id == current_user.id
).first()
template_dict["is_favorited"] = favorite is not None
rating = db.query(TemplateRating).filter(
TemplateRating.template_id == template.id,
TemplateRating.user_id == current_user.id
).first()
template_dict["user_rating"] = rating.rating if rating else None
else:
template_dict["is_favorited"] = None
template_dict["user_rating"] = None
return TemplateResponse(**template_dict)
@router.post("", response_model=TemplateResponse, status_code=status.HTTP_201_CREATED)
async def create_template(
template_data: TemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""分享模板"""
# 验证工作流
validation_result = validate_workflow(template_data.nodes, template_data.edges)
if not validation_result["valid"]:
raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}")
# 创建模板
template = WorkflowTemplate(
name=template_data.name,
description=template_data.description,
category=template_data.category,
tags=template_data.tags or [],
nodes=template_data.nodes,
edges=template_data.edges,
thumbnail=template_data.thumbnail,
is_public=template_data.is_public,
user_id=current_user.id
)
db.add(template)
db.commit()
db.refresh(template)
return template
@router.put("/{template_id}", response_model=TemplateResponse)
async def update_template(
template_id: str,
template_data: TemplateUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新模板"""
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查权限
if template.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权修改此模板")
# 更新字段
if template_data.name is not None:
template.name = template_data.name
if template_data.description is not None:
template.description = template_data.description
if template_data.category is not None:
template.category = template_data.category
if template_data.tags is not None:
template.tags = template_data.tags
if template_data.thumbnail is not None:
template.thumbnail = template_data.thumbnail
if template_data.is_public is not None:
template.is_public = template_data.is_public
db.commit()
db.refresh(template)
return template
@router.delete("/{template_id}", status_code=status.HTTP_200_OK)
async def delete_template(
template_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除模板"""
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查权限
if template.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权删除此模板")
db.delete(template)
db.commit()
return {"message": "模板已删除"}
@router.post("/{template_id}/use", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
async def use_template(
template_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""使用模板创建工作流"""
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查权限
if not template.is_public and template.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权使用此模板")
# 创建工作流
workflow = Workflow(
name=name or f"{template.name} (副本)",
description=description or template.description,
nodes=template.nodes,
edges=template.edges,
user_id=current_user.id
)
db.add(workflow)
# 增加使用次数
template.use_count += 1
db.commit()
db.refresh(workflow)
return {
"message": "工作流已创建",
"workflow_id": workflow.id,
"workflow_name": workflow.name
}
@router.post("/{template_id}/rate", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
async def rate_template(
template_id: str,
rating_data: RatingCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""评分模板"""
if rating_data.rating < 1 or rating_data.rating > 5:
raise ValidationError("评分必须在1-5之间")
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查是否已评分
existing_rating = db.query(TemplateRating).filter(
TemplateRating.template_id == template_id,
TemplateRating.user_id == current_user.id
).first()
if existing_rating:
# 更新评分
existing_rating.rating = rating_data.rating
existing_rating.comment = rating_data.comment
else:
# 创建新评分
rating = TemplateRating(
template_id=template_id,
user_id=current_user.id,
rating=rating_data.rating,
comment=rating_data.comment
)
db.add(rating)
template.rating_count += 1
# 重新计算平均评分
ratings = db.query(func.avg(TemplateRating.rating)).filter(
TemplateRating.template_id == template_id
).scalar()
template.rating_avg = float(ratings) if ratings else 0.0
db.commit()
return {
"message": "评分成功",
"rating": rating_data.rating,
"rating_avg": template.rating_avg,
"rating_count": template.rating_count
}
@router.post("/{template_id}/favorite", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
async def favorite_template(
template_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""收藏模板"""
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查是否已收藏
existing_favorite = db.query(TemplateFavorite).filter(
TemplateFavorite.template_id == template_id,
TemplateFavorite.user_id == current_user.id
).first()
if existing_favorite:
raise HTTPException(status_code=400, detail="已收藏此模板")
# 创建收藏
favorite = TemplateFavorite(
template_id=template_id,
user_id=current_user.id
)
db.add(favorite)
db.commit()
return {"message": "收藏成功"}
@router.delete("/{template_id}/favorite", status_code=status.HTTP_200_OK)
async def unfavorite_template(
template_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""取消收藏"""
favorite = db.query(TemplateFavorite).filter(
TemplateFavorite.template_id == template_id,
TemplateFavorite.user_id == current_user.id
).first()
if not favorite:
raise HTTPException(status_code=404, detail="未收藏此模板")
db.delete(favorite)
db.commit()
return {"message": "已取消收藏"}
@router.get("/my/favorites", response_model=List[TemplateResponse])
async def get_my_favorites(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取我的收藏"""
favorites = db.query(TemplateFavorite).filter(
TemplateFavorite.user_id == current_user.id
).offset(skip).limit(limit).all()
result = []
for favorite in favorites:
template = favorite.template
template_dict = {
"id": template.id,
"name": template.name,
"description": template.description,
"category": template.category,
"tags": template.tags,
"nodes": template.nodes,
"edges": template.edges,
"thumbnail": template.thumbnail,
"is_public": template.is_public,
"is_featured": template.is_featured,
"view_count": template.view_count,
"use_count": template.use_count,
"rating_count": template.rating_count,
"rating_avg": template.rating_avg,
"user_id": template.user_id,
"creator_username": template.user.username if template.user else None,
"is_favorited": True,
"user_rating": None,
"created_at": template.created_at,
"updated_at": template.updated_at
}
# 获取用户评分
rating = db.query(TemplateRating).filter(
TemplateRating.template_id == template.id,
TemplateRating.user_id == current_user.id
).first()
template_dict["user_rating"] = rating.rating if rating else None
result.append(TemplateResponse(**template_dict))
return result
@router.get("/my/shared", response_model=List[TemplateResponse])
async def get_my_shared_templates(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取我分享的模板"""
templates = db.query(WorkflowTemplate).filter(
WorkflowTemplate.user_id == current_user.id
).offset(skip).limit(limit).all()
result = []
for template in templates:
template_dict = {
"id": template.id,
"name": template.name,
"description": template.description,
"category": template.category,
"tags": template.tags,
"nodes": template.nodes,
"edges": template.edges,
"thumbnail": template.thumbnail,
"is_public": template.is_public,
"is_featured": template.is_featured,
"view_count": template.view_count,
"use_count": template.use_count,
"rating_count": template.rating_count,
"rating_avg": template.rating_avg,
"user_id": template.user_id,
"creator_username": template.user.username if template.user else None,
"is_favorited": None,
"user_rating": None,
"created_at": template.created_at,
"updated_at": template.updated_at
}
result.append(TemplateResponse(**template_dict))
return result