Files
aiagent/backend/scripts/sync_db.py

426 lines
13 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
数据库同步脚本本地 云端
用法
python scripts/sync_db.py push # 本地→云端(核心+用户数据)
python scripts/sync_db.py push --tables all # 全部表
python scripts/sync_db.py push --tables agents,users
python scripts/sync_db.py push --exclude executions,agent_llm_logs
python scripts/sync_db.py pull # 云端→本地
python scripts/sync_db.py push --dry-run # 预览
python scripts/sync_db.py push --mode replace
"""
from __future__ import annotations
import argparse
import os
import re
import sys
from collections import OrderedDict
from typing import Dict, List, Tuple
import pymysql
# ── 从 .env 读取数据库连接 ──────────────────────────────────────────────
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_ENV_FILE = os.path.join(os.path.dirname(_SCRIPT_DIR), ".env")
def _load_env() -> Dict[str, str]:
"""解析 .env 文件,返回键值对。"""
env: Dict[str, str] = {}
if not os.path.exists(_ENV_FILE):
return env
with open(_ENV_FILE, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
m = re.match(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(.+)$", line)
if m:
env[m.group(1)] = m.group(2).strip("\"'")
return env
def _parse_db_url(url: str) -> dict:
"""解析 mysql+pymysql://user:pass@host:port/db?params → pymysql.connect kwargs."""
m = re.match(
r"mysql\+pymysql://(?P<user>[^:]+):(?P<password>[^@]+)"
r"@(?P<host>[^:/]+):?(?P<port>\d+)?/(?P<database>[^?]+)",
url,
)
if not m:
raise ValueError(f"无法解析数据库 URL: {url}")
return {
"host": m.group("host"),
"port": int(m.group("port") or 3306),
"user": m.group("user"),
"password": m.group("password"),
"database": m.group("database"),
"charset": "utf8mb4",
"connect_timeout": 30,
"read_timeout": 60,
"write_timeout": 60,
}
# ── 表分组配置 ───────────────────────────────────────────────────────────
TABLE_GROUPS: Dict[str, List[str]] = {
"core": [
"roles",
"permissions",
"role_permissions",
"tools",
"node_templates",
"node_plugins",
"workflow_templates",
"template_favorites",
"template_ratings",
"orchestration_templates",
"agents",
"agent_extensions",
"agent_favorites",
"agent_ratings",
"agent_permissions",
"knowledge_bases",
"kb_documents",
"kb_document_chunks",
"data_sources",
"model_configs",
"scene_contracts",
"alert_rules",
"alert_logs",
"workflows",
"workflow_versions",
"workflow_permissions",
"conversation_branches",
"push_subscriptions",
"fcm_tokens",
],
"user_data": [
"users",
"user_roles",
"user_feishu_open_ids",
"user_fingerprints",
"workspaces",
"workspace_memberships",
"teams",
"team_members",
"goals",
"tasks",
"agent_schedules",
"agent_execution_logs",
"feedback_records",
],
"knowledge": [
"global_knowledge",
"knowledge_entries",
"knowledge_entities",
"knowledge_relations",
],
"logs": [
"executions",
"execution_logs",
"agent_execution_logs",
"agent_llm_logs",
"agent_learning_patterns",
"agent_vector_memories",
"audit_logs",
"notifications",
"user_behavior_logs",
"persistent_user_memories",
"shadow_comparisons",
"chat_messages",
],
}
# 不需要同步的系统表
SYSTEM_SKIP = {"alembic_version"}
def resolve_tables(tables_arg: str | None, exclude_arg: str | None) -> List[str]:
"""解析命令行参数,返回要同步的表列表(按 FK 依赖排序)。"""
if tables_arg == "all" or tables_arg is None:
selected = set()
if tables_arg is None:
# 默认core + user_data + knowledge排除 logs
for group in ["core", "user_data", "knowledge"]:
selected.update(TABLE_GROUPS.get(group, []))
else:
# all所有表
for tables in TABLE_GROUPS.values():
selected.update(tables)
else:
selected = {t.strip() for t in tables_arg.split(",") if t.strip()}
if exclude_arg:
excluded = {t.strip() for t in exclude_arg.split(",") if t.strip()}
selected -= excluded
selected -= SYSTEM_SKIP
return _sort_by_fk(selected)
# FK 依赖顺序(主表在前)
FK_ORDER: List[str] = [
"roles", "permissions", "role_permissions", "users", "user_roles",
"user_feishu_open_ids", "user_fingerprints", "workspaces",
"workspace_memberships", "teams", "team_members",
"workflow_templates", "template_favorites", "template_ratings",
"node_templates", "node_plugins", "orchestration_templates",
"tools", "data_sources", "model_configs",
"knowledge_bases", "kb_documents", "kb_document_chunks",
"global_knowledge", "knowledge_entries", "knowledge_entities",
"knowledge_relations",
"agents", "agent_extensions", "agent_favorites", "agent_ratings",
"agent_permissions", "agent_schedules",
"workflows", "workflow_versions", "workflow_permissions",
"executions", "execution_logs",
"agent_execution_logs", "agent_llm_logs", "agent_learning_patterns",
"agent_vector_memories",
"goals", "tasks",
"alert_rules", "alert_logs",
"notifications", "audit_logs", "user_behavior_logs",
"persistent_user_memories", "shadow_comparisons",
"feedback_records",
"conversation_branches",
"scene_contracts",
"push_subscriptions", "fcm_tokens",
"chat_messages",
]
def _sort_by_fk(table_names: set) -> List[str]:
"""按 FK 依赖顺序排列表名。"""
order_map = {name: i for i, name in enumerate(FK_ORDER)}
sorted_tables = sorted(
table_names, key=lambda t: order_map.get(t, len(FK_ORDER))
)
return sorted_tables
# ── 核心同步逻辑 ──────────────────────────────────────────────────────────
def get_connection(env: dict, key: str = "DATABASE_URL") -> pymysql.Connection:
kwargs = _parse_db_url(env[key])
return pymysql.connect(**kwargs)
def get_column_names(cursor, table: str) -> List[str]:
cursor.execute(f"DESC `{table}`")
return [row[0] for row in cursor.fetchall()]
def get_primary_key(cursor, table: str) -> str | None:
cursor.execute(f"DESC `{table}`")
for row in cursor.fetchall():
if row[3] == "PRI": # Key column
return row[0]
return None
def fetch_table(cursor, table: str, columns: List[str]) -> List[Tuple]:
cols = ", ".join(f"`{c}`" for c in columns)
cursor.execute(f"SELECT {cols} FROM `{table}`")
return cursor.fetchall()
def sync_table(
src_conn: pymysql.Connection,
dst_conn: pymysql.Connection,
table: str,
mode: str = "upsert",
dry_run: bool = False,
):
"""同步单张表的数据。"""
src_cur = src_conn.cursor()
dst_cur = dst_conn.cursor()
columns = get_column_names(src_cur, table)
if not columns:
print(f" [SKIP] {table}: 无法读取列信息")
return 0
pk = get_primary_key(src_cur, table)
rows = fetch_table(src_cur, table, columns)
row_count = len(rows)
if row_count == 0:
print(f" [SKIP] {table}: 源表无数据")
return 0
col_placeholders = ", ".join(["%s"] * len(columns))
cols_quoted = ", ".join(f"`{c}`" for c in columns)
if dry_run:
print(f" [DRY-RUN] {table}: {row_count} 行 (模式: {mode})")
return 0
try:
if mode == "replace":
dst_cur.execute(f"DELETE FROM `{table}`")
insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
dst_cur.executemany(insert_sql, rows)
elif mode == "upsert" and pk:
update_cols = [c for c in columns if c != pk]
update_clause = ", ".join(f"`{c}` = VALUES(`{c}`)" for c in update_cols)
insert_sql = (
f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders}) "
f"ON DUPLICATE KEY UPDATE {update_clause}"
)
dst_cur.executemany(insert_sql, rows)
elif mode == "append":
insert_sql = f"INSERT IGNORE INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
dst_cur.executemany(insert_sql, rows)
else:
# 没有主键的表,直接用 replace 模式
dst_cur.execute(f"DELETE FROM `{table}`")
insert_sql = f"INSERT INTO `{table}` ({cols_quoted}) VALUES ({col_placeholders})"
dst_cur.executemany(insert_sql, rows)
print(f" [OK] {table}: {row_count}")
return row_count
except Exception as e:
print(f" [ERROR] {table}: {e}")
return -1
def run_sync(
env: dict,
tables: List[str],
direction: str,
mode: str,
dry_run: bool,
):
"""执行所有表的同步。"""
if direction == "push":
src_label, dst_label = "本地", "云端"
src_key, dst_key = "DATABASE_URL", "CLOUD_DATABASE_URL"
else:
src_label, dst_label = "云端", "本地"
src_key, dst_key = "CLOUD_DATABASE_URL", "DATABASE_URL"
if dst_key not in env:
print(f"错误: .env 中未配置 {dst_key}")
sys.exit(1)
print(f"同步方向: {src_label}{dst_label}")
print(f"模式: {mode} | 表数: {len(tables)} | {'预览模式' if dry_run else '执行模式'}")
print(f"表列表: {', '.join(tables)}")
print("-" * 60)
src_conn = get_connection(env, src_key)
dst_conn = get_connection(env, dst_key)
# 关闭外键检查
if not dry_run:
dst_cur = dst_conn.cursor()
dst_cur.execute("SET FOREIGN_KEY_CHECKS = 0")
total_rows = 0
errors = 0
try:
for table in tables:
result = sync_table(src_conn, dst_conn, table, mode, dry_run)
if result > 0:
total_rows += result
elif result < 0:
errors += 1
if not dry_run:
dst_conn.commit()
finally:
if not dry_run:
dst_cur = dst_conn.cursor()
dst_cur.execute("SET FOREIGN_KEY_CHECKS = 1")
src_conn.close()
dst_conn.close()
print("-" * 60)
status = "预览完成" if dry_run else "同步完成"
if errors:
print(f"{status}: {total_rows} 行, {errors} 个错误")
else:
print(f"{status}: {total_rows} 行, 无错误")
# ── CLI ───────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="数据库同步脚本:本地 ↔ 云端",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
%(prog)s push # 本地→云端(核心+用户数据+知识库)
%(prog)s push --tables all # 同步所有表
%(prog)s push --tables agents,users # 只同步指定表
%(prog)s push --exclude executions # 排除指定表
%(prog)s push --dry-run # 预览模式
%(prog)s push --mode replace # 全量替换DELETE+INSERT
%(prog)s pull # 云端→本地
""",
)
parser.add_argument(
"direction",
choices=["push", "pull"],
help="push = 本地→云端, pull = 云端→本地",
)
parser.add_argument(
"--tables",
default=None,
help="要同步的表(逗号分隔),默认=核心+用户+知识库, 'all'=全部表",
)
parser.add_argument(
"--exclude",
default=None,
help="排除的表(逗号分隔)",
)
parser.add_argument(
"--mode",
choices=["upsert", "replace", "append"],
default="upsert",
help="upsert=有主键则更新(recommended), replace=删后插入, append=跳过重复 (默认: upsert)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="预览模式,不实际写入数据",
)
args = parser.parse_args()
env = _load_env()
if "DATABASE_URL" not in env:
print("错误: .env 文件中未找到 DATABASE_URL")
sys.exit(1)
if args.direction == "push" and "CLOUD_DATABASE_URL" not in env:
print("错误: .env 文件中未找到 CLOUD_DATABASE_URL推送需要云端连接信息")
sys.exit(1)
tables = resolve_tables(args.tables, args.exclude)
run_sync(
env=env,
tables=tables,
direction=args.direction,
mode=args.mode,
dry_run=args.dry_run,
)
if __name__ == "__main__":
main()