426 lines
13 KiB
Python
426 lines
13 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
数据库同步脚本:本地 ↔ 云端
|
|||
|
|
|
|||
|
|
用法:
|
|||
|
|
python scripts/sync_db.py push # 本地→云端(核心+用户数据)
|
|||
|
|
python scripts/sync_db.py push --tables all # 全部表
|
|||
|
|
python scripts/sync_db.py push --tables agents,users
|
|||
|
|
python scripts/sync_db.py push --exclude executions,agent_llm_logs
|
|||
|
|
python scripts/sync_db.py pull # 云端→本地
|
|||
|
|
python scripts/sync_db.py push --dry-run # 预览
|
|||
|
|
python scripts/sync_db.py push --mode replace
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
import sys
|
|||
|
|
from collections import OrderedDict
|
|||
|
|
from typing import Dict, List, Tuple
|
|||
|
|
|
|||
|
|
import pymysql
|
|||
|
|
|
|||
|
|
# ── 从 .env 读取数据库连接 ──────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|||
|
|
_ENV_FILE = os.path.join(os.path.dirname(_SCRIPT_DIR), ".env")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _load_env() -> Dict[str, str]:
|
|||
|
|
"""解析 .env 文件,返回键值对。"""
|
|||
|
|
env: Dict[str, str] = {}
|
|||
|
|
if not os.path.exists(_ENV_FILE):
|
|||
|
|
return env
|
|||
|
|
with open(_ENV_FILE, "r", encoding="utf-8") as f:
|
|||
|
|
for line in f:
|
|||
|
|
line = line.strip()
|
|||
|
|
if not line or line.startswith("#"):
|
|||
|
|
continue
|
|||
|
|
m = re.match(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(.+)$", line)
|
|||
|
|
if m:
|
|||
|
|
env[m.group(1)] = m.group(2).strip("\"'")
|
|||
|
|
return env
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_db_url(url: str) -> dict:
|
|||
|
|
"""解析 mysql+pymysql://user:pass@host:port/db?params → pymysql.connect kwargs."""
|
|||
|
|
m = re.match(
|
|||
|
|
r"mysql\+pymysql://(?P<user>[^:]+):(?P<password>[^@]+)"
|
|||
|
|
r"@(?P<host>[^:/]+):?(?P<port>\d+)?/(?P<database>[^?]+)",
|
|||
|
|
url,
|
|||
|
|
)
|
|||
|
|
if not m:
|
|||
|
|
raise ValueError(f"无法解析数据库 URL: {url}")
|
|||
|
|
return {
|
|||
|
|
"host": m.group("host"),
|
|||
|
|
"port": int(m.group("port") or 3306),
|
|||
|
|
"user": m.group("user"),
|
|||
|
|
"password": m.group("password"),
|
|||
|
|
"database": m.group("database"),
|
|||
|
|
"charset": "utf8mb4",
|
|||
|
|
"connect_timeout": 30,
|
|||
|
|
"read_timeout": 60,
|
|||
|
|
"write_timeout": 60,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 表分组配置 ───────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
TABLE_GROUPS: Dict[str, List[str]] = {
|
|||
|
|
"core": [
|
|||
|
|
"roles",
|
|||
|
|
"permissions",
|
|||
|
|
"role_permissions",
|
|||
|
|
"tools",
|
|||
|
|
"node_templates",
|
|||
|
|
"node_plugins",
|
|||
|
|
"workflow_templates",
|
|||
|
|
"template_favorites",
|
|||
|
|
"template_ratings",
|
|||
|
|
"orchestration_templates",
|
|||
|
|
"agents",
|
|||
|
|
"agent_extensions",
|
|||
|
|
"agent_favorites",
|
|||
|
|
"agent_ratings",
|
|||
|
|
"agent_permissions",
|
|||
|
|
"knowledge_bases",
|
|||
|
|
"kb_documents",
|
|||
|
|
"kb_document_chunks",
|
|||
|
|
"data_sources",
|
|||
|
|
"model_configs",
|
|||
|
|
"scene_contracts",
|
|||
|
|
"alert_rules",
|
|||
|
|
"alert_logs",
|
|||
|
|
"workflows",
|
|||
|
|
"workflow_versions",
|
|||
|
|
"workflow_permissions",
|
|||
|
|
"conversation_branches",
|
|||
|
|
"push_subscriptions",
|
|||
|
|
"fcm_tokens",
|
|||
|
|
],
|
|||
|
|
"user_data": [
|
|||
|
|
"users",
|
|||
|
|
"user_roles",
|
|||
|
|
"user_feishu_open_ids",
|
|||
|
|
"user_fingerprints",
|
|||
|
|
"workspaces",
|
|||
|
|
"workspace_memberships",
|
|||
|
|
"teams",
|
|||
|
|
"team_members",
|
|||
|
|
"goals",
|
|||
|
|
"tasks",
|
|||
|
|
"agent_schedules",
|
|||
|
|
"agent_execution_logs",
|
|||
|
|
"feedback_records",
|
|||
|
|
],
|
|||
|
|
"knowledge": [
|
|||
|
|
"global_knowledge",
|
|||
|
|
"knowledge_entries",
|
|||
|
|
"knowledge_entities",
|
|||
|
|
"knowledge_relations",
|
|||
|
|
],
|
|||
|
|
"logs": [
|
|||
|
|
"executions",
|
|||
|
|
"execution_logs",
|
|||
|
|
"agent_execution_logs",
|
|||
|
|
"agent_llm_logs",
|
|||
|
|
"agent_learning_patterns",
|
|||
|
|
"agent_vector_memories",
|
|||
|
|
"audit_logs",
|
|||
|
|
"notifications",
|
|||
|
|
"user_behavior_logs",
|
|||
|
|
"persistent_user_memories",
|
|||
|
|
"shadow_comparisons",
|
|||
|
|
"chat_messages",
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 不需要同步的系统表
|
|||
|
|
SYSTEM_SKIP = {"alembic_version"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def resolve_tables(tables_arg: str | None, exclude_arg: str | None) -> List[str]:
|
|||
|
|
"""解析命令行参数,返回要同步的表列表(按 FK 依赖排序)。"""
|
|||
|
|
if tables_arg == "all" or tables_arg is None:
|
|||
|
|
selected = set()
|
|||
|
|
if tables_arg is None:
|
|||
|
|
# 默认:core + user_data + knowledge(排除 logs)
|
|||
|
|
for group in ["core", "user_data", "knowledge"]:
|
|||
|
|
selected.update(TABLE_GROUPS.get(group, []))
|
|||
|
|
else:
|
|||
|
|
# all:所有表
|
|||
|
|
for tables in TABLE_GROUPS.values():
|
|||
|
|
selected.update(tables)
|
|||
|
|
else:
|
|||
|
|
selected = {t.strip() for t in tables_arg.split(",") if t.strip()}
|
|||
|
|
|
|||
|
|
if exclude_arg:
|
|||
|
|
excluded = {t.strip() for t in exclude_arg.split(",") if t.strip()}
|
|||
|
|
selected -= excluded
|
|||
|
|
|
|||
|
|
selected -= SYSTEM_SKIP
|
|||
|
|
return _sort_by_fk(selected)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# FK 依赖顺序(主表在前)
|
|||
|
|
FK_ORDER: List[str] = [
|
|||
|
|
"roles", "permissions", "role_permissions", "users", "user_roles",
|
|||
|
|
"user_feishu_open_ids", "user_fingerprints", "workspaces",
|
|||
|
|
"workspace_memberships", "teams", "team_members",
|
|||
|
|
"workflow_templates", "template_favorites", "template_ratings",
|
|||
|
|
"node_templates", "node_plugins", "orchestration_templates",
|
|||
|
|
"tools", "data_sources", "model_configs",
|
|||
|
|
"knowledge_bases", "kb_documents", "kb_document_chunks",
|
|||
|
|
"global_knowledge", "knowledge_entries", "knowledge_entities",
|
|||
|
|
"knowledge_relations",
|
|||
|
|
"agents", "agent_extensions", "agent_favorites", "agent_ratings",
|
|||
|
|
"agent_permissions", "agent_schedules",
|
|||
|
|
"workflows", "workflow_versions", "workflow_permissions",
|
|||
|
|
"executions", "execution_logs",
|
|||
|
|
"agent_execution_logs", "agent_llm_logs", "agent_learning_patterns",
|
|||
|
|
"agent_vector_memories",
|
|||
|
|
"goals", "tasks",
|
|||
|
|
"alert_rules", "alert_logs",
|
|||
|
|
"notifications", "audit_logs", "user_behavior_logs",
|
|||
|
|
"persistent_user_memories", "shadow_comparisons",
|
|||
|
|
"feedback_records",
|
|||
|
|
"conversation_branches",
|
|||
|
|
"scene_contracts",
|
|||
|
|
"push_subscriptions", "fcm_tokens",
|
|||
|
|
"chat_messages",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _sort_by_fk(table_names: set) -> List[str]:
|
|||
|
|
"""按 FK 依赖顺序排列表名。"""
|
|||
|
|
order_map = {name: i for i, name in enumerate(FK_ORDER)}
|
|||
|
|
sorted_tables = sorted(
|
|||
|
|
table_names, key=lambda t: order_map.get(t, len(FK_ORDER))
|
|||
|
|
)
|
|||
|
|
return sorted_tables
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 核心同步逻辑 ──────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_connection(env: dict, key: str = "DATABASE_URL") -> pymysql.Connection:
|
|||
|
|
kwargs = _parse_db_url(env[key])
|
|||
|
|
return pymysql.connect(**kwargs)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_column_names(cursor, table: str) -> List[str]:
|
|||
|
|
cursor.execute(f"DESC `{table}`")
|
|||
|
|
return [row[0] for row in cursor.fetchall()]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_primary_key(cursor, table: str) -> str | None:
|
|||
|
|
cursor.execute(f"DESC `{table}`")
|
|||
|
|
for row in cursor.fetchall():
|
|||
|
|
if row[3] == "PRI": # Key column
|
|||
|
|
return row[0]
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def fetch_table(cursor, table: str, columns: List[str]) -> List[Tuple]:
|
|||
|
|
cols = ", ".join(f"`{c}`" for c in columns)
|
|||
|
|
cursor.execute(f"SELECT {cols} FROM `{table}`")
|
|||
|
|
return cursor.fetchall()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def sync_table(
|
|||
|
|
src_conn: pymysql.Connection,
|
|||
|
|
dst_conn: pymysql.Connection,
|
|||
|
|
table: str,
|
|||
|
|
mode: str = "upsert",
|
|||
|
|
dry_run: bool = False,
|
|||
|
|
):
|
|||
|
|
"""同步单张表的数据。"""
|
|||
|
|
src_cur = src_conn.cursor()
|
|||
|
|
dst_cur = dst_conn.cursor()
|
|||
|
|
|
|||
|
|
columns = get_column_names(src_cur, table)
|
|||
|
|
if not columns:
|
|||
|
|
print(f" [SKIP] {table}: 无法读取列信息")
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
pk = get_primary_key(src_cur, table)
|
|||
|
|
rows = fetch_table(src_cur, table, columns)
|
|||
|
|
row_count = len(rows)
|
|||
|
|
|
|||
|
|
if row_count == 0:
|
|||
|
|
print(f" [SKIP] {table}: 源表无数据")
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
col_placeholders = ", ".join(["%s"] * len(columns))
|
|||
|
|
cols_quoted = ", ".join(f"`{c}`" for c in columns)
|
|||
|
|
|
|||
|
|
if dry_run:
|
|||
|
|
print(f" [DRY-RUN] {table}: {row_count} 行 (模式: {mode})")
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if mode == "replace":
|
|||
|
|
dst_cur.execute(f"DELETE FROM `{table}`")
|
|||
|
|
insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
|
|||
|
|
dst_cur.executemany(insert_sql, rows)
|
|||
|
|
|
|||
|
|
elif mode == "upsert" and pk:
|
|||
|
|
update_cols = [c for c in columns if c != pk]
|
|||
|
|
update_clause = ", ".join(f"`{c}` = VALUES(`{c}`)" for c in update_cols)
|
|||
|
|
insert_sql = (
|
|||
|
|
f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders}) "
|
|||
|
|
f"ON DUPLICATE KEY UPDATE {update_clause}"
|
|||
|
|
)
|
|||
|
|
dst_cur.executemany(insert_sql, rows)
|
|||
|
|
|
|||
|
|
elif mode == "append":
|
|||
|
|
insert_sql = f"INSERT IGNORE INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
|
|||
|
|
dst_cur.executemany(insert_sql, rows)
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
# 没有主键的表,直接用 replace 模式
|
|||
|
|
dst_cur.execute(f"DELETE FROM `{table}`")
|
|||
|
|
insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
|
|||
|
|
dst_cur.executemany(insert_sql, rows)
|
|||
|
|
|
|||
|
|
print(f" [OK] {table}: {row_count} 行")
|
|||
|
|
return row_count
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f" [ERROR] {table}: {e}")
|
|||
|
|
return -1
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_sync(
|
|||
|
|
env: dict,
|
|||
|
|
tables: List[str],
|
|||
|
|
direction: str,
|
|||
|
|
mode: str,
|
|||
|
|
dry_run: bool,
|
|||
|
|
):
|
|||
|
|
"""执行所有表的同步。"""
|
|||
|
|
if direction == "push":
|
|||
|
|
src_label, dst_label = "本地", "云端"
|
|||
|
|
src_key, dst_key = "DATABASE_URL", "CLOUD_DATABASE_URL"
|
|||
|
|
else:
|
|||
|
|
src_label, dst_label = "云端", "本地"
|
|||
|
|
src_key, dst_key = "CLOUD_DATABASE_URL", "DATABASE_URL"
|
|||
|
|
|
|||
|
|
if dst_key not in env:
|
|||
|
|
print(f"错误: .env 中未配置 {dst_key}")
|
|||
|
|
sys.exit(1)
|
|||
|
|
|
|||
|
|
print(f"同步方向: {src_label} → {dst_label}")
|
|||
|
|
print(f"模式: {mode} | 表数: {len(tables)} | {'预览模式' if dry_run else '执行模式'}")
|
|||
|
|
print(f"表列表: {', '.join(tables)}")
|
|||
|
|
print("-" * 60)
|
|||
|
|
|
|||
|
|
src_conn = get_connection(env, src_key)
|
|||
|
|
dst_conn = get_connection(env, dst_key)
|
|||
|
|
|
|||
|
|
# 关闭外键检查
|
|||
|
|
if not dry_run:
|
|||
|
|
dst_cur = dst_conn.cursor()
|
|||
|
|
dst_cur.execute("SET FOREIGN_KEY_CHECKS = 0")
|
|||
|
|
|
|||
|
|
total_rows = 0
|
|||
|
|
errors = 0
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
for table in tables:
|
|||
|
|
result = sync_table(src_conn, dst_conn, table, mode, dry_run)
|
|||
|
|
if result > 0:
|
|||
|
|
total_rows += result
|
|||
|
|
elif result < 0:
|
|||
|
|
errors += 1
|
|||
|
|
|
|||
|
|
if not dry_run:
|
|||
|
|
dst_conn.commit()
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
if not dry_run:
|
|||
|
|
dst_cur = dst_conn.cursor()
|
|||
|
|
dst_cur.execute("SET FOREIGN_KEY_CHECKS = 1")
|
|||
|
|
src_conn.close()
|
|||
|
|
dst_conn.close()
|
|||
|
|
|
|||
|
|
print("-" * 60)
|
|||
|
|
status = "预览完成" if dry_run else "同步完成"
|
|||
|
|
if errors:
|
|||
|
|
print(f"{status}: {total_rows} 行, {errors} 个错误")
|
|||
|
|
else:
|
|||
|
|
print(f"{status}: {total_rows} 行, 无错误")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── CLI ───────────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser(
|
|||
|
|
description="数据库同步脚本:本地 ↔ 云端",
|
|||
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|||
|
|
epilog="""
|
|||
|
|
示例:
|
|||
|
|
%(prog)s push # 本地→云端(核心+用户数据+知识库)
|
|||
|
|
%(prog)s push --tables all # 同步所有表
|
|||
|
|
%(prog)s push --tables agents,users # 只同步指定表
|
|||
|
|
%(prog)s push --exclude executions # 排除指定表
|
|||
|
|
%(prog)s push --dry-run # 预览模式
|
|||
|
|
%(prog)s push --mode replace # 全量替换(DELETE+INSERT)
|
|||
|
|
%(prog)s pull # 云端→本地
|
|||
|
|
""",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"direction",
|
|||
|
|
choices=["push", "pull"],
|
|||
|
|
help="push = 本地→云端, pull = 云端→本地",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--tables",
|
|||
|
|
default=None,
|
|||
|
|
help="要同步的表(逗号分隔),默认=核心+用户+知识库, 'all'=全部表",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--exclude",
|
|||
|
|
default=None,
|
|||
|
|
help="排除的表(逗号分隔)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--mode",
|
|||
|
|
choices=["upsert", "replace", "append"],
|
|||
|
|
default="upsert",
|
|||
|
|
help="upsert=有主键则更新(recommended), replace=删后插入, append=跳过重复 (默认: upsert)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--dry-run",
|
|||
|
|
action="store_true",
|
|||
|
|
help="预览模式,不实际写入数据",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
env = _load_env()
|
|||
|
|
|
|||
|
|
if "DATABASE_URL" not in env:
|
|||
|
|
print("错误: .env 文件中未找到 DATABASE_URL")
|
|||
|
|
sys.exit(1)
|
|||
|
|
|
|||
|
|
if args.direction == "push" and "CLOUD_DATABASE_URL" not in env:
|
|||
|
|
print("错误: .env 文件中未找到 CLOUD_DATABASE_URL(推送需要云端连接信息)")
|
|||
|
|
sys.exit(1)
|
|||
|
|
|
|||
|
|
tables = resolve_tables(args.tables, args.exclude)
|
|||
|
|
|
|||
|
|
run_sync(
|
|||
|
|
env=env,
|
|||
|
|
tables=tables,
|
|||
|
|
direction=args.direction,
|
|||
|
|
mode=args.mode,
|
|||
|
|
dry_run=args.dry_run,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|