98 lines
3.2 KiB
Python
98 lines
3.2 KiB
Python
import pymysql
|
|
from flask_prompt_master.init_db import templates
|
|
|
|
def sync_templates():
|
|
"""同步提示词模板数据到数据库"""
|
|
try:
|
|
# 连接数据库
|
|
conn = pymysql.connect(
|
|
host='localhost',
|
|
user='root',
|
|
password='123456',
|
|
database='food_db',
|
|
charset='utf8mb4'
|
|
)
|
|
cursor = conn.cursor()
|
|
|
|
# 获取现有模板
|
|
cursor.execute("SELECT name FROM prompt_template")
|
|
existing_templates = {row[0] for row in cursor.fetchall()}
|
|
|
|
# 准备插入和更新的SQL语句
|
|
insert_sql = """
|
|
INSERT INTO prompt_template
|
|
(name, description, category, industry, profession, sub_category, system_prompt, is_default)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
|
"""
|
|
|
|
update_sql = """
|
|
UPDATE prompt_template
|
|
SET description = %s,
|
|
category = %s,
|
|
industry = %s,
|
|
profession = %s,
|
|
sub_category = %s,
|
|
system_prompt = %s,
|
|
is_default = %s
|
|
WHERE name = %s
|
|
"""
|
|
|
|
# 统计计数
|
|
inserted_count = 0
|
|
updated_count = 0
|
|
|
|
# 遍历模板数据
|
|
for template in templates:
|
|
template_name = template['name']
|
|
|
|
if template_name in existing_templates:
|
|
# 更新现有模板
|
|
cursor.execute(update_sql, (
|
|
template['description'],
|
|
template.get('category', ''),
|
|
template.get('industry', ''),
|
|
template.get('profession', ''),
|
|
template.get('sub_category', ''),
|
|
template['system_prompt'],
|
|
template.get('is_default', False),
|
|
template_name
|
|
))
|
|
updated_count += 1
|
|
print(f"更新模板: {template_name}")
|
|
else:
|
|
# 插入新模板
|
|
cursor.execute(insert_sql, (
|
|
template_name,
|
|
template['description'],
|
|
template.get('category', ''),
|
|
template.get('industry', ''),
|
|
template.get('profession', ''),
|
|
template.get('sub_category', ''),
|
|
template['system_prompt'],
|
|
template.get('is_default', False)
|
|
))
|
|
inserted_count += 1
|
|
print(f"新增模板: {template_name}")
|
|
|
|
# 提交事务
|
|
conn.commit()
|
|
|
|
# 打印同步结果
|
|
print("\n=== 模板同步完成 ===")
|
|
print(f"新增模板数: {inserted_count}")
|
|
print(f"更新模板数: {updated_count}")
|
|
print(f"总模板数: {len(templates)}")
|
|
print("===================")
|
|
|
|
except Exception as e:
|
|
print(f"同步模板失败: {str(e)}")
|
|
if 'conn' in locals():
|
|
conn.rollback()
|
|
finally:
|
|
if 'cursor' in locals():
|
|
cursor.close()
|
|
if 'conn' in locals():
|
|
conn.close()
|
|
|
|
if __name__ == '__main__':
|
|
sync_templates() |