新增专业级提示词生成接口

This commit is contained in:
2025-04-02 21:37:24 +08:00
parent b115462260
commit 59ff67595c
10 changed files with 2381 additions and 0 deletions

239
flask_prompt_master/app.py Normal file
View File

@@ -0,0 +1,239 @@
from flask import Flask, jsonify, request
from flask_cors import CORS
from datetime import datetime
app = Flask(__name__)
CORS(app) # 允许跨域请求
# 模拟数据库中的模板数据
MOCK_TEMPLATES = [
{
"id": "1",
"name": "后端开发工程师",
"description": "专业的后端开发工程师模板",
"category": "软件开发",
"industry": "互联网",
"profession": "开发工程师",
"sub_category": "后端开发",
"system_prompt": "你是一位专业的后端开发工程师..."
},
{
"id": "2",
"name": "前端开发工程师",
"description": "专业的前端开发工程师模板",
"category": "软件开发",
"industry": "互联网",
"profession": "开发工程师",
"sub_category": "前端开发",
"system_prompt": "你是一位专业的前端开发工程师..."
}
]
# 身份验证装饰器
def require_auth(f):
from functools import wraps
@wraps(f)
def decorated(*args, **kwargs):
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({'error': 'No authorization token provided'}), 401
# 这里简单验证token实际应该更严格
token = auth_header.split(' ')[1]
if token != 'test_token_123':
return jsonify({'error': 'Invalid token'}), 401
return f(*args, **kwargs)
return decorated
# 获取模板列表接口
@app.route('/api/v1/templates', methods=['GET'])
@require_auth
def get_templates():
# 获取查询参数
page = int(request.args.get('page', 1))
size = int(request.args.get('size', 10))
category = request.args.get('category')
industry = request.args.get('industry')
profession = request.args.get('profession')
# 过滤模板
filtered_templates = MOCK_TEMPLATES
if category:
filtered_templates = [t for t in filtered_templates if t['category'] == category]
if industry:
filtered_templates = [t for t in filtered_templates if t['industry'] == industry]
if profession:
filtered_templates = [t for t in filtered_templates if t['profession'] == profession]
# 计算分页
start = (page - 1) * size
end = start + size
paginated_templates = filtered_templates[start:end]
return jsonify({
'code': 200,
'data': {
'total': len(filtered_templates),
'pages': (len(filtered_templates) + size - 1) // size,
'current_page': page,
'templates': paginated_templates
}
})
# 获取单个模板接口
@app.route('/api/v1/templates/<template_id>', methods=['GET'])
@require_auth
def get_template(template_id):
template = next((t for t in MOCK_TEMPLATES if t['id'] == template_id), None)
if not template:
return jsonify({'error': 'Template not found'}), 404
return jsonify({
'code': 200,
'data': template
})
# 创建模板接口
@app.route('/api/v1/templates', methods=['POST'])
@require_auth
def create_template():
data = request.get_json()
# 验证必要字段
required_fields = ['name', 'description', 'category', 'industry', 'profession']
for field in required_fields:
if field not in data:
return jsonify({'error': f'Missing required field: {field}'}), 400
# 创建新模板
new_template = {
'id': str(len(MOCK_TEMPLATES) + 1),
**data,
'created_at': datetime.now().isoformat()
}
MOCK_TEMPLATES.append(new_template)
return jsonify({
'code': 200,
'data': {
'id': new_template['id'],
'message': 'Template created successfully'
}
})
# 更新模板接口
@app.route('/api/v1/templates/<template_id>', methods=['PUT'])
@require_auth
def update_template(template_id):
template = next((t for t in MOCK_TEMPLATES if t['id'] == template_id), None)
if not template:
return jsonify({'error': 'Template not found'}), 404
data = request.get_json()
template.update(data)
return jsonify({
'code': 200,
'data': {
'message': 'Template updated successfully'
}
})
# 删除模板接口
@app.route('/api/v1/templates/<template_id>', methods=['DELETE'])
@require_auth
def delete_template(template_id):
template_index = next((i for i, t in enumerate(MOCK_TEMPLATES) if t['id'] == template_id), None)
if template_index is None:
return jsonify({'error': 'Template not found'}), 404
MOCK_TEMPLATES.pop(template_index)
return jsonify({
'code': 200,
'data': {
'message': 'Template deleted successfully'
}
})
# 模糊搜索提示词模板接口
@app.route('/api/v1/templates/search', methods=['GET'])
@require_auth
def search_templates():
"""模糊搜索提示词模板"""
# 获取搜索参数
keyword = request.args.get('keyword', '').strip()
page = int(request.args.get('page', 1))
size = int(request.args.get('size', 10))
# 如果没有关键词,返回空结果
if not keyword:
return jsonify({
'code': 200,
'data': {
'total': 0,
'pages': 0,
'current_page': page,
'templates': []
}
})
# 执行模糊搜索
# 搜索范围:名称、描述、分类、行业、职业、子分类、系统提示词
search_results = []
for template in MOCK_TEMPLATES:
# 在各个字段中搜索关键词
if any(keyword.lower() in str(value).lower() for value in [
template.get('name', ''),
template.get('description', ''),
template.get('category', ''),
template.get('industry', ''),
template.get('profession', ''),
template.get('sub_category', ''),
template.get('system_prompt', '')
]):
# 添加匹配度信息
match_score = sum(
str(value).lower().count(keyword.lower())
for value in [
template.get('name', ''),
template.get('description', ''),
template.get('category', ''),
template.get('industry', ''),
template.get('profession', ''),
template.get('sub_category', ''),
template.get('system_prompt', '')
]
)
search_results.append({
**template,
'_match_score': match_score
})
# 按匹配度排序
search_results.sort(key=lambda x: x['_match_score'], reverse=True)
# 移除匹配度信息
search_results = [{k: v for k, v in t.items() if k != '_match_score'}
for t in search_results]
# 计算分页
total = len(search_results)
total_pages = (total + size - 1) // size
start = (page - 1) * size
end = start + size
paginated_results = search_results[start:end]
# 返回结果
return jsonify({
'code': 200,
'data': {
'total': total,
'pages': total_pages,
'current_page': page,
'keyword': keyword,
'templates': paginated_results
}
})
if __name__ == '__main__':
app.run(debug=True, port=5000)

