From a5152b2ea58e04e4939af364fa0f2c816453f5db Mon Sep 17 00:00:00 2001 From: renjianbo <263303411@qq.com> Date: Tue, 30 Jun 2026 00:15:22 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20add=20database=20sync=20script=20for=20?= =?UTF-8?q?local=20=E2=86=94=20cloud=20data=20sync?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Supports push (local→cloud) and pull (cloud→local) with: - Table grouping (core/user_data/knowledge/logs), logs excluded by default - --tables/--exclude for fine-grained control - --mode upsert/replace/append - --dry-run preview - FK dependency ordering, auto-disable foreign key checks Co-Authored-By: Claude Opus 4.6 --- backend/scripts/sync_db.py | 425 +++++++++++++++++++++++++++++++++++++ 1 file changed, 425 insertions(+) create mode 100644 backend/scripts/sync_db.py diff --git a/backend/scripts/sync_db.py b/backend/scripts/sync_db.py new file mode 100644 index 0000000..035674b --- /dev/null +++ b/backend/scripts/sync_db.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python3 +""" +数据库同步脚本:本地 ↔ 云端 + +用法: + python scripts/sync_db.py push # 本地→云端(核心+用户数据) + python scripts/sync_db.py push --tables all # 全部表 + python scripts/sync_db.py push --tables agents,users + python scripts/sync_db.py push --exclude executions,agent_llm_logs + python scripts/sync_db.py pull # 云端→本地 + python scripts/sync_db.py push --dry-run # 预览 + python scripts/sync_db.py push --mode replace +""" +from __future__ import annotations + +import argparse +import os +import re +import sys +from collections import OrderedDict +from typing import Dict, List, Tuple + +import pymysql + +# ── 从 .env 读取数据库连接 ────────────────────────────────────────────── + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_ENV_FILE = os.path.join(os.path.dirname(_SCRIPT_DIR), ".env") + + +def _load_env() -> Dict[str, str]: + """解析 .env 文件,返回键值对。""" + env: Dict[str, str] = {} + if not os.path.exists(_ENV_FILE): + return env + with open(_ENV_FILE, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + m = re.match(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(.+)$", line) + if m: + env[m.group(1)] = m.group(2).strip("\"'") + return env + + +def _parse_db_url(url: str) -> dict: + """解析 mysql+pymysql://user:pass@host:port/db?params → pymysql.connect kwargs.""" + m = re.match( + r"mysql\+pymysql://(?P[^:]+):(?P[^@]+)" + r"@(?P[^:/]+):?(?P\d+)?/(?P[^?]+)", + url, + ) + if not m: + raise ValueError(f"无法解析数据库 URL: {url}") + return { + "host": m.group("host"), + "port": int(m.group("port") or 3306), + "user": m.group("user"), + "password": m.group("password"), + "database": m.group("database"), + "charset": "utf8mb4", + "connect_timeout": 30, + "read_timeout": 60, + "write_timeout": 60, + } + + +# ── 表分组配置 ─────────────────────────────────────────────────────────── + +TABLE_GROUPS: Dict[str, List[str]] = { + "core": [ + "roles", + "permissions", + "role_permissions", + "tools", + "node_templates", + "node_plugins", + "workflow_templates", + "template_favorites", + "template_ratings", + "orchestration_templates", + "agents", + "agent_extensions", + "agent_favorites", + "agent_ratings", + "agent_permissions", + "knowledge_bases", + "kb_documents", + "kb_document_chunks", + "data_sources", + "model_configs", + "scene_contracts", + "alert_rules", + "alert_logs", + "workflows", + "workflow_versions", + "workflow_permissions", + "conversation_branches", + "push_subscriptions", + "fcm_tokens", + ], + "user_data": [ + "users", + "user_roles", + "user_feishu_open_ids", + "user_fingerprints", + "workspaces", + "workspace_memberships", + "teams", + "team_members", + "goals", + "tasks", + "agent_schedules", + "agent_execution_logs", + "feedback_records", + ], + "knowledge": [ + "global_knowledge", + "knowledge_entries", + "knowledge_entities", + "knowledge_relations", + ], + "logs": [ + "executions", + "execution_logs", + "agent_execution_logs", + "agent_llm_logs", + "agent_learning_patterns", + "agent_vector_memories", + "audit_logs", + "notifications", + "user_behavior_logs", + "persistent_user_memories", + "shadow_comparisons", + "chat_messages", + ], +} + +# 不需要同步的系统表 +SYSTEM_SKIP = {"alembic_version"} + + +def resolve_tables(tables_arg: str | None, exclude_arg: str | None) -> List[str]: + """解析命令行参数,返回要同步的表列表(按 FK 依赖排序)。""" + if tables_arg == "all" or tables_arg is None: + selected = set() + if tables_arg is None: + # 默认:core + user_data + knowledge(排除 logs) + for group in ["core", "user_data", "knowledge"]: + selected.update(TABLE_GROUPS.get(group, [])) + else: + # all:所有表 + for tables in TABLE_GROUPS.values(): + selected.update(tables) + else: + selected = {t.strip() for t in tables_arg.split(",") if t.strip()} + + if exclude_arg: + excluded = {t.strip() for t in exclude_arg.split(",") if t.strip()} + selected -= excluded + + selected -= SYSTEM_SKIP + return _sort_by_fk(selected) + + +# FK 依赖顺序(主表在前) +FK_ORDER: List[str] = [ + "roles", "permissions", "role_permissions", "users", "user_roles", + "user_feishu_open_ids", "user_fingerprints", "workspaces", + "workspace_memberships", "teams", "team_members", + "workflow_templates", "template_favorites", "template_ratings", + "node_templates", "node_plugins", "orchestration_templates", + "tools", "data_sources", "model_configs", + "knowledge_bases", "kb_documents", "kb_document_chunks", + "global_knowledge", "knowledge_entries", "knowledge_entities", + "knowledge_relations", + "agents", "agent_extensions", "agent_favorites", "agent_ratings", + "agent_permissions", "agent_schedules", + "workflows", "workflow_versions", "workflow_permissions", + "executions", "execution_logs", + "agent_execution_logs", "agent_llm_logs", "agent_learning_patterns", + "agent_vector_memories", + "goals", "tasks", + "alert_rules", "alert_logs", + "notifications", "audit_logs", "user_behavior_logs", + "persistent_user_memories", "shadow_comparisons", + "feedback_records", + "conversation_branches", + "scene_contracts", + "push_subscriptions", "fcm_tokens", + "chat_messages", +] + + +def _sort_by_fk(table_names: set) -> List[str]: + """按 FK 依赖顺序排列表名。""" + order_map = {name: i for i, name in enumerate(FK_ORDER)} + sorted_tables = sorted( + table_names, key=lambda t: order_map.get(t, len(FK_ORDER)) + ) + return sorted_tables + + +# ── 核心同步逻辑 ────────────────────────────────────────────────────────── + + +def get_connection(env: dict, key: str = "DATABASE_URL") -> pymysql.Connection: + kwargs = _parse_db_url(env[key]) + return pymysql.connect(**kwargs) + + +def get_column_names(cursor, table: str) -> List[str]: + cursor.execute(f"DESC `{table}`") + return [row[0] for row in cursor.fetchall()] + + +def get_primary_key(cursor, table: str) -> str | None: + cursor.execute(f"DESC `{table}`") + for row in cursor.fetchall(): + if row[3] == "PRI": # Key column + return row[0] + return None + + +def fetch_table(cursor, table: str, columns: List[str]) -> List[Tuple]: + cols = ", ".join(f"`{c}`" for c in columns) + cursor.execute(f"SELECT {cols} FROM `{table}`") + return cursor.fetchall() + + +def sync_table( + src_conn: pymysql.Connection, + dst_conn: pymysql.Connection, + table: str, + mode: str = "upsert", + dry_run: bool = False, +): + """同步单张表的数据。""" + src_cur = src_conn.cursor() + dst_cur = dst_conn.cursor() + + columns = get_column_names(src_cur, table) + if not columns: + print(f" [SKIP] {table}: 无法读取列信息") + return 0 + + pk = get_primary_key(src_cur, table) + rows = fetch_table(src_cur, table, columns) + row_count = len(rows) + + if row_count == 0: + print(f" [SKIP] {table}: 源表无数据") + return 0 + + col_placeholders = ", ".join(["%s"] * len(columns)) + cols_quoted = ", ".join(f"`{c}`" for c in columns) + + if dry_run: + print(f" [DRY-RUN] {table}: {row_count} 行 (模式: {mode})") + return 0 + + try: + if mode == "replace": + dst_cur.execute(f"DELETE FROM `{table}`") + insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})" + dst_cur.executemany(insert_sql, rows) + + elif mode == "upsert" and pk: + update_cols = [c for c in columns if c != pk] + update_clause = ", ".join(f"`{c}` = VALUES(`{c}`)" for c in update_cols) + insert_sql = ( + f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders}) " + f"ON DUPLICATE KEY UPDATE {update_clause}" + ) + dst_cur.executemany(insert_sql, rows) + + elif mode == "append": + insert_sql = f"INSERT IGNORE INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})" + dst_cur.executemany(insert_sql, rows) + + else: + # 没有主键的表,直接用 replace 模式 + dst_cur.execute(f"DELETE FROM `{table}`") + insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})" + dst_cur.executemany(insert_sql, rows) + + print(f" [OK] {table}: {row_count} 行") + return row_count + + except Exception as e: + print(f" [ERROR] {table}: {e}") + return -1 + + +def run_sync( + env: dict, + tables: List[str], + direction: str, + mode: str, + dry_run: bool, +): + """执行所有表的同步。""" + if direction == "push": + src_label, dst_label = "本地", "云端" + src_key, dst_key = "DATABASE_URL", "CLOUD_DATABASE_URL" + else: + src_label, dst_label = "云端", "本地" + src_key, dst_key = "CLOUD_DATABASE_URL", "DATABASE_URL" + + if dst_key not in env: + print(f"错误: .env 中未配置 {dst_key}") + sys.exit(1) + + print(f"同步方向: {src_label} → {dst_label}") + print(f"模式: {mode} | 表数: {len(tables)} | {'预览模式' if dry_run else '执行模式'}") + print(f"表列表: {', '.join(tables)}") + print("-" * 60) + + src_conn = get_connection(env, src_key) + dst_conn = get_connection(env, dst_key) + + # 关闭外键检查 + if not dry_run: + dst_cur = dst_conn.cursor() + dst_cur.execute("SET FOREIGN_KEY_CHECKS = 0") + + total_rows = 0 + errors = 0 + + try: + for table in tables: + result = sync_table(src_conn, dst_conn, table, mode, dry_run) + if result > 0: + total_rows += result + elif result < 0: + errors += 1 + + if not dry_run: + dst_conn.commit() + + finally: + if not dry_run: + dst_cur = dst_conn.cursor() + dst_cur.execute("SET FOREIGN_KEY_CHECKS = 1") + src_conn.close() + dst_conn.close() + + print("-" * 60) + status = "预览完成" if dry_run else "同步完成" + if errors: + print(f"{status}: {total_rows} 行, {errors} 个错误") + else: + print(f"{status}: {total_rows} 行, 无错误") + + +# ── CLI ─────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser( + description="数据库同步脚本:本地 ↔ 云端", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + %(prog)s push # 本地→云端(核心+用户数据+知识库) + %(prog)s push --tables all # 同步所有表 + %(prog)s push --tables agents,users # 只同步指定表 + %(prog)s push --exclude executions # 排除指定表 + %(prog)s push --dry-run # 预览模式 + %(prog)s push --mode replace # 全量替换(DELETE+INSERT) + %(prog)s pull # 云端→本地 + """, + ) + parser.add_argument( + "direction", + choices=["push", "pull"], + help="push = 本地→云端, pull = 云端→本地", + ) + parser.add_argument( + "--tables", + default=None, + help="要同步的表(逗号分隔),默认=核心+用户+知识库, 'all'=全部表", + ) + parser.add_argument( + "--exclude", + default=None, + help="排除的表(逗号分隔)", + ) + parser.add_argument( + "--mode", + choices=["upsert", "replace", "append"], + default="upsert", + help="upsert=有主键则更新(recommended), replace=删后插入, append=跳过重复 (默认: upsert)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="预览模式,不实际写入数据", + ) + + args = parser.parse_args() + env = _load_env() + + if "DATABASE_URL" not in env: + print("错误: .env 文件中未找到 DATABASE_URL") + sys.exit(1) + + if args.direction == "push" and "CLOUD_DATABASE_URL" not in env: + print("错误: .env 文件中未找到 CLOUD_DATABASE_URL(推送需要云端连接信息)") + sys.exit(1) + + tables = resolve_tables(args.tables, args.exclude) + + run_sync( + env=env, + tables=tables, + direction=args.direction, + mode=args.mode, + dry_run=args.dry_run, + ) + + +if __name__ == "__main__": + main()