Files
aiagent/backend/app/api/template_market.py
2026-01-19 00:09:36 +08:00

573 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工作流模板市场API
支持用户分享、搜索、评分、收藏模板
"""
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from sqlalchemy import func, or_, and_
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from datetime import datetime
import logging
from app.core.database import get_db
from app.models.workflow_template import WorkflowTemplate, TemplateRating, TemplateFavorite
from app.models.workflow import Workflow
from app.api.auth import get_current_user
from app.models.user import User
from app.core.exceptions import NotFoundError, ValidationError
from app.services.workflow_validator import validate_workflow
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/template-market", tags=["template-market"])
class TemplateCreate(BaseModel):
"""模板创建模型"""
name: str
description: Optional[str] = None
category: Optional[str] = None
tags: Optional[List[str]] = None
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
thumbnail: Optional[str] = None
is_public: bool = True
class TemplateUpdate(BaseModel):
"""模板更新模型"""
name: Optional[str] = None
description: Optional[str] = None
category: Optional[str] = None
tags: Optional[List[str]] = None
thumbnail: Optional[str] = None
is_public: Optional[bool] = None
class TemplateResponse(BaseModel):
"""模板响应模型"""
id: str
name: str
description: Optional[str]
category: Optional[str]
tags: Optional[List[str]]
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
thumbnail: Optional[str]
is_public: bool
is_featured: bool
view_count: int
use_count: int
rating_count: int
rating_avg: float
user_id: str
creator_username: Optional[str] = None
is_favorited: Optional[bool] = None # 当前用户是否收藏
user_rating: Optional[int] = None # 当前用户的评分
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class RatingCreate(BaseModel):
"""评分创建模型"""
rating: int # 1-5
comment: Optional[str] = None
@router.get("", response_model=List[TemplateResponse])
async def get_templates(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
search: Optional[str] = None,
category: Optional[str] = None,
tags: Optional[str] = None, # 逗号分隔的标签
sort_by: Optional[str] = Query("created_at", regex="^(created_at|rating_avg|use_count|view_count)$"),
sort_order: Optional[str] = Query("desc", regex="^(asc|desc)$"),
featured_only: bool = Query(False),
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user)
):
"""获取模板列表"""
query = db.query(WorkflowTemplate).filter(WorkflowTemplate.is_public == True)
# 搜索
if search:
query = query.filter(
or_(
WorkflowTemplate.name.like(f"%{search}%"),
WorkflowTemplate.description.like(f"%{search}%")
)
)
# 分类筛选
if category:
query = query.filter(WorkflowTemplate.category == category)
# 标签筛选
if tags:
tag_list = [tag.strip() for tag in tags.split(",")]
# MySQL JSON查询简化版实际可能需要更复杂的查询
for tag in tag_list:
query = query.filter(WorkflowTemplate.tags.contains([tag]))
# 精选筛选
if featured_only:
query = query.filter(WorkflowTemplate.is_featured == True)
# 排序
if sort_by == "rating_avg":
order_by = WorkflowTemplate.rating_avg.desc() if sort_order == "desc" else WorkflowTemplate.rating_avg.asc()
elif sort_by == "use_count":
order_by = WorkflowTemplate.use_count.desc() if sort_order == "desc" else WorkflowTemplate.use_count.asc()
elif sort_by == "view_count":
order_by = WorkflowTemplate.view_count.desc() if sort_order == "desc" else WorkflowTemplate.view_count.asc()
else:
order_by = WorkflowTemplate.created_at.desc() if sort_order == "desc" else WorkflowTemplate.created_at.asc()
query = query.order_by(order_by)
# 分页
templates = query.offset(skip).limit(limit).all()
# 构建响应
result = []
for template in templates:
template_dict = {
"id": template.id,
"name": template.name,
"description": template.description,
"category": template.category,
"tags": template.tags,
"nodes": template.nodes,
"edges": template.edges,
"thumbnail": template.thumbnail,
"is_public": template.is_public,
"is_featured": template.is_featured,
"view_count": template.view_count,
"use_count": template.use_count,
"rating_count": template.rating_count,
"rating_avg": template.rating_avg,
"user_id": template.user_id,
"creator_username": template.user.username if template.user else None,
"created_at": template.created_at,
"updated_at": template.updated_at
}
# 如果用户已登录,检查是否收藏和评分
if current_user:
favorite = db.query(TemplateFavorite).filter(
TemplateFavorite.template_id == template.id,
TemplateFavorite.user_id == current_user.id
).first()
template_dict["is_favorited"] = favorite is not None
rating = db.query(TemplateRating).filter(
TemplateRating.template_id == template.id,
TemplateRating.user_id == current_user.id
).first()
template_dict["user_rating"] = rating.rating if rating else None
else:
template_dict["is_favorited"] = None
template_dict["user_rating"] = None
result.append(TemplateResponse(**template_dict))
return result
@router.get("/{template_id}", response_model=TemplateResponse)
async def get_template(
template_id: str,
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user)
):
"""获取模板详情"""
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查权限
if not template.is_public and (not current_user or template.user_id != current_user.id):
raise HTTPException(status_code=403, detail="无权访问此模板")
# 增加查看次数
template.view_count += 1
db.commit()
# 构建响应
template_dict = {
"id": template.id,
"name": template.name,
"description": template.description,
"category": template.category,
"tags": template.tags,
"nodes": template.nodes,
"edges": template.edges,
"thumbnail": template.thumbnail,
"is_public": template.is_public,
"is_featured": template.is_featured,
"view_count": template.view_count,
"use_count": template.use_count,
"rating_count": template.rating_count,
"rating_avg": template.rating_avg,
"user_id": template.user_id,
"creator_username": template.user.username if template.user else None,
"created_at": template.created_at,
"updated_at": template.updated_at
}
# 如果用户已登录,检查是否收藏和评分
if current_user:
favorite = db.query(TemplateFavorite).filter(
TemplateFavorite.template_id == template.id,
TemplateFavorite.user_id == current_user.id
).first()
template_dict["is_favorited"] = favorite is not None
rating = db.query(TemplateRating).filter(
TemplateRating.template_id == template.id,
TemplateRating.user_id == current_user.id
).first()
template_dict["user_rating"] = rating.rating if rating else None
else:
template_dict["is_favorited"] = None
template_dict["user_rating"] = None
return TemplateResponse(**template_dict)
@router.post("", response_model=TemplateResponse, status_code=status.HTTP_201_CREATED)
async def create_template(
template_data: TemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""分享模板"""
# 验证工作流
validation_result = validate_workflow(template_data.nodes, template_data.edges)
if not validation_result["valid"]:
raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}")
# 创建模板
template = WorkflowTemplate(
name=template_data.name,
description=template_data.description,
category=template_data.category,
tags=template_data.tags or [],
nodes=template_data.nodes,
edges=template_data.edges,
thumbnail=template_data.thumbnail,
is_public=template_data.is_public,
user_id=current_user.id
)
db.add(template)
db.commit()
db.refresh(template)
return template
@router.put("/{template_id}", response_model=TemplateResponse)
async def update_template(
template_id: str,
template_data: TemplateUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新模板"""
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查权限
if template.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权修改此模板")
# 更新字段
if template_data.name is not None:
template.name = template_data.name
if template_data.description is not None:
template.description = template_data.description
if template_data.category is not None:
template.category = template_data.category
if template_data.tags is not None:
template.tags = template_data.tags
if template_data.thumbnail is not None:
template.thumbnail = template_data.thumbnail
if template_data.is_public is not None:
template.is_public = template_data.is_public
db.commit()
db.refresh(template)
return template
@router.delete("/{template_id}", status_code=status.HTTP_200_OK)
async def delete_template(
template_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除模板"""
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查权限
if template.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权删除此模板")
db.delete(template)
db.commit()
return {"message": "模板已删除"}
@router.post("/{template_id}/use", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
async def use_template(
template_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""使用模板创建工作流"""
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查权限
if not template.is_public and template.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权使用此模板")
# 创建工作流
workflow = Workflow(
name=name or f"{template.name} (副本)",
description=description or template.description,
nodes=template.nodes,
edges=template.edges,
user_id=current_user.id
)
db.add(workflow)
# 增加使用次数
template.use_count += 1
db.commit()
db.refresh(workflow)
return {
"message": "工作流已创建",
"workflow_id": workflow.id,
"workflow_name": workflow.name
}
@router.post("/{template_id}/rate", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
async def rate_template(
template_id: str,
rating_data: RatingCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""评分模板"""
if rating_data.rating < 1 or rating_data.rating > 5:
raise ValidationError("评分必须在1-5之间")
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查是否已评分
existing_rating = db.query(TemplateRating).filter(
TemplateRating.template_id == template_id,
TemplateRating.user_id == current_user.id
).first()
if existing_rating:
# 更新评分
existing_rating.rating = rating_data.rating
existing_rating.comment = rating_data.comment
else:
# 创建新评分
rating = TemplateRating(
template_id=template_id,
user_id=current_user.id,
rating=rating_data.rating,
comment=rating_data.comment
)
db.add(rating)
template.rating_count += 1
# 重新计算平均评分
ratings = db.query(func.avg(TemplateRating.rating)).filter(
TemplateRating.template_id == template_id
).scalar()
template.rating_avg = float(ratings) if ratings else 0.0
db.commit()
return {
"message": "评分成功",
"rating": rating_data.rating,
"rating_avg": template.rating_avg,
"rating_count": template.rating_count
}
@router.post("/{template_id}/favorite", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
async def favorite_template(
template_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""收藏模板"""
template = db.query(WorkflowTemplate).filter(WorkflowTemplate.id == template_id).first()
if not template:
raise NotFoundError("模板", template_id)
# 检查是否已收藏
existing_favorite = db.query(TemplateFavorite).filter(
TemplateFavorite.template_id == template_id,
TemplateFavorite.user_id == current_user.id
).first()
if existing_favorite:
raise HTTPException(status_code=400, detail="已收藏此模板")
# 创建收藏
favorite = TemplateFavorite(
template_id=template_id,
user_id=current_user.id
)
db.add(favorite)
db.commit()
return {"message": "收藏成功"}
@router.delete("/{template_id}/favorite", status_code=status.HTTP_200_OK)
async def unfavorite_template(
template_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""取消收藏"""
favorite = db.query(TemplateFavorite).filter(
TemplateFavorite.template_id == template_id,
TemplateFavorite.user_id == current_user.id
).first()
if not favorite:
raise HTTPException(status_code=404, detail="未收藏此模板")
db.delete(favorite)
db.commit()
return {"message": "已取消收藏"}
@router.get("/my/favorites", response_model=List[TemplateResponse])
async def get_my_favorites(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取我的收藏"""
favorites = db.query(TemplateFavorite).filter(
TemplateFavorite.user_id == current_user.id
).offset(skip).limit(limit).all()
result = []
for favorite in favorites:
template = favorite.template
template_dict = {
"id": template.id,
"name": template.name,
"description": template.description,
"category": template.category,
"tags": template.tags,
"nodes": template.nodes,
"edges": template.edges,
"thumbnail": template.thumbnail,
"is_public": template.is_public,
"is_featured": template.is_featured,
"view_count": template.view_count,
"use_count": template.use_count,
"rating_count": template.rating_count,
"rating_avg": template.rating_avg,
"user_id": template.user_id,
"creator_username": template.user.username if template.user else None,
"is_favorited": True,
"user_rating": None,
"created_at": template.created_at,
"updated_at": template.updated_at
}
# 获取用户评分
rating = db.query(TemplateRating).filter(
TemplateRating.template_id == template.id,
TemplateRating.user_id == current_user.id
).first()
template_dict["user_rating"] = rating.rating if rating else None
result.append(TemplateResponse(**template_dict))
return result
@router.get("/my/shared", response_model=List[TemplateResponse])
async def get_my_shared_templates(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取我分享的模板"""
templates = db.query(WorkflowTemplate).filter(
WorkflowTemplate.user_id == current_user.id
).offset(skip).limit(limit).all()
result = []
for template in templates:
template_dict = {
"id": template.id,
"name": template.name,
"description": template.description,
"category": template.category,
"tags": template.tags,
"nodes": template.nodes,
"edges": template.edges,
"thumbnail": template.thumbnail,
"is_public": template.is_public,
"is_featured": template.is_featured,
"view_count": template.view_count,
"use_count": template.use_count,
"rating_count": template.rating_count,
"rating_avg": template.rating_avg,
"user_id": template.user_id,
"creator_username": template.user.username if template.user else None,
"is_favorited": None,
"user_rating": None,
"created_at": template.created_at,
"updated_at": template.updated_at
}
result.append(TemplateResponse(**template_dict))
return result