File diff suppressed because it is too large Load Diff

View File

@@ -686,4 +686,487 @@ def wx_delete_prompt(prompt_id):
'data': None
})
@main_bp.route('/api/wx/prompts/search', methods=['GET'])
def wx_search_prompts():
"""搜索提示词接口"""
try:
# 获取参数
uid = request.args.get('uid')
keyword = request.args.get('keyword', '').strip()
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
if not uid:
return jsonify({
'code': 400,
'message': '缺少用户ID',
'data': None
})
# 构建查询
query = Prompt.query.filter_by(wx_user_id=uid)
# 如果有关键词,添加搜索条件
if keyword:
search_condition = (
Prompt.input_text.ilike(f'%{keyword}%') | # 搜索输入文本
Prompt.generated_text.ilike(f'%{keyword}%') # 搜索生成的提示词
)
query = query.filter(search_condition)
# 按时间倒序排序并分页
query = query.order_by(Prompt.created_at.desc())
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
prompts = pagination.items
return jsonify({
'code': 200,
'message': 'success',
'data': {
'prompts': [{
'id': p.id,
'input_text': p.input_text,
'generated_text': p.generated_text,
'created_at': p.created_at.strftime('%Y-%m-%d %H:%M:%S')
} for p in prompts],
'pagination': {
'total': pagination.total,
'pages': pagination.pages,
'current_page': page,
'per_page': per_page,
'has_next': pagination.has_next,
'has_prev': pagination.has_prev
}
}
})
except Exception as e:
current_app.logger.error(f"搜索提示词失败: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
})
@main_bp.route('/api/wx/templates/search', methods=['GET'])
def wx_search_templates():
"""搜索提示词模板接口"""
try:
# 获取搜索参数
keyword = request.args.get('keyword', '').strip()
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
# 构建基础查询
query = PromptTemplate.query
# 添加搜索条件
if keyword:
search_condition = (
PromptTemplate.name.ilike(f'%{keyword}%') | # 搜索模板名称
PromptTemplate.description.ilike(f'%{keyword}%') | # 搜索模板描述
PromptTemplate.category.ilike(f'%{keyword}%') | # 搜索分类
PromptTemplate.industry.ilike(f'%{keyword}%') | # 搜索行业
PromptTemplate.profession.ilike(f'%{keyword}%') | # 搜索职业
PromptTemplate.system_prompt.ilike(f'%{keyword}%') # 搜索系统提示词
)
query = query.filter(search_condition)
# 获取筛选参数(可选)
industry = request.args.get('industry')
profession = request.args.get('profession')
category = request.args.get('category')
# 添加筛选条件
if industry:
query = query.filter_by(industry=industry)
if profession:
query = query.filter_by(profession=profession)
if category:
query = query.filter_by(category=category)
# 按是否默认模板和创建时间排序
query = query.order_by(PromptTemplate.is_default.desc(),
PromptTemplate.created_at.desc())
# 分页
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
templates = pagination.items
return jsonify({
'code': 200,
'message': 'success',
'data': {
'templates': [{
'id': t.id,
'name': t.name,
'description': t.description,
'system_prompt': t.system_prompt, # 添加system_prompt字段
'category': t.category,
'industry': t.industry,
'profession': t.profession,
'sub_category': t.sub_category,
'is_default': t.is_default,
'created_at': t.created_at.strftime('%Y-%m-%d %H:%M:%S') if t.created_at else None
} for t in templates],
'pagination': {
'total': pagination.total,
'pages': pagination.pages,
'current_page': page,
'per_page': per_page,
'has_next': pagination.has_next,
'has_prev': pagination.has_prev
}
}
})
except Exception as e:
current_app.logger.error(f"搜索模板失败: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
})
@main_bp.route('/api/wx/templates/intent', methods=['POST'])
def wx_get_template_by_intent():
"""根据意图获取提示词模板"""
try:
# 获取参数
data = request.get_json()
user_input = data.get('input_text', '').strip()
# 意图识别系统提示词
intent_system_prompt = """你是一位出色的意图识别专家。请分析用户输入的意图,并仅返回以下类别之一:
- 新闻获取
- 生成图片
- 网站开发
- 文案创作
- 代码开发
- 数据分析
- 市场营销
- 产品设计
- 其它
只返回分类名称,不要其他任何内容。"""
# 调用意图识别
response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": intent_system_prompt},
{"role": "user", "content": user_input}
],
temperature=0.1
)
intent = response.choices[0].message.content.strip()
# 根据意图获取对应的模板提示词
intent_prompts = {
"新闻获取": "你是一位专业的新闻编辑,擅长整理和总结新闻信息。请帮助用户获取和理解新闻内容,注意:\n1. 确保信息的准确性和时效性\n2. 提供客观中立的视角\n3. 突出重要信息要点\n4. 适当添加背景信息解释",
"生成图片": "你是一位专业的图像生成提示词专家,擅长将文字需求转化为详细的图像生成提示词。请注意:\n1. 详细描述图像的视觉元素\n2. 指定图像的风格和氛围\n3. 添加技术参数说明\n4. 包含构图和视角建议",
"网站开发": "你是一位专业的网站开发专家,擅长将需求转化为具体的开发方案。请注意:\n1. 明确网站的目标用户和核心功能\n2. 建议合适的技术栈\n3. 考虑性能和安全性要求\n4. 提供响应式设计建议",
"文案创作": "你是一位专业的文案创作专家,擅长创作各类营销和品牌文案。请注意:\n1. 确定目标受众和传播渠道\n2. 突出产品/服务的核心价值\n3. 使用适当的语言风格\n4. 注意文案的节奏和结构",
"代码开发": "你是一位专业的软件开发工程师,擅长编写高质量的代码。请注意:\n1. 遵循编码规范和最佳实践\n2. 考虑代码的可维护性和扩展性\n3. 注重性能优化\n4. 添加适当的注释和文档",
"数据分析": "你是一位专业的数据分析师,擅长数据处理和分析。请注意:\n1. 明确分析目标和范围\n2. 选择合适的分析方法\n3. 关注数据质量和准确性\n4. 提供可操作的洞察建议",
"市场营销": "你是一位专业的市场营销专家,擅长制定营销策略。请注意:\n1. 分析目标市场和竞争环境\n2. 制定明确的营销目标\n3. 选择合适的营销渠道\n4. 设计有效的营销活动",
"产品设计": "你是一位专业的产品设计师,擅长用户体验和界面设计。请注意:\n1. 理解用户需求和痛点\n2. 遵循设计原则和规范\n3. 注重交互体验\n4. 考虑可实现性",
"其它": "你是一位专业的AI助手擅长理解和解决各类问题。请注意\n1. 仔细理解用户需求\n2. 提供清晰的解决方案\n3. 使用专业的语言表达\n4. 确保回答的实用性"
}
template_prompt = intent_prompts.get(intent, intent_prompts["其它"])
return jsonify({
'code': 200,
'message': 'success',
'data': {
'intent': intent,
'template_prompt': template_prompt
}
})
except Exception as e:
current_app.logger.error(f"获取意图模板失败: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
})
@main_bp.route('/api/wx/generate/expert', methods=['POST'])
def wx_generate_expert_prompt():
"""两阶段专家提示词生成系统"""
try:
# 检查请求数据
if not request.is_json:
return jsonify({
'code': 400,
'message': '请求必须是JSON格式',
'data': None
})
data = request.get_json()
if not data:
return jsonify({
'code': 400,
'message': '请求数据为空',
'data': None
})
# 验证必要参数
user_input = data.get('input_text')
uid = data.get('uid')
if not user_input or not uid:
return jsonify({
'code': 400,
'message': '缺少必要参数input_text 或 uid',
'data': None
})
user_input = user_input.strip()
# 修改第一阶段:意图识别专家的提示词,使其更严格
intent_analyst_prompt = """你是一位资深的意图分析专家,请分析用户输入的意图和需求。
你必须严格按照以下JSON格式返回不要添加任何其他内容
{
"core_intent": "技术", // 必须是以下选项之一:技术、创意、分析、咨询
"domain": "web开发", // 具体的专业领域
"key_requirements": [ // 2-4个关键需求
"需求1",
"需求2"
],
"expected_output": "期望输出的具体形式", // 简短描述
"constraints": [ // 1-3个主要约束
"约束1",
"约束2"
],
"keywords": [ // 2-4个关键词
"关键词1",
"关键词2"
]
}
注意:
1. 严格遵守JSON格式
2. core_intent必须是四个选项之一
3. 数组至少包含1个元素
4. 所有字段都必须存在
5. 不要包含注释
6. 不要添加任何额外的文本"""
try:
# 获取意图分析结果
intent_response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": intent_analyst_prompt},
{"role": "user", "content": user_input}
],
temperature=0.1 # 降低温度,使输出更确定
)
intent_analysis_text = intent_response.choices[0].message.content.strip()
# 添加日志记录
current_app.logger.info(f"AI返回的意图分析结果: {intent_analysis_text}")
# 尝试清理和解析JSON
try:
# 移除可能的markdown代码块标记
intent_analysis_text = intent_analysis_text.replace('```json', '').replace('```', '').strip()
intent_analysis = json.loads(intent_analysis_text)
# 验证必要字段
required_fields = ['core_intent', 'domain', 'key_requirements',
'expected_output', 'constraints', 'keywords']
for field in required_fields:
if field not in intent_analysis:
raise ValueError(f"缺少必要字段: {field}")
# 验证core_intent是否为有效值
valid_intents = ['技术', '创意', '分析', '咨询']
if intent_analysis['core_intent'] not in valid_intents:
intent_analysis['core_intent'] = '技术' # 默认使用技术
# 确保数组字段非空
array_fields = ['key_requirements', 'constraints', 'keywords']
for field in array_fields:
if not isinstance(intent_analysis[field], list) or len(intent_analysis[field]) == 0:
intent_analysis[field] = ['未指定']
except json.JSONDecodeError as e:
current_app.logger.error(f"JSON解析失败: {str(e)}, 原始文本: {intent_analysis_text}")
return jsonify({
'code': 500,
'message': 'AI返回的格式有误请重试',
'data': None
})
except ValueError as e:
current_app.logger.error(f"数据验证失败: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
})
except Exception as e:
current_app.logger.error(f"意图分析失败: {str(e)}")
return jsonify({
'code': 500,
'message': '意图分析过程出错,请重试',
'data': None
})
# 第二阶段:领域专家提示生成
domain_expert_templates = {
"技术": """你是一位专业的技术领域提示工程师。基于以下意图分析,生成一个专业的技术任务提示词:
意图分析:
{analysis}
请生成的提示词包含:
1. 明确的技术背景和上下文
2. 具体的技术要求和规范
3. 性能和质量标准
4. 技术约束条件
5. 预期交付成果
6. 评估标准
使用专业技术术语,确保提示词的可执行性和可验证性。""",
"创意": """你是一位专业的创意领域提示工程师。基于以下意图分析,生成一个创意设计提示词:
意图分析:
{analysis}
请生成的提示词包含:
1. 创意方向和灵感来源
2. 风格和氛围要求
3. 目标受众定义
4. 设计元素规范
5. 创意表现形式
6. 评估标准
使用专业创意术语,确保提示词的创新性和可执行性。""",
"分析": """你是一位专业的数据分析提示工程师。基于以下意图分析,生成一个数据分析提示词:
意图分析:
{analysis}
请生成的提示词包含:
1. 分析目标和范围
2. 数据要求和规范
3. 分析方法和工具
4. 输出格式要求
5. 关键指标定义
6. 质量控制标准
使用专业分析术语,确保提示词的科学性和可操作性。""",
"咨询": """你是一位专业的咨询领域提示工程师。基于以下意图分析,生成一个咨询服务提示词:
意图分析:
{analysis}
请生成的提示词包含:
1. 咨询问题界定
2. 背景信息要求
3. 分析框架设定
4. 建议输出格式
5. 实施考虑因素
6. 效果评估标准
使用专业咨询术语,确保提示词的专业性和实用性。"""
}
# 选择领域专家模板
expert_prompt = domain_expert_templates.get(
intent_analysis['core_intent'],
"""你是一位专业的通用领域提示工程师。基于以下意图分析,生成一个专业的提示词:
意图分析:
{analysis}
请生成的提示词包含:
1. 明确的目标定义
2. 具体要求和规范
3. 质量标准
4. 约束条件
5. 预期输出
6. 评估标准
确保提示词的清晰性和可执行性。"""
)
try:
# 生成最终提示词
final_response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": expert_prompt.format(
analysis=json.dumps(intent_analysis, ensure_ascii=False, indent=2)
)},
{"role": "user", "content": user_input}
],
temperature=0.7
)
generated_prompt = final_response.choices[0].message.content.strip()
except Exception as e:
current_app.logger.error(f"生成提示词失败: {str(e)}")
return jsonify({
'code': 500,
'message': '生成提示词过程出错',
'data': None
})
try:
# 保存到数据库
prompt = Prompt(
input_text=user_input,
generated_text=generated_prompt,
wx_user_id=uid,
intent_analysis=json.dumps(intent_analysis, ensure_ascii=False),
created_at=datetime.utcnow()
)
db.session.add(prompt)
db.session.commit()
except Exception as e:
current_app.logger.error(f"保存到数据库失败: {str(e)}")
db.session.rollback()
# 即使保存失败,也返回生成的结果
return jsonify({
'code': 200,
'message': 'success',
'data': {
'prompt_id': prompt.id if 'prompt' in locals() else None,
'intent_analysis': intent_analysis,
'generated_prompt': generated_prompt,
'created_at': prompt.created_at.strftime('%Y-%m-%d %H:%M:%S') if 'prompt' in locals() else None
}
})
except Exception as e:
current_app.logger.error(f"生成专家提示词失败: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
})
# ... 其他路由保持不变,但要把 @app 改成 @main_bp ...

