Files
aiagent/backend/scripts/sync_db.py
renjianbo a5152b2ea5
Some checks failed
CI/CD Pipeline / Backend — Lint & Test (push) Has been cancelled
CI/CD Pipeline / Frontend — Lint & Build (push) Has been cancelled
CI/CD Pipeline / Docker — Build Check (push) Has been cancelled
feat: add database sync script for local ↔ cloud data sync
Supports push (local→cloud) and pull (cloud→local) with:
- Table grouping (core/user_data/knowledge/logs), logs excluded by default
- --tables/--exclude for fine-grained control
- --mode upsert/replace/append
- --dry-run preview
- FK dependency ordering, auto-disable foreign key checks

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-06-30 00:15:22 +08:00

426 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
数据库同步脚本:本地 ↔ 云端
用法:
python scripts/sync_db.py push # 本地→云端(核心+用户数据)
python scripts/sync_db.py push --tables all # 全部表
python scripts/sync_db.py push --tables agents,users
python scripts/sync_db.py push --exclude executions,agent_llm_logs
python scripts/sync_db.py pull # 云端→本地
python scripts/sync_db.py push --dry-run # 预览
python scripts/sync_db.py push --mode replace
"""
from __future__ import annotations
import argparse
import os
import re
import sys
from collections import OrderedDict
from typing import Dict, List, Tuple
import pymysql
# ── 从 .env 读取数据库连接 ──────────────────────────────────────────────
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_ENV_FILE = os.path.join(os.path.dirname(_SCRIPT_DIR), ".env")
def _load_env() -> Dict[str, str]:
"""解析 .env 文件,返回键值对。"""
env: Dict[str, str] = {}
if not os.path.exists(_ENV_FILE):
return env
with open(_ENV_FILE, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
m = re.match(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(.+)$", line)
if m:
env[m.group(1)] = m.group(2).strip("\"'")
return env
def _parse_db_url(url: str) -> dict:
"""解析 mysql+pymysql://user:pass@host:port/db?params → pymysql.connect kwargs."""
m = re.match(
r"mysql\+pymysql://(?P<user>[^:]+):(?P<password>[^@]+)"
r"@(?P<host>[^:/]+):?(?P<port>\d+)?/(?P<database>[^?]+)",
url,
)
if not m:
raise ValueError(f"无法解析数据库 URL: {url}")
return {
"host": m.group("host"),
"port": int(m.group("port") or 3306),
"user": m.group("user"),
"password": m.group("password"),
"database": m.group("database"),
"charset": "utf8mb4",
"connect_timeout": 30,
"read_timeout": 60,
"write_timeout": 60,
}
# ── 表分组配置 ───────────────────────────────────────────────────────────
TABLE_GROUPS: Dict[str, List[str]] = {
"core": [
"roles",
"permissions",
"role_permissions",
"tools",
"node_templates",
"node_plugins",
"workflow_templates",
"template_favorites",
"template_ratings",
"orchestration_templates",
"agents",
"agent_extensions",
"agent_favorites",
"agent_ratings",
"agent_permissions",
"knowledge_bases",
"kb_documents",
"kb_document_chunks",
"data_sources",
"model_configs",
"scene_contracts",
"alert_rules",
"alert_logs",
"workflows",
"workflow_versions",
"workflow_permissions",
"conversation_branches",
"push_subscriptions",
"fcm_tokens",
],
"user_data": [
"users",
"user_roles",
"user_feishu_open_ids",
"user_fingerprints",
"workspaces",
"workspace_memberships",
"teams",
"team_members",
"goals",
"tasks",
"agent_schedules",
"agent_execution_logs",
"feedback_records",
],
"knowledge": [
"global_knowledge",
"knowledge_entries",
"knowledge_entities",
"knowledge_relations",
],
"logs": [
"executions",
"execution_logs",
"agent_execution_logs",
"agent_llm_logs",
"agent_learning_patterns",
"agent_vector_memories",
"audit_logs",
"notifications",
"user_behavior_logs",
"persistent_user_memories",
"shadow_comparisons",
"chat_messages",
],
}
# 不需要同步的系统表
SYSTEM_SKIP = {"alembic_version"}
def resolve_tables(tables_arg: str | None, exclude_arg: str | None) -> List[str]:
"""解析命令行参数,返回要同步的表列表(按 FK 依赖排序)。"""
if tables_arg == "all" or tables_arg is None:
selected = set()
if tables_arg is None:
# 默认core + user_data + knowledge排除 logs
for group in ["core", "user_data", "knowledge"]:
selected.update(TABLE_GROUPS.get(group, []))
else:
# all所有表
for tables in TABLE_GROUPS.values():
selected.update(tables)
else:
selected = {t.strip() for t in tables_arg.split(",") if t.strip()}
if exclude_arg:
excluded = {t.strip() for t in exclude_arg.split(",") if t.strip()}
selected -= excluded
selected -= SYSTEM_SKIP
return _sort_by_fk(selected)
# FK 依赖顺序(主表在前)
FK_ORDER: List[str] = [
"roles", "permissions", "role_permissions", "users", "user_roles",
"user_feishu_open_ids", "user_fingerprints", "workspaces",
"workspace_memberships", "teams", "team_members",
"workflow_templates", "template_favorites", "template_ratings",
"node_templates", "node_plugins", "orchestration_templates",
"tools", "data_sources", "model_configs",
"knowledge_bases", "kb_documents", "kb_document_chunks",
"global_knowledge", "knowledge_entries", "knowledge_entities",
"knowledge_relations",
"agents", "agent_extensions", "agent_favorites", "agent_ratings",
"agent_permissions", "agent_schedules",
"workflows", "workflow_versions", "workflow_permissions",
"executions", "execution_logs",
"agent_execution_logs", "agent_llm_logs", "agent_learning_patterns",
"agent_vector_memories",
"goals", "tasks",
"alert_rules", "alert_logs",
"notifications", "audit_logs", "user_behavior_logs",
"persistent_user_memories", "shadow_comparisons",
"feedback_records",
"conversation_branches",
"scene_contracts",
"push_subscriptions", "fcm_tokens",
"chat_messages",
]
def _sort_by_fk(table_names: set) -> List[str]:
"""按 FK 依赖顺序排列表名。"""
order_map = {name: i for i, name in enumerate(FK_ORDER)}
sorted_tables = sorted(
table_names, key=lambda t: order_map.get(t, len(FK_ORDER))
)
return sorted_tables
# ── 核心同步逻辑 ──────────────────────────────────────────────────────────
def get_connection(env: dict, key: str = "DATABASE_URL") -> pymysql.Connection:
kwargs = _parse_db_url(env[key])
return pymysql.connect(**kwargs)
def get_column_names(cursor, table: str) -> List[str]:
cursor.execute(f"DESC `{table}`")
return [row[0] for row in cursor.fetchall()]
def get_primary_key(cursor, table: str) -> str | None:
cursor.execute(f"DESC `{table}`")
for row in cursor.fetchall():
if row[3] == "PRI": # Key column
return row[0]
return None
def fetch_table(cursor, table: str, columns: List[str]) -> List[Tuple]:
cols = ", ".join(f"`{c}`" for c in columns)
cursor.execute(f"SELECT {cols} FROM `{table}`")
return cursor.fetchall()
def sync_table(
src_conn: pymysql.Connection,
dst_conn: pymysql.Connection,
table: str,
mode: str = "upsert",
dry_run: bool = False,
):
"""同步单张表的数据。"""
src_cur = src_conn.cursor()
dst_cur = dst_conn.cursor()
columns = get_column_names(src_cur, table)
if not columns:
print(f" [SKIP] {table}: 无法读取列信息")
return 0
pk = get_primary_key(src_cur, table)
rows = fetch_table(src_cur, table, columns)
row_count = len(rows)
if row_count == 0:
print(f" [SKIP] {table}: 源表无数据")
return 0
col_placeholders = ", ".join(["%s"] * len(columns))
cols_quoted = ", ".join(f"`{c}`" for c in columns)
if dry_run:
print(f" [DRY-RUN] {table}: {row_count} 行 (模式: {mode})")
return 0
try:
if mode == "replace":
dst_cur.execute(f"DELETE FROM `{table}`")
insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
dst_cur.executemany(insert_sql, rows)
elif mode == "upsert" and pk:
update_cols = [c for c in columns if c != pk]
update_clause = ", ".join(f"`{c}` = VALUES(`{c}`)" for c in update_cols)
insert_sql = (
f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders}) "
f"ON DUPLICATE KEY UPDATE {update_clause}"
)
dst_cur.executemany(insert_sql, rows)
elif mode == "append":
insert_sql = f"INSERT IGNORE INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
dst_cur.executemany(insert_sql, rows)
else:
# 没有主键的表,直接用 replace 模式
dst_cur.execute(f"DELETE FROM `{table}`")
insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
dst_cur.executemany(insert_sql, rows)
print(f" [OK] {table}: {row_count}")
return row_count
except Exception as e:
print(f" [ERROR] {table}: {e}")
return -1
def run_sync(
env: dict,
tables: List[str],
direction: str,
mode: str,
dry_run: bool,
):
"""执行所有表的同步。"""
if direction == "push":
src_label, dst_label = "本地", "云端"
src_key, dst_key = "DATABASE_URL", "CLOUD_DATABASE_URL"
else:
src_label, dst_label = "云端", "本地"
src_key, dst_key = "CLOUD_DATABASE_URL", "DATABASE_URL"
if dst_key not in env:
print(f"错误: .env 中未配置 {dst_key}")
sys.exit(1)
print(f"同步方向: {src_label}{dst_label}")
print(f"模式: {mode} | 表数: {len(tables)} | {'预览模式' if dry_run else '执行模式'}")
print(f"表列表: {', '.join(tables)}")
print("-" * 60)
src_conn = get_connection(env, src_key)
dst_conn = get_connection(env, dst_key)
# 关闭外键检查
if not dry_run:
dst_cur = dst_conn.cursor()
dst_cur.execute("SET FOREIGN_KEY_CHECKS = 0")
total_rows = 0
errors = 0
try:
for table in tables:
result = sync_table(src_conn, dst_conn, table, mode, dry_run)
if result > 0:
total_rows += result
elif result < 0:
errors += 1
if not dry_run:
dst_conn.commit()
finally:
if not dry_run:
dst_cur = dst_conn.cursor()
dst_cur.execute("SET FOREIGN_KEY_CHECKS = 1")
src_conn.close()
dst_conn.close()
print("-" * 60)
status = "预览完成" if dry_run else "同步完成"
if errors:
print(f"{status}: {total_rows} 行, {errors} 个错误")
else:
print(f"{status}: {total_rows} 行, 无错误")
# ── CLI ───────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="数据库同步脚本:本地 ↔ 云端",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
%(prog)s push # 本地→云端(核心+用户数据+知识库)
%(prog)s push --tables all # 同步所有表
%(prog)s push --tables agents,users # 只同步指定表
%(prog)s push --exclude executions # 排除指定表
%(prog)s push --dry-run # 预览模式
%(prog)s push --mode replace # 全量替换DELETE+INSERT
%(prog)s pull # 云端→本地
""",
)
parser.add_argument(
"direction",
choices=["push", "pull"],
help="push = 本地→云端, pull = 云端→本地",
)
parser.add_argument(
"--tables",
default=None,
help="要同步的表(逗号分隔),默认=核心+用户+知识库, 'all'=全部表",
)
parser.add_argument(
"--exclude",
default=None,
help="排除的表(逗号分隔)",
)
parser.add_argument(
"--mode",
choices=["upsert", "replace", "append"],
default="upsert",
help="upsert=有主键则更新(recommended), replace=删后插入, append=跳过重复 (默认: upsert)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="预览模式,不实际写入数据",
)
args = parser.parse_args()
env = _load_env()
if "DATABASE_URL" not in env:
print("错误: .env 文件中未找到 DATABASE_URL")
sys.exit(1)
if args.direction == "push" and "CLOUD_DATABASE_URL" not in env:
print("错误: .env 文件中未找到 CLOUD_DATABASE_URL推送需要云端连接信息")
sys.exit(1)
tables = resolve_tables(args.tables, args.exclude)
run_sync(
env=env,
tables=tables,
direction=args.direction,
mode=args.mode,
dry_run=args.dry_run,
)
if __name__ == "__main__":
main()