272 lines
9.0 KiB
Python
272 lines
9.0 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
腾讯云数据库初始化脚本
|
||
独立运行,不依赖Flask应用
|
||
"""
|
||
import pymysql
|
||
import sys
|
||
import os
|
||
|
||
# 添加项目根目录到Python路径
|
||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||
sys.path.append(project_root)
|
||
|
||
def get_templates():
|
||
"""获取模板数据"""
|
||
# 从promptsTemplates.py文件导入模板数据
|
||
try:
|
||
from src.flask_prompt_master.promptsTemplates import templates
|
||
return templates
|
||
except ImportError:
|
||
print("❌ 无法导入模板数据,请确保在项目根目录运行此脚本")
|
||
return []
|
||
|
||
def init_tencent_database(force_insert=False):
|
||
"""初始化腾讯云数据库"""
|
||
print("🚀 开始初始化腾讯云数据库...")
|
||
|
||
# 腾讯云数据库配置
|
||
config = {
|
||
'host': 'gz-cynosdbmysql-grp-d26pzce5.sql.tencentcdb.com',
|
||
'port': 24936,
|
||
'user': 'root',
|
||
'password': '!Rjb12191',
|
||
'database': 'pro_db',
|
||
'charset': 'utf8mb4'
|
||
}
|
||
|
||
try:
|
||
# 连接数据库
|
||
print("🔗 连接到腾讯云数据库...")
|
||
conn = pymysql.connect(**config)
|
||
cursor = conn.cursor()
|
||
print("✅ 数据库连接成功")
|
||
|
||
# 创建 prompt_template 表
|
||
print("📋 创建 prompt_template 表...")
|
||
cursor.execute("""
|
||
CREATE TABLE IF NOT EXISTS prompt_template (
|
||
id INT PRIMARY KEY AUTO_INCREMENT,
|
||
name VARCHAR(100) NOT NULL,
|
||
description TEXT,
|
||
category VARCHAR(50),
|
||
industry VARCHAR(50),
|
||
profession VARCHAR(50),
|
||
sub_category VARCHAR(50),
|
||
system_prompt TEXT NOT NULL,
|
||
is_default BOOLEAN DEFAULT FALSE,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||
""")
|
||
print("✅ prompt_template 表创建/检查完成")
|
||
|
||
# 检查是否已有模板数据
|
||
cursor.execute("SELECT COUNT(*) FROM prompt_template")
|
||
count = cursor.fetchone()[0]
|
||
|
||
if count == 0:
|
||
print("📝 开始插入模板数据...")
|
||
|
||
# 获取模板数据
|
||
templates = get_templates()
|
||
if not templates:
|
||
print("❌ 无法获取模板数据,退出")
|
||
return
|
||
|
||
# 插入模板数据
|
||
sql = """
|
||
INSERT INTO prompt_template
|
||
(name, description, category, industry, profession, sub_category, system_prompt, is_default)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||
"""
|
||
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
for template in templates:
|
||
try:
|
||
cursor.execute(sql, (
|
||
template['name'],
|
||
template['description'],
|
||
template.get('category', ''),
|
||
template.get('industry', ''),
|
||
template.get('profession', ''),
|
||
template.get('sub_category', ''),
|
||
template['system_prompt'],
|
||
template.get('is_default', False)
|
||
))
|
||
success_count += 1
|
||
except Exception as e:
|
||
print(f"⚠️ 插入模板 '{template['name']}' 失败: {str(e)}")
|
||
error_count += 1
|
||
|
||
print(f"✅ 成功插入 {success_count} 个模板数据!")
|
||
if error_count > 0:
|
||
print(f"⚠️ {error_count} 个模板插入失败")
|
||
else:
|
||
print(f"ℹ️ 模板数据已存在 ({count} 条记录),跳过初始化。")
|
||
|
||
# 提交事务
|
||
conn.commit()
|
||
print("🎉 腾讯云数据库初始化完成!")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 初始化数据库失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
if 'conn' in locals():
|
||
conn.rollback()
|
||
finally:
|
||
if 'cursor' in locals():
|
||
cursor.close()
|
||
if 'conn' in locals():
|
||
conn.close()
|
||
|
||
def init_local_database():
|
||
"""初始化本地数据库"""
|
||
print("🚀 开始初始化本地数据库...")
|
||
|
||
# 本地数据库配置
|
||
config = {
|
||
'host': 'localhost',
|
||
'user': 'root',
|
||
'password': '123456',
|
||
'database': 'pro_db',
|
||
'charset': 'utf8mb4'
|
||
}
|
||
|
||
try:
|
||
# 连接数据库
|
||
print("🔗 连接到本地数据库...")
|
||
conn = pymysql.connect(**config)
|
||
cursor = conn.cursor()
|
||
print("✅ 数据库连接成功")
|
||
|
||
# 创建 prompt_template 表
|
||
print("📋 创建 prompt_template 表...")
|
||
cursor.execute("""
|
||
CREATE TABLE IF NOT EXISTS prompt_template (
|
||
id INT PRIMARY KEY AUTO_INCREMENT,
|
||
name VARCHAR(100) NOT NULL,
|
||
description TEXT,
|
||
category VARCHAR(50),
|
||
industry VARCHAR(50),
|
||
profession VARCHAR(50),
|
||
sub_category VARCHAR(50),
|
||
system_prompt TEXT NOT NULL,
|
||
is_default BOOLEAN DEFAULT FALSE,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||
""")
|
||
print("✅ prompt_template 表创建/检查完成")
|
||
|
||
# 检查是否已有模板数据
|
||
cursor.execute("SELECT COUNT(*) FROM prompt_template")
|
||
count = cursor.fetchone()[0]
|
||
|
||
if count == 0:
|
||
print("📝 开始插入模板数据...")
|
||
|
||
# 获取模板数据
|
||
templates = get_templates()
|
||
if not templates:
|
||
print("❌ 无法获取模板数据,退出")
|
||
return
|
||
|
||
# 插入模板数据
|
||
sql = """
|
||
INSERT INTO prompt_template
|
||
(name, description, category, industry, profession, sub_category, system_prompt, is_default)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||
"""
|
||
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
for template in templates:
|
||
try:
|
||
cursor.execute(sql, (
|
||
template['name'],
|
||
template['description'],
|
||
template.get('category', ''),
|
||
template.get('industry', ''),
|
||
template.get('profession', ''),
|
||
template.get('sub_category', ''),
|
||
template['system_prompt'],
|
||
template.get('is_default', False)
|
||
))
|
||
success_count += 1
|
||
except Exception as e:
|
||
print(f"⚠️ 插入模板 '{template['name']}' 失败: {str(e)}")
|
||
error_count += 1
|
||
|
||
print(f"✅ 成功插入 {success_count} 个模板数据!")
|
||
if error_count > 0:
|
||
print(f"⚠️ {error_count} 个模板插入失败")
|
||
else:
|
||
print(f"ℹ️ 模板数据已存在 ({count} 条记录),跳过初始化。")
|
||
|
||
# 提交事务
|
||
conn.commit()
|
||
print("🎉 本地数据库初始化完成!")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 初始化数据库失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
if 'conn' in locals():
|
||
conn.rollback()
|
||
finally:
|
||
if 'cursor' in locals():
|
||
cursor.close()
|
||
if 'conn' in locals():
|
||
conn.close()
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("=" * 60)
|
||
print("🗄️ 数据库初始化工具")
|
||
print("=" * 60)
|
||
|
||
if len(sys.argv) > 1:
|
||
db_type = sys.argv[1].lower()
|
||
if db_type in ['tencent', 't']:
|
||
init_tencent_database()
|
||
elif db_type in ['local', 'l']:
|
||
init_local_database()
|
||
else:
|
||
print("❌ 无效的数据库类型参数")
|
||
print("用法: python init_tencent_db.py [local|tencent]")
|
||
print(" local 或 l - 初始化本地数据库")
|
||
print(" tencent 或 t - 初始化腾讯云数据库")
|
||
sys.exit(1)
|
||
else:
|
||
# 交互式选择
|
||
while True:
|
||
print("\n请选择要初始化的数据库:")
|
||
print("1. 本地数据库 (localhost)")
|
||
print("2. 腾讯云数据库")
|
||
print("3. 退出")
|
||
|
||
choice = input("\n请输入选择 (1-3): ").strip()
|
||
|
||
if choice == '1':
|
||
print("\n" + "="*40)
|
||
init_local_database()
|
||
print("="*40)
|
||
break
|
||
elif choice == '2':
|
||
print("\n" + "="*40)
|
||
init_tencent_database()
|
||
print("="*40)
|
||
break
|
||
elif choice == '3':
|
||
print("👋 退出程序")
|
||
break
|
||
else:
|
||
print("❌ 无效选择,请重新输入")
|
||
|
||
if __name__ == '__main__':
|
||
main()
|