#!/usr/bin/env python3 """ 数据库同步脚本:本地 ↔ 云端 用法: python scripts/sync_db.py push # 本地→云端(核心+用户数据) python scripts/sync_db.py push --tables all # 全部表 python scripts/sync_db.py push --tables agents,users python scripts/sync_db.py push --exclude executions,agent_llm_logs python scripts/sync_db.py pull # 云端→本地 python scripts/sync_db.py push --dry-run # 预览 python scripts/sync_db.py push --mode replace """ from __future__ import annotations import argparse import os import re import sys from collections import OrderedDict from typing import Dict, List, Tuple import pymysql # ── 从 .env 读取数据库连接 ────────────────────────────────────────────── _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) _ENV_FILE = os.path.join(os.path.dirname(_SCRIPT_DIR), ".env") def _load_env() -> Dict[str, str]: """解析 .env 文件,返回键值对。""" env: Dict[str, str] = {} if not os.path.exists(_ENV_FILE): return env with open(_ENV_FILE, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line or line.startswith("#"): continue m = re.match(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(.+)$", line) if m: env[m.group(1)] = m.group(2).strip("\"'") return env def _parse_db_url(url: str) -> dict: """解析 mysql+pymysql://user:pass@host:port/db?params → pymysql.connect kwargs.""" m = re.match( r"mysql\+pymysql://(?P[^:]+):(?P[^@]+)" r"@(?P[^:/]+):?(?P\d+)?/(?P[^?]+)", url, ) if not m: raise ValueError(f"无法解析数据库 URL: {url}") return { "host": m.group("host"), "port": int(m.group("port") or 3306), "user": m.group("user"), "password": m.group("password"), "database": m.group("database"), "charset": "utf8mb4", "connect_timeout": 30, "read_timeout": 60, "write_timeout": 60, } # ── 表分组配置 ─────────────────────────────────────────────────────────── TABLE_GROUPS: Dict[str, List[str]] = { "core": [ "roles", "permissions", "role_permissions", "tools", "node_templates", "node_plugins", "workflow_templates", "template_favorites", "template_ratings", "orchestration_templates", "agents", "agent_extensions", "agent_favorites", "agent_ratings", "agent_permissions", "knowledge_bases", "kb_documents", "kb_document_chunks", "data_sources", "model_configs", "scene_contracts", "alert_rules", "alert_logs", "workflows", "workflow_versions", "workflow_permissions", "conversation_branches", "push_subscriptions", "fcm_tokens", ], "user_data": [ "users", "user_roles", "user_feishu_open_ids", "user_fingerprints", "workspaces", "workspace_memberships", "teams", "team_members", "goals", "tasks", "agent_schedules", "agent_execution_logs", "feedback_records", ], "knowledge": [ "global_knowledge", "knowledge_entries", "knowledge_entities", "knowledge_relations", ], "logs": [ "executions", "execution_logs", "agent_execution_logs", "agent_llm_logs", "agent_learning_patterns", "agent_vector_memories", "audit_logs", "notifications", "user_behavior_logs", "persistent_user_memories", "shadow_comparisons", "chat_messages", ], } # 不需要同步的系统表 SYSTEM_SKIP = {"alembic_version"} def resolve_tables(tables_arg: str | None, exclude_arg: str | None) -> List[str]: """解析命令行参数,返回要同步的表列表(按 FK 依赖排序)。""" if tables_arg == "all" or tables_arg is None: selected = set() if tables_arg is None: # 默认:core + user_data + knowledge(排除 logs) for group in ["core", "user_data", "knowledge"]: selected.update(TABLE_GROUPS.get(group, [])) else: # all:所有表 for tables in TABLE_GROUPS.values(): selected.update(tables) else: selected = {t.strip() for t in tables_arg.split(",") if t.strip()} if exclude_arg: excluded = {t.strip() for t in exclude_arg.split(",") if t.strip()} selected -= excluded selected -= SYSTEM_SKIP return _sort_by_fk(selected) # FK 依赖顺序(主表在前) FK_ORDER: List[str] = [ "roles", "permissions", "role_permissions", "users", "user_roles", "user_feishu_open_ids", "user_fingerprints", "workspaces", "workspace_memberships", "teams", "team_members", "workflow_templates", "template_favorites", "template_ratings", "node_templates", "node_plugins", "orchestration_templates", "tools", "data_sources", "model_configs", "knowledge_bases", "kb_documents", "kb_document_chunks", "global_knowledge", "knowledge_entries", "knowledge_entities", "knowledge_relations", "agents", "agent_extensions", "agent_favorites", "agent_ratings", "agent_permissions", "agent_schedules", "workflows", "workflow_versions", "workflow_permissions", "executions", "execution_logs", "agent_execution_logs", "agent_llm_logs", "agent_learning_patterns", "agent_vector_memories", "goals", "tasks", "alert_rules", "alert_logs", "notifications", "audit_logs", "user_behavior_logs", "persistent_user_memories", "shadow_comparisons", "feedback_records", "conversation_branches", "scene_contracts", "push_subscriptions", "fcm_tokens", "chat_messages", ] def _sort_by_fk(table_names: set) -> List[str]: """按 FK 依赖顺序排列表名。""" order_map = {name: i for i, name in enumerate(FK_ORDER)} sorted_tables = sorted( table_names, key=lambda t: order_map.get(t, len(FK_ORDER)) ) return sorted_tables # ── 核心同步逻辑 ────────────────────────────────────────────────────────── def get_connection(env: dict, key: str = "DATABASE_URL") -> pymysql.Connection: kwargs = _parse_db_url(env[key]) return pymysql.connect(**kwargs) def get_column_names(cursor, table: str) -> List[str]: cursor.execute(f"DESC `{table}`") return [row[0] for row in cursor.fetchall()] def get_primary_key(cursor, table: str) -> str | None: cursor.execute(f"DESC `{table}`") for row in cursor.fetchall(): if row[3] == "PRI": # Key column return row[0] return None def fetch_table(cursor, table: str, columns: List[str]) -> List[Tuple]: cols = ", ".join(f"`{c}`" for c in columns) cursor.execute(f"SELECT {cols} FROM `{table}`") return cursor.fetchall() def sync_table( src_conn: pymysql.Connection, dst_conn: pymysql.Connection, table: str, mode: str = "upsert", dry_run: bool = False, ): """同步单张表的数据。""" src_cur = src_conn.cursor() dst_cur = dst_conn.cursor() columns = get_column_names(src_cur, table) if not columns: print(f" [SKIP] {table}: 无法读取列信息") return 0 pk = get_primary_key(src_cur, table) rows = fetch_table(src_cur, table, columns) row_count = len(rows) if row_count == 0: print(f" [SKIP] {table}: 源表无数据") return 0 col_placeholders = ", ".join(["%s"] * len(columns)) cols_quoted = ", ".join(f"`{c}`" for c in columns) if dry_run: print(f" [DRY-RUN] {table}: {row_count} 行 (模式: {mode})") return 0 try: if mode == "replace": dst_cur.execute(f"DELETE FROM `{table}`") insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})" dst_cur.executemany(insert_sql, rows) elif mode == "upsert" and pk: update_cols = [c for c in columns if c != pk] update_clause = ", ".join(f"`{c}` = VALUES(`{c}`)" for c in update_cols) insert_sql = ( f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders}) " f"ON DUPLICATE KEY UPDATE {update_clause}" ) dst_cur.executemany(insert_sql, rows) elif mode == "append": insert_sql = f"INSERT IGNORE INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})" dst_cur.executemany(insert_sql, rows) else: # 没有主键的表,直接用 replace 模式 dst_cur.execute(f"DELETE FROM `{table}`") insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})" dst_cur.executemany(insert_sql, rows) print(f" [OK] {table}: {row_count} 行") return row_count except Exception as e: print(f" [ERROR] {table}: {e}") return -1 def run_sync( env: dict, tables: List[str], direction: str, mode: str, dry_run: bool, ): """执行所有表的同步。""" if direction == "push": src_label, dst_label = "本地", "云端" src_key, dst_key = "DATABASE_URL", "CLOUD_DATABASE_URL" else: src_label, dst_label = "云端", "本地" src_key, dst_key = "CLOUD_DATABASE_URL", "DATABASE_URL" if dst_key not in env: print(f"错误: .env 中未配置 {dst_key}") sys.exit(1) print(f"同步方向: {src_label} → {dst_label}") print(f"模式: {mode} | 表数: {len(tables)} | {'预览模式' if dry_run else '执行模式'}") print(f"表列表: {', '.join(tables)}") print("-" * 60) src_conn = get_connection(env, src_key) dst_conn = get_connection(env, dst_key) # 关闭外键检查 if not dry_run: dst_cur = dst_conn.cursor() dst_cur.execute("SET FOREIGN_KEY_CHECKS = 0") total_rows = 0 errors = 0 try: for table in tables: result = sync_table(src_conn, dst_conn, table, mode, dry_run) if result > 0: total_rows += result elif result < 0: errors += 1 if not dry_run: dst_conn.commit() finally: if not dry_run: dst_cur = dst_conn.cursor() dst_cur.execute("SET FOREIGN_KEY_CHECKS = 1") src_conn.close() dst_conn.close() print("-" * 60) status = "预览完成" if dry_run else "同步完成" if errors: print(f"{status}: {total_rows} 行, {errors} 个错误") else: print(f"{status}: {total_rows} 行, 无错误") # ── CLI ─────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser( description="数据库同步脚本:本地 ↔ 云端", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 示例: %(prog)s push # 本地→云端(核心+用户数据+知识库) %(prog)s push --tables all # 同步所有表 %(prog)s push --tables agents,users # 只同步指定表 %(prog)s push --exclude executions # 排除指定表 %(prog)s push --dry-run # 预览模式 %(prog)s push --mode replace # 全量替换(DELETE+INSERT) %(prog)s pull # 云端→本地 """, ) parser.add_argument( "direction", choices=["push", "pull"], help="push = 本地→云端, pull = 云端→本地", ) parser.add_argument( "--tables", default=None, help="要同步的表(逗号分隔),默认=核心+用户+知识库, 'all'=全部表", ) parser.add_argument( "--exclude", default=None, help="排除的表(逗号分隔)", ) parser.add_argument( "--mode", choices=["upsert", "replace", "append"], default="upsert", help="upsert=有主键则更新(recommended), replace=删后插入, append=跳过重复 (默认: upsert)", ) parser.add_argument( "--dry-run", action="store_true", help="预览模式,不实际写入数据", ) args = parser.parse_args() env = _load_env() if "DATABASE_URL" not in env: print("错误: .env 文件中未找到 DATABASE_URL") sys.exit(1) if args.direction == "push" and "CLOUD_DATABASE_URL" not in env: print("错误: .env 文件中未找到 CLOUD_DATABASE_URL(推送需要云端连接信息)") sys.exit(1) tables = resolve_tables(args.tables, args.exclude) run_sync( env=env, tables=tables, direction=args.direction, mode=args.mode, dry_run=args.dry_run, ) if __name__ == "__main__": main()