View File

@@ -0,0 +1,18 @@
.prompt-container {
display: flex;
flex-wrap: wrap;
gap: 16px;
padding: 20px;
max-width: 1200px;
margin: 0 auto;
}
.prompt-card {
width: calc((100% - 64px) / 5);
min-width: 180px;
margin: 0;
background: #fff;
border-radius: 8px;
padding: 12px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}

3
requirements-test.txt Normal file
View File

@@ -0,0 +1,3 @@
pytest==7.4.0
pytest-cov==4.1.0
requests==2.31.0

2
requirements.txt Normal file
View File

@@ -0,0 +1,2 @@
flask==2.0.1
flask-cors==3.0.10

24
tests/conftest.py Normal file
View File

@@ -0,0 +1,24 @@
import pytest
import os
import sys
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# 测试数据库配置
@pytest.fixture(scope="session")
def test_db():
# 这里可以设置测试数据库的配置
return {
"host": "localhost",
"database": "test_prompt_template",
"user": "test_user",
"password": "test_password"
}
# 测试客户端配置
@pytest.fixture(scope="session")
def test_client():
from flask_prompt_master import create_app
app = create_app('testing')
return app.test_client()

173
tests/test_api.py Normal file
View File

@@ -0,0 +1,173 @@
import pytest
import requests
from datetime import datetime
import json
# 测试配置
class TestConfig:
BASE_URL = "http://localhost:5000/api/v1"
TEST_TOKEN = "test_token_123" # 测试用token
HEADERS = {
"Authorization": f"Bearer {TEST_TOKEN}",
"Content-Type": "application/json"
}
# 测试夹具
@pytest.fixture
def api_client():
class APIClient:
def __init__(self):
self.base_url = TestConfig.BASE_URL
self.headers = TestConfig.HEADERS
def get(self, endpoint, params=None):
url = f"{self.base_url}{endpoint}"
return requests.get(url, headers=self.headers, params=params)
def post(self, endpoint, data):
url = f"{self.base_url}{endpoint}"
return requests.post(url, headers=self.headers, json=data)
def put(self, endpoint, data):
url = f"{self.base_url}{endpoint}"
return requests.put(url, headers=self.headers, json=data)
def delete(self, endpoint):
url = f"{self.base_url}{endpoint}"
return requests.delete(url, headers=self.headers)
return APIClient()
# 模板管理接口测试
class TestTemplateAPI:
def test_get_template_list(self, api_client):
"""测试获取模板列表"""
response = api_client.get("/templates")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "templates" in data["data"]
def test_get_template_with_filters(self, api_client):
"""测试带过滤条件的模板列表"""
params = {
"category": "软件开发",
"industry": "互联网",
"page": 1,
"size": 10
}
response = api_client.get("/templates", params=params)
assert response.status_code == 200
data = response.json()
assert "current_page" in data["data"]
assert data["data"]["current_page"] == 1
def test_create_template(self, api_client):
"""测试创建模板"""
template_data = {
"name": "测试模板",
"description": "这是一个测试模板",
"category": "软件开发",
"industry": "互联网",
"profession": "开发工程师",
"sub_category": "后端开发",
"system_prompt": "你是一个专业的后端开发工程师..."
}
response = api_client.post("/templates", template_data)
assert response.status_code == 200
data = response.json()
assert "id" in data["data"]
return data["data"]["id"]
def test_get_template_detail(self, api_client):
"""测试获取单个模板详情"""
# 先创建一个模板
template_id = self.test_create_template(api_client)
response = api_client.get(f"/templates/{template_id}")
assert response.status_code == 200
data = response.json()
assert data["data"]["name"] == "测试模板"
def test_update_template(self, api_client):
"""测试更新模板"""
# 先创建一个模板
template_id = self.test_create_template(api_client)
update_data = {
"name": "更新后的模板",
"description": "这是更新后的描述"
}
response = api_client.put(f"/templates/{template_id}", update_data)
assert response.status_code == 200
# 验证更新结果
response = api_client.get(f"/templates/{template_id}")
assert response.json()["data"]["name"] == "更新后的模板"
def test_delete_template(self, api_client):
"""测试删除模板"""
# 先创建一个模板
template_id = self.test_create_template(api_client)
response = api_client.delete(f"/templates/{template_id}")
assert response.status_code == 200
# 验证删除结果
response = api_client.get(f"/templates/{template_id}")
assert response.status_code == 404
# 分类管理接口测试
class TestCategoryAPI:
def test_get_categories(self, api_client):
"""测试获取分类列表"""
response = api_client.get("/categories")
assert response.status_code == 200
data = response.json()
assert "categories" in data["data"]
def test_create_category(self, api_client):
"""测试创建分类"""
category_data = {
"name": "测试分类",
"icon": "test-icon",
"description": "这是一个测试分类"
}
response = api_client.post("/categories", category_data)
assert response.status_code == 200
# 搜索接口测试
class TestSearchAPI:
def test_search_templates(self, api_client):
"""测试搜索模板"""
params = {
"keyword": "开发",
"category": "软件开发"
}
response = api_client.get("/search/templates", params=params)
assert response.status_code == 200
data = response.json()
assert "results" in data["data"]
# 统计接口测试
class TestStatisticsAPI:
def test_get_template_statistics(self, api_client):
"""测试获取模板统计信息"""
response = api_client.get("/statistics/templates")
assert response.status_code == 200
data = response.json()
assert "total_templates" in data["data"]
assert "category_distribution" in data["data"]
# 错误处理测试
class TestErrorHandling:
def test_invalid_token(self, api_client):
"""测试无效token"""
api_client.headers["Authorization"] = "Bearer invalid_token"
response = api_client.get("/templates")
assert response.status_code == 401
def test_invalid_request(self, api_client):
"""测试无效请求"""
invalid_data = {
"name": "" # 空名称,应该触发验证错误
}
response = api_client.post("/templates", invalid_data)
assert response.status_code == 400