Files
aitsc/src/flask_prompt_master/app.py

239 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from flask import Flask, jsonify, request
from flask_cors import CORS
from datetime import datetime
app = Flask(__name__)
CORS(app) # 允许跨域请求
# 模拟数据库中的模板数据
MOCK_TEMPLATES = [
{
"id": "1",
"name": "后端开发工程师",
"description": "专业的后端开发工程师模板",
"category": "软件开发",
"industry": "互联网",
"profession": "开发工程师",
"sub_category": "后端开发",
"system_prompt": "你是一位专业的后端开发工程师..."
},
{
"id": "2",
"name": "前端开发工程师",
"description": "专业的前端开发工程师模板",
"category": "软件开发",
"industry": "互联网",
"profession": "开发工程师",
"sub_category": "前端开发",
"system_prompt": "你是一位专业的前端开发工程师..."
}
]
# 身份验证装饰器
def require_auth(f):
from functools import wraps
@wraps(f)
def decorated(*args, **kwargs):
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({'error': 'No authorization token provided'}), 401
# 这里简单验证token实际应该更严格
token = auth_header.split(' ')[1]
if token != 'test_token_123':
return jsonify({'error': 'Invalid token'}), 401
return f(*args, **kwargs)
return decorated
# 获取模板列表接口
@app.route('/api/v1/templates', methods=['GET'])
@require_auth
def get_templates():
# 获取查询参数
page = int(request.args.get('page', 1))
size = int(request.args.get('size', 10))
category = request.args.get('category')
industry = request.args.get('industry')
profession = request.args.get('profession')
# 过滤模板
filtered_templates = MOCK_TEMPLATES
if category:
filtered_templates = [t for t in filtered_templates if t['category'] == category]
if industry:
filtered_templates = [t for t in filtered_templates if t['industry'] == industry]
if profession:
filtered_templates = [t for t in filtered_templates if t['profession'] == profession]
# 计算分页
start = (page - 1) * size
end = start + size
paginated_templates = filtered_templates[start:end]
return jsonify({
'code': 200,
'data': {
'total': len(filtered_templates),
'pages': (len(filtered_templates) + size - 1) // size,
'current_page': page,
'templates': paginated_templates
}
})
# 获取单个模板接口
@app.route('/api/v1/templates/<template_id>', methods=['GET'])
@require_auth
def get_template(template_id):
template = next((t for t in MOCK_TEMPLATES if t['id'] == template_id), None)
if not template:
return jsonify({'error': 'Template not found'}), 404
return jsonify({
'code': 200,
'data': template
})
# 创建模板接口
@app.route('/api/v1/templates', methods=['POST'])
@require_auth
def create_template():
data = request.get_json()
# 验证必要字段
required_fields = ['name', 'description', 'category', 'industry', 'profession']
for field in required_fields:
if field not in data:
return jsonify({'error': f'Missing required field: {field}'}), 400
# 创建新模板
new_template = {
'id': str(len(MOCK_TEMPLATES) + 1),
**data,
'created_at': datetime.now().isoformat()
}
MOCK_TEMPLATES.append(new_template)
return jsonify({
'code': 200,
'data': {
'id': new_template['id'],
'message': 'Template created successfully'
}
})
# 更新模板接口
@app.route('/api/v1/templates/<template_id>', methods=['PUT'])
@require_auth
def update_template(template_id):
template = next((t for t in MOCK_TEMPLATES if t['id'] == template_id), None)
if not template:
return jsonify({'error': 'Template not found'}), 404
data = request.get_json()
template.update(data)
return jsonify({
'code': 200,
'data': {
'message': 'Template updated successfully'
}
})
# 删除模板接口
@app.route('/api/v1/templates/<template_id>', methods=['DELETE'])
@require_auth
def delete_template(template_id):
template_index = next((i for i, t in enumerate(MOCK_TEMPLATES) if t['id'] == template_id), None)
if template_index is None:
return jsonify({'error': 'Template not found'}), 404
MOCK_TEMPLATES.pop(template_index)
return jsonify({
'code': 200,
'data': {
'message': 'Template deleted successfully'
}
})
# 模糊搜索提示词模板接口
@app.route('/api/v1/templates/search', methods=['GET'])
@require_auth
def search_templates():
"""模糊搜索提示词模板"""
# 获取搜索参数
keyword = request.args.get('keyword', '').strip()
page = int(request.args.get('page', 1))
size = int(request.args.get('size', 10))
# 如果没有关键词,返回空结果
if not keyword:
return jsonify({
'code': 200,
'data': {
'total': 0,
'pages': 0,
'current_page': page,
'templates': []
}
})
# 执行模糊搜索
# 搜索范围:名称、描述、分类、行业、职业、子分类、系统提示词
search_results = []
for template in MOCK_TEMPLATES:
# 在各个字段中搜索关键词
if any(keyword.lower() in str(value).lower() for value in [
template.get('name', ''),
template.get('description', ''),
template.get('category', ''),
template.get('industry', ''),
template.get('profession', ''),
template.get('sub_category', ''),
template.get('system_prompt', '')
]):
# 添加匹配度信息
match_score = sum(
str(value).lower().count(keyword.lower())
for value in [
template.get('name', ''),
template.get('description', ''),
template.get('category', ''),
template.get('industry', ''),
template.get('profession', ''),
template.get('sub_category', ''),
template.get('system_prompt', '')
]
)
search_results.append({
**template,
'_match_score': match_score
})
# 按匹配度排序
search_results.sort(key=lambda x: x['_match_score'], reverse=True)
# 移除匹配度信息
search_results = [{k: v for k, v in t.items() if k != '_match_score'}
for t in search_results]
# 计算分页
total = len(search_results)
total_pages = (total + size - 1) // size
start = (page - 1) * size
end = start + size
paginated_results = search_results[start:end]
# 返回结果
return jsonify({
'code': 200,
'data': {
'total': total,
'pages': total_pages,
'current_page': page,
'keyword': keyword,
'templates': paginated_results
}
})
if __name__ == '__main__':
app.run(debug=True, port=5000)