Supports push (local→cloud) and pull (cloud→local) with: - Table grouping (core/user_data/knowledge/logs), logs excluded by default - --tables/--exclude for fine-grained control - --mode upsert/replace/append - --dry-run preview - FK dependency ordering, auto-disable foreign key checks Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
426 lines
13 KiB
Python
426 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
数据库同步脚本:本地 ↔ 云端
|
||
|
||
用法:
|
||
python scripts/sync_db.py push # 本地→云端(核心+用户数据)
|
||
python scripts/sync_db.py push --tables all # 全部表
|
||
python scripts/sync_db.py push --tables agents,users
|
||
python scripts/sync_db.py push --exclude executions,agent_llm_logs
|
||
python scripts/sync_db.py pull # 云端→本地
|
||
python scripts/sync_db.py push --dry-run # 预览
|
||
python scripts/sync_db.py push --mode replace
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import os
|
||
import re
|
||
import sys
|
||
from collections import OrderedDict
|
||
from typing import Dict, List, Tuple
|
||
|
||
import pymysql
|
||
|
||
# ── 从 .env 读取数据库连接 ──────────────────────────────────────────────
|
||
|
||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
_ENV_FILE = os.path.join(os.path.dirname(_SCRIPT_DIR), ".env")
|
||
|
||
|
||
def _load_env() -> Dict[str, str]:
|
||
"""解析 .env 文件,返回键值对。"""
|
||
env: Dict[str, str] = {}
|
||
if not os.path.exists(_ENV_FILE):
|
||
return env
|
||
with open(_ENV_FILE, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line or line.startswith("#"):
|
||
continue
|
||
m = re.match(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(.+)$", line)
|
||
if m:
|
||
env[m.group(1)] = m.group(2).strip("\"'")
|
||
return env
|
||
|
||
|
||
def _parse_db_url(url: str) -> dict:
|
||
"""解析 mysql+pymysql://user:pass@host:port/db?params → pymysql.connect kwargs."""
|
||
m = re.match(
|
||
r"mysql\+pymysql://(?P<user>[^:]+):(?P<password>[^@]+)"
|
||
r"@(?P<host>[^:/]+):?(?P<port>\d+)?/(?P<database>[^?]+)",
|
||
url,
|
||
)
|
||
if not m:
|
||
raise ValueError(f"无法解析数据库 URL: {url}")
|
||
return {
|
||
"host": m.group("host"),
|
||
"port": int(m.group("port") or 3306),
|
||
"user": m.group("user"),
|
||
"password": m.group("password"),
|
||
"database": m.group("database"),
|
||
"charset": "utf8mb4",
|
||
"connect_timeout": 30,
|
||
"read_timeout": 60,
|
||
"write_timeout": 60,
|
||
}
|
||
|
||
|
||
# ── 表分组配置 ───────────────────────────────────────────────────────────
|
||
|
||
TABLE_GROUPS: Dict[str, List[str]] = {
|
||
"core": [
|
||
"roles",
|
||
"permissions",
|
||
"role_permissions",
|
||
"tools",
|
||
"node_templates",
|
||
"node_plugins",
|
||
"workflow_templates",
|
||
"template_favorites",
|
||
"template_ratings",
|
||
"orchestration_templates",
|
||
"agents",
|
||
"agent_extensions",
|
||
"agent_favorites",
|
||
"agent_ratings",
|
||
"agent_permissions",
|
||
"knowledge_bases",
|
||
"kb_documents",
|
||
"kb_document_chunks",
|
||
"data_sources",
|
||
"model_configs",
|
||
"scene_contracts",
|
||
"alert_rules",
|
||
"alert_logs",
|
||
"workflows",
|
||
"workflow_versions",
|
||
"workflow_permissions",
|
||
"conversation_branches",
|
||
"push_subscriptions",
|
||
"fcm_tokens",
|
||
],
|
||
"user_data": [
|
||
"users",
|
||
"user_roles",
|
||
"user_feishu_open_ids",
|
||
"user_fingerprints",
|
||
"workspaces",
|
||
"workspace_memberships",
|
||
"teams",
|
||
"team_members",
|
||
"goals",
|
||
"tasks",
|
||
"agent_schedules",
|
||
"agent_execution_logs",
|
||
"feedback_records",
|
||
],
|
||
"knowledge": [
|
||
"global_knowledge",
|
||
"knowledge_entries",
|
||
"knowledge_entities",
|
||
"knowledge_relations",
|
||
],
|
||
"logs": [
|
||
"executions",
|
||
"execution_logs",
|
||
"agent_execution_logs",
|
||
"agent_llm_logs",
|
||
"agent_learning_patterns",
|
||
"agent_vector_memories",
|
||
"audit_logs",
|
||
"notifications",
|
||
"user_behavior_logs",
|
||
"persistent_user_memories",
|
||
"shadow_comparisons",
|
||
"chat_messages",
|
||
],
|
||
}
|
||
|
||
# 不需要同步的系统表
|
||
SYSTEM_SKIP = {"alembic_version"}
|
||
|
||
|
||
def resolve_tables(tables_arg: str | None, exclude_arg: str | None) -> List[str]:
|
||
"""解析命令行参数,返回要同步的表列表(按 FK 依赖排序)。"""
|
||
if tables_arg == "all" or tables_arg is None:
|
||
selected = set()
|
||
if tables_arg is None:
|
||
# 默认:core + user_data + knowledge(排除 logs)
|
||
for group in ["core", "user_data", "knowledge"]:
|
||
selected.update(TABLE_GROUPS.get(group, []))
|
||
else:
|
||
# all:所有表
|
||
for tables in TABLE_GROUPS.values():
|
||
selected.update(tables)
|
||
else:
|
||
selected = {t.strip() for t in tables_arg.split(",") if t.strip()}
|
||
|
||
if exclude_arg:
|
||
excluded = {t.strip() for t in exclude_arg.split(",") if t.strip()}
|
||
selected -= excluded
|
||
|
||
selected -= SYSTEM_SKIP
|
||
return _sort_by_fk(selected)
|
||
|
||
|
||
# FK 依赖顺序(主表在前)
|
||
FK_ORDER: List[str] = [
|
||
"roles", "permissions", "role_permissions", "users", "user_roles",
|
||
"user_feishu_open_ids", "user_fingerprints", "workspaces",
|
||
"workspace_memberships", "teams", "team_members",
|
||
"workflow_templates", "template_favorites", "template_ratings",
|
||
"node_templates", "node_plugins", "orchestration_templates",
|
||
"tools", "data_sources", "model_configs",
|
||
"knowledge_bases", "kb_documents", "kb_document_chunks",
|
||
"global_knowledge", "knowledge_entries", "knowledge_entities",
|
||
"knowledge_relations",
|
||
"agents", "agent_extensions", "agent_favorites", "agent_ratings",
|
||
"agent_permissions", "agent_schedules",
|
||
"workflows", "workflow_versions", "workflow_permissions",
|
||
"executions", "execution_logs",
|
||
"agent_execution_logs", "agent_llm_logs", "agent_learning_patterns",
|
||
"agent_vector_memories",
|
||
"goals", "tasks",
|
||
"alert_rules", "alert_logs",
|
||
"notifications", "audit_logs", "user_behavior_logs",
|
||
"persistent_user_memories", "shadow_comparisons",
|
||
"feedback_records",
|
||
"conversation_branches",
|
||
"scene_contracts",
|
||
"push_subscriptions", "fcm_tokens",
|
||
"chat_messages",
|
||
]
|
||
|
||
|
||
def _sort_by_fk(table_names: set) -> List[str]:
|
||
"""按 FK 依赖顺序排列表名。"""
|
||
order_map = {name: i for i, name in enumerate(FK_ORDER)}
|
||
sorted_tables = sorted(
|
||
table_names, key=lambda t: order_map.get(t, len(FK_ORDER))
|
||
)
|
||
return sorted_tables
|
||
|
||
|
||
# ── 核心同步逻辑 ──────────────────────────────────────────────────────────
|
||
|
||
|
||
def get_connection(env: dict, key: str = "DATABASE_URL") -> pymysql.Connection:
|
||
kwargs = _parse_db_url(env[key])
|
||
return pymysql.connect(**kwargs)
|
||
|
||
|
||
def get_column_names(cursor, table: str) -> List[str]:
|
||
cursor.execute(f"DESC `{table}`")
|
||
return [row[0] for row in cursor.fetchall()]
|
||
|
||
|
||
def get_primary_key(cursor, table: str) -> str | None:
|
||
cursor.execute(f"DESC `{table}`")
|
||
for row in cursor.fetchall():
|
||
if row[3] == "PRI": # Key column
|
||
return row[0]
|
||
return None
|
||
|
||
|
||
def fetch_table(cursor, table: str, columns: List[str]) -> List[Tuple]:
|
||
cols = ", ".join(f"`{c}`" for c in columns)
|
||
cursor.execute(f"SELECT {cols} FROM `{table}`")
|
||
return cursor.fetchall()
|
||
|
||
|
||
def sync_table(
|
||
src_conn: pymysql.Connection,
|
||
dst_conn: pymysql.Connection,
|
||
table: str,
|
||
mode: str = "upsert",
|
||
dry_run: bool = False,
|
||
):
|
||
"""同步单张表的数据。"""
|
||
src_cur = src_conn.cursor()
|
||
dst_cur = dst_conn.cursor()
|
||
|
||
columns = get_column_names(src_cur, table)
|
||
if not columns:
|
||
print(f" [SKIP] {table}: 无法读取列信息")
|
||
return 0
|
||
|
||
pk = get_primary_key(src_cur, table)
|
||
rows = fetch_table(src_cur, table, columns)
|
||
row_count = len(rows)
|
||
|
||
if row_count == 0:
|
||
print(f" [SKIP] {table}: 源表无数据")
|
||
return 0
|
||
|
||
col_placeholders = ", ".join(["%s"] * len(columns))
|
||
cols_quoted = ", ".join(f"`{c}`" for c in columns)
|
||
|
||
if dry_run:
|
||
print(f" [DRY-RUN] {table}: {row_count} 行 (模式: {mode})")
|
||
return 0
|
||
|
||
try:
|
||
if mode == "replace":
|
||
dst_cur.execute(f"DELETE FROM `{table}`")
|
||
insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
|
||
dst_cur.executemany(insert_sql, rows)
|
||
|
||
elif mode == "upsert" and pk:
|
||
update_cols = [c for c in columns if c != pk]
|
||
update_clause = ", ".join(f"`{c}` = VALUES(`{c}`)" for c in update_cols)
|
||
insert_sql = (
|
||
f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders}) "
|
||
f"ON DUPLICATE KEY UPDATE {update_clause}"
|
||
)
|
||
dst_cur.executemany(insert_sql, rows)
|
||
|
||
elif mode == "append":
|
||
insert_sql = f"INSERT IGNORE INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
|
||
dst_cur.executemany(insert_sql, rows)
|
||
|
||
else:
|
||
# 没有主键的表,直接用 replace 模式
|
||
dst_cur.execute(f"DELETE FROM `{table}`")
|
||
insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
|
||
dst_cur.executemany(insert_sql, rows)
|
||
|
||
print(f" [OK] {table}: {row_count} 行")
|
||
return row_count
|
||
|
||
except Exception as e:
|
||
print(f" [ERROR] {table}: {e}")
|
||
return -1
|
||
|
||
|
||
def run_sync(
|
||
env: dict,
|
||
tables: List[str],
|
||
direction: str,
|
||
mode: str,
|
||
dry_run: bool,
|
||
):
|
||
"""执行所有表的同步。"""
|
||
if direction == "push":
|
||
src_label, dst_label = "本地", "云端"
|
||
src_key, dst_key = "DATABASE_URL", "CLOUD_DATABASE_URL"
|
||
else:
|
||
src_label, dst_label = "云端", "本地"
|
||
src_key, dst_key = "CLOUD_DATABASE_URL", "DATABASE_URL"
|
||
|
||
if dst_key not in env:
|
||
print(f"错误: .env 中未配置 {dst_key}")
|
||
sys.exit(1)
|
||
|
||
print(f"同步方向: {src_label} → {dst_label}")
|
||
print(f"模式: {mode} | 表数: {len(tables)} | {'预览模式' if dry_run else '执行模式'}")
|
||
print(f"表列表: {', '.join(tables)}")
|
||
print("-" * 60)
|
||
|
||
src_conn = get_connection(env, src_key)
|
||
dst_conn = get_connection(env, dst_key)
|
||
|
||
# 关闭外键检查
|
||
if not dry_run:
|
||
dst_cur = dst_conn.cursor()
|
||
dst_cur.execute("SET FOREIGN_KEY_CHECKS = 0")
|
||
|
||
total_rows = 0
|
||
errors = 0
|
||
|
||
try:
|
||
for table in tables:
|
||
result = sync_table(src_conn, dst_conn, table, mode, dry_run)
|
||
if result > 0:
|
||
total_rows += result
|
||
elif result < 0:
|
||
errors += 1
|
||
|
||
if not dry_run:
|
||
dst_conn.commit()
|
||
|
||
finally:
|
||
if not dry_run:
|
||
dst_cur = dst_conn.cursor()
|
||
dst_cur.execute("SET FOREIGN_KEY_CHECKS = 1")
|
||
src_conn.close()
|
||
dst_conn.close()
|
||
|
||
print("-" * 60)
|
||
status = "预览完成" if dry_run else "同步完成"
|
||
if errors:
|
||
print(f"{status}: {total_rows} 行, {errors} 个错误")
|
||
else:
|
||
print(f"{status}: {total_rows} 行, 无错误")
|
||
|
||
|
||
# ── CLI ───────────────────────────────────────────────────────────────────
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="数据库同步脚本:本地 ↔ 云端",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
示例:
|
||
%(prog)s push # 本地→云端(核心+用户数据+知识库)
|
||
%(prog)s push --tables all # 同步所有表
|
||
%(prog)s push --tables agents,users # 只同步指定表
|
||
%(prog)s push --exclude executions # 排除指定表
|
||
%(prog)s push --dry-run # 预览模式
|
||
%(prog)s push --mode replace # 全量替换(DELETE+INSERT)
|
||
%(prog)s pull # 云端→本地
|
||
""",
|
||
)
|
||
parser.add_argument(
|
||
"direction",
|
||
choices=["push", "pull"],
|
||
help="push = 本地→云端, pull = 云端→本地",
|
||
)
|
||
parser.add_argument(
|
||
"--tables",
|
||
default=None,
|
||
help="要同步的表(逗号分隔),默认=核心+用户+知识库, 'all'=全部表",
|
||
)
|
||
parser.add_argument(
|
||
"--exclude",
|
||
default=None,
|
||
help="排除的表(逗号分隔)",
|
||
)
|
||
parser.add_argument(
|
||
"--mode",
|
||
choices=["upsert", "replace", "append"],
|
||
default="upsert",
|
||
help="upsert=有主键则更新(recommended), replace=删后插入, append=跳过重复 (默认: upsert)",
|
||
)
|
||
parser.add_argument(
|
||
"--dry-run",
|
||
action="store_true",
|
||
help="预览模式,不实际写入数据",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
env = _load_env()
|
||
|
||
if "DATABASE_URL" not in env:
|
||
print("错误: .env 文件中未找到 DATABASE_URL")
|
||
sys.exit(1)
|
||
|
||
if args.direction == "push" and "CLOUD_DATABASE_URL" not in env:
|
||
print("错误: .env 文件中未找到 CLOUD_DATABASE_URL(推送需要云端连接信息)")
|
||
sys.exit(1)
|
||
|
||
tables = resolve_tables(args.tables, args.exclude)
|
||
|
||
run_sync(
|
||
env=env,
|
||
tables=tables,
|
||
direction=args.direction,
|
||
mode=args.mode,
|
||
dry_run=args.dry_run,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|