第一次提交
This commit is contained in:
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Services package
|
||||
391
backend/app/services/alert_service.py
Normal file
391
backend/app/services/alert_service.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
告警服务
|
||||
提供告警检测和通知功能
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, and_
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
import logging
|
||||
from app.models.alert_rule import AlertRule, AlertLog
|
||||
from app.models.execution import Execution
|
||||
from app.models.workflow import Workflow
|
||||
from app.models.execution_log import ExecutionLog
|
||||
import httpx
|
||||
import aiosmtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlertService:
|
||||
"""告警服务类"""
|
||||
|
||||
@staticmethod
|
||||
async def check_execution_failed(
|
||||
db: Session,
|
||||
rule: AlertRule,
|
||||
execution: Execution
|
||||
) -> bool:
|
||||
"""
|
||||
检查执行失败告警
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
rule: 告警规则
|
||||
execution: 执行记录
|
||||
|
||||
Returns:
|
||||
是否触发告警
|
||||
"""
|
||||
if execution.status != 'failed':
|
||||
return False
|
||||
|
||||
# 检查目标是否匹配
|
||||
if rule.target_type == 'workflow' and rule.target_id:
|
||||
if execution.workflow_id != rule.target_id:
|
||||
return False
|
||||
elif rule.target_type == 'agent' and rule.target_id:
|
||||
if execution.agent_id != rule.target_id:
|
||||
return False
|
||||
|
||||
# 检查时间窗口内的失败次数
|
||||
conditions = rule.conditions
|
||||
threshold = conditions.get('threshold', 1)
|
||||
time_window = conditions.get('time_window', 3600) # 默认1小时
|
||||
comparison = conditions.get('comparison', 'gt') # gt, gte, eq
|
||||
|
||||
start_time = datetime.utcnow() - timedelta(seconds=time_window)
|
||||
|
||||
# 构建查询条件
|
||||
query = db.query(func.count(Execution.id)).filter(
|
||||
Execution.status == 'failed',
|
||||
Execution.created_at >= start_time
|
||||
)
|
||||
|
||||
if rule.target_type == 'workflow' and rule.target_id:
|
||||
query = query.filter(Execution.workflow_id == rule.target_id)
|
||||
elif rule.target_type == 'agent' and rule.target_id:
|
||||
query = query.filter(Execution.agent_id == rule.target_id)
|
||||
|
||||
failed_count = query.scalar() or 0
|
||||
|
||||
# 根据比较操作符判断
|
||||
if comparison == 'gt':
|
||||
return failed_count > threshold
|
||||
elif comparison == 'gte':
|
||||
return failed_count >= threshold
|
||||
elif comparison == 'eq':
|
||||
return failed_count == threshold
|
||||
else:
|
||||
return failed_count > threshold
|
||||
|
||||
@staticmethod
|
||||
async def check_execution_timeout(
|
||||
db: Session,
|
||||
rule: AlertRule,
|
||||
execution: Execution
|
||||
) -> bool:
|
||||
"""
|
||||
检查执行超时告警
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
rule: 告警规则
|
||||
execution: 执行记录
|
||||
|
||||
Returns:
|
||||
是否触发告警
|
||||
"""
|
||||
if execution.status != 'running':
|
||||
return False
|
||||
|
||||
# 检查目标是否匹配
|
||||
if rule.target_type == 'workflow' and rule.target_id:
|
||||
if execution.workflow_id != rule.target_id:
|
||||
return False
|
||||
elif rule.target_type == 'agent' and rule.target_id:
|
||||
if execution.agent_id != rule.target_id:
|
||||
return False
|
||||
|
||||
# 检查执行时间
|
||||
conditions = rule.conditions
|
||||
timeout_seconds = conditions.get('timeout_seconds', 3600) # 默认1小时
|
||||
|
||||
if execution.created_at:
|
||||
elapsed = (datetime.utcnow() - execution.created_at).total_seconds()
|
||||
return elapsed > timeout_seconds
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def check_error_rate(
|
||||
db: Session,
|
||||
rule: AlertRule
|
||||
) -> bool:
|
||||
"""
|
||||
检查错误率告警
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
rule: 告警规则
|
||||
|
||||
Returns:
|
||||
是否触发告警
|
||||
"""
|
||||
conditions = rule.conditions
|
||||
threshold = conditions.get('threshold', 0.1) # 默认10%
|
||||
time_window = conditions.get('time_window', 3600) # 默认1小时
|
||||
|
||||
start_time = datetime.utcnow() - timedelta(seconds=time_window)
|
||||
|
||||
# 构建查询条件
|
||||
query = db.query(Execution).filter(
|
||||
Execution.created_at >= start_time
|
||||
)
|
||||
|
||||
if rule.target_type == 'workflow' and rule.target_id:
|
||||
query = query.filter(Execution.workflow_id == rule.target_id)
|
||||
elif rule.target_type == 'agent' and rule.target_id:
|
||||
query = query.filter(Execution.agent_id == rule.target_id)
|
||||
|
||||
executions = query.all()
|
||||
|
||||
if not executions:
|
||||
return False
|
||||
|
||||
total_count = len(executions)
|
||||
failed_count = sum(1 for e in executions if e.status == 'failed')
|
||||
error_rate = failed_count / total_count if total_count > 0 else 0
|
||||
|
||||
return error_rate >= threshold
|
||||
|
||||
@staticmethod
|
||||
async def check_alerts_for_execution(
|
||||
db: Session,
|
||||
execution: Execution
|
||||
) -> List[AlertLog]:
|
||||
"""
|
||||
检查执行记录相关的告警规则
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
execution: 执行记录
|
||||
|
||||
Returns:
|
||||
触发的告警日志列表
|
||||
"""
|
||||
triggered_logs = []
|
||||
|
||||
# 获取相关的告警规则
|
||||
query = db.query(AlertRule).filter(AlertRule.enabled == True)
|
||||
|
||||
# 根据执行记录筛选相关规则
|
||||
if execution.workflow_id:
|
||||
query = query.filter(
|
||||
(AlertRule.target_type == 'workflow') &
|
||||
((AlertRule.target_id == execution.workflow_id) | (AlertRule.target_id.is_(None)))
|
||||
)
|
||||
elif execution.agent_id:
|
||||
query = query.filter(
|
||||
(AlertRule.target_type == 'agent') &
|
||||
((AlertRule.target_id == execution.agent_id) | (AlertRule.target_id.is_(None)))
|
||||
)
|
||||
|
||||
# 也包含系统级告警
|
||||
query = query.filter(AlertRule.target_type == 'system')
|
||||
|
||||
rules = query.all()
|
||||
|
||||
for rule in rules:
|
||||
try:
|
||||
should_trigger = False
|
||||
alert_message = ""
|
||||
alert_details = {}
|
||||
|
||||
if rule.alert_type == 'execution_failed':
|
||||
should_trigger = await AlertService.check_execution_failed(db, rule, execution)
|
||||
if should_trigger:
|
||||
alert_message = f"执行失败告警: 工作流 {execution.workflow_id} 执行失败"
|
||||
alert_details = {
|
||||
"execution_id": execution.id,
|
||||
"workflow_id": execution.workflow_id,
|
||||
"status": execution.status,
|
||||
"error_message": execution.error_message
|
||||
}
|
||||
|
||||
elif rule.alert_type == 'execution_timeout':
|
||||
should_trigger = await AlertService.check_execution_timeout(db, rule, execution)
|
||||
if should_trigger:
|
||||
elapsed = (datetime.utcnow() - execution.created_at).total_seconds() if execution.created_at else 0
|
||||
alert_message = f"执行超时告警: 工作流 {execution.workflow_id} 执行超时 ({elapsed:.0f}秒)"
|
||||
alert_details = {
|
||||
"execution_id": execution.id,
|
||||
"workflow_id": execution.workflow_id,
|
||||
"elapsed_seconds": elapsed
|
||||
}
|
||||
|
||||
elif rule.alert_type == 'error_rate':
|
||||
should_trigger = await AlertService.check_error_rate(db, rule)
|
||||
if should_trigger:
|
||||
alert_message = f"错误率告警: {rule.target_type} 错误率超过阈值"
|
||||
alert_details = {
|
||||
"target_type": rule.target_type,
|
||||
"target_id": rule.target_id
|
||||
}
|
||||
|
||||
if should_trigger:
|
||||
# 创建告警日志
|
||||
alert_log = AlertLog(
|
||||
rule_id=rule.id,
|
||||
alert_type=rule.alert_type,
|
||||
severity=rule.conditions.get('severity', 'warning'),
|
||||
message=alert_message,
|
||||
details=alert_details,
|
||||
status='pending',
|
||||
notification_type=rule.notification_type,
|
||||
triggered_at=datetime.utcnow()
|
||||
)
|
||||
db.add(alert_log)
|
||||
|
||||
# 更新规则统计
|
||||
rule.trigger_count += 1
|
||||
rule.last_triggered_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(alert_log)
|
||||
|
||||
# 发送通知
|
||||
await AlertService.send_notification(db, alert_log, rule)
|
||||
|
||||
triggered_logs.append(alert_log)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查告警规则失败 {rule.id}: {str(e)}")
|
||||
continue
|
||||
|
||||
return triggered_logs
|
||||
|
||||
@staticmethod
|
||||
async def send_notification(
|
||||
db: Session,
|
||||
alert_log: AlertLog,
|
||||
rule: AlertRule
|
||||
):
|
||||
"""
|
||||
发送告警通知
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
alert_log: 告警日志
|
||||
rule: 告警规则
|
||||
"""
|
||||
try:
|
||||
if rule.notification_type == 'email':
|
||||
await AlertService.send_email_notification(alert_log, rule)
|
||||
elif rule.notification_type == 'webhook':
|
||||
await AlertService.send_webhook_notification(alert_log, rule)
|
||||
elif rule.notification_type == 'internal':
|
||||
# 站内通知,只需要记录日志即可
|
||||
pass
|
||||
|
||||
alert_log.status = 'sent'
|
||||
alert_log.notification_result = '通知发送成功'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送告警通知失败: {str(e)}")
|
||||
alert_log.status = 'failed'
|
||||
alert_log.notification_result = f"通知发送失败: {str(e)}"
|
||||
|
||||
finally:
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
async def send_email_notification(
|
||||
alert_log: AlertLog,
|
||||
rule: AlertRule
|
||||
):
|
||||
"""
|
||||
发送邮件通知
|
||||
|
||||
Args:
|
||||
alert_log: 告警日志
|
||||
rule: 告警规则
|
||||
"""
|
||||
config = rule.notification_config or {}
|
||||
smtp_host = config.get('smtp_host', 'smtp.gmail.com')
|
||||
smtp_port = config.get('smtp_port', 587)
|
||||
smtp_user = config.get('smtp_user')
|
||||
smtp_password = config.get('smtp_password')
|
||||
to_email = config.get('to_email')
|
||||
|
||||
if not to_email:
|
||||
raise ValueError("邮件通知配置缺少收件人地址")
|
||||
|
||||
# 创建邮件
|
||||
message = MIMEMultipart()
|
||||
message['From'] = smtp_user
|
||||
message['To'] = to_email
|
||||
message['Subject'] = f"告警通知: {rule.name}"
|
||||
|
||||
body = f"""
|
||||
告警规则: {rule.name}
|
||||
告警类型: {alert_log.alert_type}
|
||||
严重程度: {alert_log.severity}
|
||||
告警消息: {alert_log.message}
|
||||
触发时间: {alert_log.triggered_at}
|
||||
"""
|
||||
|
||||
if alert_log.details:
|
||||
body += f"\n详细信息:\n{alert_log.details}"
|
||||
|
||||
message.attach(MIMEText(body, 'plain', 'utf-8'))
|
||||
|
||||
# 发送邮件
|
||||
await aiosmtplib.send(
|
||||
message,
|
||||
hostname=smtp_host,
|
||||
port=smtp_port,
|
||||
username=smtp_user,
|
||||
password=smtp_password,
|
||||
use_tls=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_webhook_notification(
|
||||
alert_log: AlertLog,
|
||||
rule: AlertRule
|
||||
):
|
||||
"""
|
||||
发送Webhook通知
|
||||
|
||||
Args:
|
||||
alert_log: 告警日志
|
||||
rule: 告警规则
|
||||
"""
|
||||
config = rule.notification_config or {}
|
||||
webhook_url = config.get('webhook_url')
|
||||
|
||||
if not webhook_url:
|
||||
raise ValueError("Webhook通知配置缺少URL")
|
||||
|
||||
# 构建请求数据
|
||||
payload = {
|
||||
"rule_name": rule.name,
|
||||
"alert_type": alert_log.alert_type,
|
||||
"severity": alert_log.severity,
|
||||
"message": alert_log.message,
|
||||
"details": alert_log.details,
|
||||
"triggered_at": alert_log.triggered_at.isoformat() if alert_log.triggered_at else None
|
||||
}
|
||||
|
||||
# 发送HTTP请求
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
webhook_url,
|
||||
json=payload,
|
||||
headers=config.get('headers', {}),
|
||||
timeout=10
|
||||
)
|
||||
response.raise_for_status()
|
||||
276
backend/app/services/condition_parser.py
Normal file
276
backend/app/services/condition_parser.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
条件表达式解析器
|
||||
支持更复杂的条件判断表达式
|
||||
"""
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, Any, Union
|
||||
|
||||
|
||||
class ConditionParser:
|
||||
"""条件表达式解析器"""
|
||||
|
||||
# 支持的运算符
|
||||
OPERATORS = {
|
||||
'==': lambda a, b: a == b,
|
||||
'!=': lambda a, b: a != b,
|
||||
'>': lambda a, b: a > b,
|
||||
'>=': lambda a, b: a >= b,
|
||||
'<': lambda a, b: a < b,
|
||||
'<=': lambda a, b: a <= b,
|
||||
'in': lambda a, b: a in b if isinstance(b, (list, str, dict)) else False,
|
||||
'not in': lambda a, b: a not in b if isinstance(b, (list, str, dict)) else False,
|
||||
'contains': lambda a, b: b in str(a) if a is not None else False,
|
||||
'not contains': lambda a, b: b not in str(a) if a is not None else True,
|
||||
}
|
||||
|
||||
# 逻辑运算符
|
||||
LOGICAL_OPERATORS = {
|
||||
'and': lambda a, b: a and b,
|
||||
'or': lambda a, b: a or b,
|
||||
'not': lambda a: not a,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_value(path: str, data: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
从数据中获取值(支持嵌套路径)
|
||||
|
||||
Args:
|
||||
path: 路径,如 'user.name' 或 'items[0].price'
|
||||
data: 数据字典
|
||||
|
||||
Returns:
|
||||
值,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
# 处理数组索引,如 items[0]
|
||||
if '[' in path and ']' in path:
|
||||
parts = re.split(r'\[|\]', path)
|
||||
value = data
|
||||
for part in parts:
|
||||
if not part:
|
||||
continue
|
||||
if part.isdigit():
|
||||
value = value[int(part)]
|
||||
else:
|
||||
value = value.get(part) if isinstance(value, dict) else None
|
||||
if value is None:
|
||||
return None
|
||||
return value
|
||||
|
||||
# 处理嵌套路径,如 user.name
|
||||
keys = path.split('.')
|
||||
value = data
|
||||
for key in keys:
|
||||
if isinstance(value, dict):
|
||||
value = value.get(key)
|
||||
elif isinstance(value, list) and key.isdigit():
|
||||
value = value[int(key)] if int(key) < len(value) else None
|
||||
else:
|
||||
return None
|
||||
if value is None:
|
||||
return None
|
||||
return value
|
||||
except (KeyError, IndexError, TypeError, AttributeError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_value(value_str: str) -> Any:
|
||||
"""
|
||||
解析值字符串(支持字符串、数字、布尔值、JSON)
|
||||
|
||||
Args:
|
||||
value_str: 值字符串
|
||||
|
||||
Returns:
|
||||
解析后的值
|
||||
"""
|
||||
value_str = value_str.strip()
|
||||
|
||||
# 布尔值
|
||||
if value_str.lower() == 'true':
|
||||
return True
|
||||
if value_str.lower() == 'false':
|
||||
return False
|
||||
|
||||
# None
|
||||
if value_str.lower() == 'null' or value_str.lower() == 'none':
|
||||
return None
|
||||
|
||||
# 数字
|
||||
try:
|
||||
if '.' in value_str:
|
||||
return float(value_str)
|
||||
return int(value_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# JSON
|
||||
if value_str.startswith('{') or value_str.startswith('['):
|
||||
try:
|
||||
return json.loads(value_str)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 字符串(移除引号)
|
||||
if (value_str.startswith('"') and value_str.endswith('"')) or \
|
||||
(value_str.startswith("'") and value_str.endswith("'")):
|
||||
return value_str[1:-1]
|
||||
|
||||
return value_str
|
||||
|
||||
@staticmethod
|
||||
def evaluate_simple_condition(condition: str, data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
评估简单条件表达式
|
||||
|
||||
支持的格式:
|
||||
- {key} == value
|
||||
- {key} > value
|
||||
- {key} in [value1, value2]
|
||||
- {key} contains "text"
|
||||
|
||||
Args:
|
||||
condition: 条件表达式
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
条件结果
|
||||
"""
|
||||
condition = condition.strip()
|
||||
|
||||
# 替换变量 {key}
|
||||
for key, value in data.items():
|
||||
placeholder = f'{{{key}}}'
|
||||
if placeholder in condition:
|
||||
# 如果值是复杂类型,转换为JSON字符串
|
||||
if isinstance(value, (dict, list)):
|
||||
condition = condition.replace(placeholder, json.dumps(value, ensure_ascii=False))
|
||||
else:
|
||||
condition = condition.replace(placeholder, str(value))
|
||||
|
||||
# 尝试解析为Python表达式(安全方式)
|
||||
try:
|
||||
# 只允许安全的操作
|
||||
safe_dict = {
|
||||
'__builtins__': {},
|
||||
'True': True,
|
||||
'False': False,
|
||||
'None': None,
|
||||
'null': None,
|
||||
}
|
||||
|
||||
# 添加数据中的值到安全字典
|
||||
for key, value in data.items():
|
||||
# 只添加简单的值,避免复杂对象
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
safe_dict[key] = value
|
||||
|
||||
# 尝试评估
|
||||
result = eval(condition, safe_dict)
|
||||
if isinstance(result, bool):
|
||||
return result
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果eval失败,尝试手动解析
|
||||
# 匹配运算符
|
||||
for op in ConditionParser.OPERATORS.keys():
|
||||
if op in condition:
|
||||
parts = condition.split(op, 1)
|
||||
if len(parts) == 2:
|
||||
left = parts[0].strip()
|
||||
right = parts[1].strip()
|
||||
|
||||
# 获取左侧值
|
||||
if left.startswith('{') and left.endswith('}'):
|
||||
key = left[1:-1]
|
||||
left_value = ConditionParser.get_value(key, data)
|
||||
else:
|
||||
left_value = ConditionParser.parse_value(left)
|
||||
|
||||
# 获取右侧值
|
||||
right_value = ConditionParser.parse_value(right)
|
||||
|
||||
# 执行比较
|
||||
if left_value is not None:
|
||||
return ConditionParser.OPERATORS[op](left_value, right_value)
|
||||
|
||||
# 默认返回False
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def evaluate_condition(condition: str, data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
评估条件表达式(支持复杂表达式)
|
||||
|
||||
支持的格式:
|
||||
- 简单条件: {key} == value
|
||||
- 逻辑组合: {key} > 10 and {key} < 20
|
||||
- 括号分组: ({key} == 'a' or {key} == 'b') and {other} > 0
|
||||
|
||||
Args:
|
||||
condition: 条件表达式
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
条件结果
|
||||
"""
|
||||
if not condition:
|
||||
return False
|
||||
|
||||
condition = condition.strip()
|
||||
|
||||
# 处理括号表达式(递归处理)
|
||||
def process_parentheses(expr: str) -> str:
|
||||
"""处理括号表达式"""
|
||||
while '(' in expr and ')' in expr:
|
||||
# 找到最内层的括号
|
||||
start = expr.rfind('(')
|
||||
end = expr.find(')', start)
|
||||
if end == -1:
|
||||
break
|
||||
|
||||
# 提取括号内的表达式
|
||||
inner_expr = expr[start+1:end]
|
||||
inner_result = ConditionParser.evaluate_condition(inner_expr, data)
|
||||
|
||||
# 替换括号表达式为结果
|
||||
expr = expr[:start] + str(inner_result) + expr[end+1:]
|
||||
return expr
|
||||
|
||||
condition = process_parentheses(condition)
|
||||
|
||||
# 分割逻辑运算符,按优先级处理
|
||||
# 先处理 and(优先级更高)
|
||||
if ' and ' in condition.lower():
|
||||
parts = re.split(r'\s+and\s+', condition, flags=re.IGNORECASE)
|
||||
results = []
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if part:
|
||||
results.append(ConditionParser.evaluate_simple_condition(part, data))
|
||||
return all(results)
|
||||
|
||||
# 再处理 or
|
||||
if ' or ' in condition.lower():
|
||||
parts = re.split(r'\s+or\s+', condition, flags=re.IGNORECASE)
|
||||
results = []
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if part:
|
||||
results.append(ConditionParser.evaluate_simple_condition(part, data))
|
||||
return any(results)
|
||||
|
||||
# 处理 not
|
||||
if condition.lower().startswith('not '):
|
||||
inner = condition[4:].strip()
|
||||
return not ConditionParser.evaluate_simple_condition(inner, data)
|
||||
|
||||
# 最终评估简单条件
|
||||
return ConditionParser.evaluate_simple_condition(condition, data)
|
||||
|
||||
|
||||
# 全局实例
|
||||
condition_parser = ConditionParser()
|
||||
285
backend/app/services/data_source_connector.py
Normal file
285
backend/app/services/data_source_connector.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
数据源连接器服务
|
||||
"""
|
||||
from typing import Dict, Any, List, Optional
|
||||
import logging
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataSourceConnector:
|
||||
"""数据源连接器基类"""
|
||||
|
||||
def __init__(self, source_type: str, config: Dict[str, Any]):
|
||||
"""
|
||||
初始化数据源连接器
|
||||
|
||||
Args:
|
||||
source_type: 数据源类型
|
||||
config: 连接配置
|
||||
"""
|
||||
self.source_type = source_type
|
||||
self.config = config
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
"""
|
||||
测试连接
|
||||
|
||||
Returns:
|
||||
连接测试结果
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现test_connection方法")
|
||||
|
||||
def query(self, query_params: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
查询数据
|
||||
|
||||
Args:
|
||||
query_params: 查询参数
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现query方法")
|
||||
|
||||
|
||||
class MySQLConnector(DataSourceConnector):
|
||||
"""MySQL连接器"""
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
try:
|
||||
import pymysql
|
||||
connection = pymysql.connect(
|
||||
host=self.config.get('host'),
|
||||
port=self.config.get('port', 3306),
|
||||
user=self.config.get('user'),
|
||||
password=self.config.get('password'),
|
||||
database=self.config.get('database'),
|
||||
connect_timeout=5
|
||||
)
|
||||
connection.close()
|
||||
return {"status": "success", "message": "连接成功"}
|
||||
except Exception as e:
|
||||
raise Exception(f"MySQL连接失败: {str(e)}")
|
||||
|
||||
def query(self, query_params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
import pymysql
|
||||
sql = query_params.get('sql')
|
||||
if not sql:
|
||||
raise ValueError("缺少SQL查询语句")
|
||||
|
||||
connection = pymysql.connect(
|
||||
host=self.config.get('host'),
|
||||
port=self.config.get('port', 3306),
|
||||
user=self.config.get('user'),
|
||||
password=self.config.get('password'),
|
||||
database=self.config.get('database')
|
||||
)
|
||||
|
||||
try:
|
||||
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
return result
|
||||
finally:
|
||||
connection.close()
|
||||
except Exception as e:
|
||||
raise Exception(f"MySQL查询失败: {str(e)}")
|
||||
|
||||
|
||||
class PostgreSQLConnector(DataSourceConnector):
|
||||
"""PostgreSQL连接器"""
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
try:
|
||||
import psycopg2
|
||||
connection = psycopg2.connect(
|
||||
host=self.config.get('host'),
|
||||
port=self.config.get('port', 5432),
|
||||
user=self.config.get('user'),
|
||||
password=self.config.get('password'),
|
||||
database=self.config.get('database'),
|
||||
connect_timeout=5
|
||||
)
|
||||
connection.close()
|
||||
return {"status": "success", "message": "连接成功"}
|
||||
except Exception as e:
|
||||
raise Exception(f"PostgreSQL连接失败: {str(e)}")
|
||||
|
||||
def query(self, query_params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
sql = query_params.get('sql')
|
||||
if not sql:
|
||||
raise ValueError("缺少SQL查询语句")
|
||||
|
||||
connection = psycopg2.connect(
|
||||
host=self.config.get('host'),
|
||||
port=self.config.get('port', 5432),
|
||||
user=self.config.get('user'),
|
||||
password=self.config.get('password'),
|
||||
database=self.config.get('database')
|
||||
)
|
||||
|
||||
try:
|
||||
with connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
return [dict(row) for row in result]
|
||||
finally:
|
||||
connection.close()
|
||||
except Exception as e:
|
||||
raise Exception(f"PostgreSQL查询失败: {str(e)}")
|
||||
|
||||
|
||||
class APIConnector(DataSourceConnector):
|
||||
"""API连接器"""
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
try:
|
||||
import httpx
|
||||
url = self.config.get('base_url')
|
||||
if not url:
|
||||
raise ValueError("缺少base_url配置")
|
||||
|
||||
headers = self.config.get('headers', {})
|
||||
timeout = self.config.get('timeout', 10)
|
||||
|
||||
response = httpx.get(url, headers=headers, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
return {"status": "success", "message": "连接成功", "status_code": response.status_code}
|
||||
except Exception as e:
|
||||
raise Exception(f"API连接失败: {str(e)}")
|
||||
|
||||
def query(self, query_params: Dict[str, Any]) -> Any:
|
||||
try:
|
||||
import httpx
|
||||
method = query_params.get('method', 'GET').upper()
|
||||
endpoint = query_params.get('endpoint', '')
|
||||
params = query_params.get('params', {})
|
||||
data = query_params.get('data', {})
|
||||
headers = self.config.get('headers', {})
|
||||
timeout = self.config.get('timeout', 10)
|
||||
|
||||
base_url = self.config.get('base_url', '').rstrip('/')
|
||||
url = f"{base_url}/{endpoint.lstrip('/')}"
|
||||
|
||||
if method == 'GET':
|
||||
response = httpx.get(url, params=params, headers=headers, timeout=timeout)
|
||||
elif method == 'POST':
|
||||
response = httpx.post(url, json=data, headers=headers, timeout=timeout)
|
||||
elif method == 'PUT':
|
||||
response = httpx.put(url, json=data, headers=headers, timeout=timeout)
|
||||
elif method == 'DELETE':
|
||||
response = httpx.delete(url, headers=headers, timeout=timeout)
|
||||
else:
|
||||
raise ValueError(f"不支持的HTTP方法: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json() if response.content else {}
|
||||
except Exception as e:
|
||||
raise Exception(f"API查询失败: {str(e)}")
|
||||
|
||||
|
||||
class JSONFileConnector(DataSourceConnector):
|
||||
"""JSON文件连接器"""
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
try:
|
||||
import os
|
||||
file_path = self.config.get('file_path')
|
||||
if not file_path:
|
||||
raise ValueError("缺少file_path配置")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
return {"status": "success", "message": "文件存在"}
|
||||
except Exception as e:
|
||||
raise Exception(f"JSON文件连接失败: {str(e)}")
|
||||
|
||||
def query(self, query_params: Dict[str, Any]) -> Any:
|
||||
try:
|
||||
import json
|
||||
import os
|
||||
file_path = self.config.get('file_path')
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 支持简单的查询过滤
|
||||
filter_path = query_params.get('path')
|
||||
if filter_path:
|
||||
# 支持JSONPath风格的路径查询
|
||||
parts = filter_path.split('.')
|
||||
result = data
|
||||
for part in parts:
|
||||
if isinstance(result, dict):
|
||||
result = result.get(part)
|
||||
elif isinstance(result, list):
|
||||
try:
|
||||
index = int(part)
|
||||
result = result[index]
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
return result
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
raise Exception(f"JSON文件查询失败: {str(e)}")
|
||||
|
||||
|
||||
# 连接器工厂
|
||||
_connector_classes = {
|
||||
'mysql': MySQLConnector,
|
||||
'postgresql': PostgreSQLConnector,
|
||||
'api': APIConnector,
|
||||
'json': JSONFileConnector,
|
||||
}
|
||||
|
||||
|
||||
def create_connector(source_type: str, config: Dict[str, Any]):
|
||||
"""
|
||||
创建数据源连接器
|
||||
|
||||
Args:
|
||||
source_type: 数据源类型
|
||||
config: 连接配置
|
||||
|
||||
Returns:
|
||||
数据源连接器实例
|
||||
"""
|
||||
connector_class = _connector_classes.get(source_type)
|
||||
if not connector_class:
|
||||
raise ValueError(f"不支持的数据源类型: {source_type}")
|
||||
|
||||
return connector_class(source_type, config)
|
||||
|
||||
|
||||
# 为了兼容API,创建一个统一的DataSourceConnector包装类
|
||||
class DataSourceConnectorWrapper:
|
||||
"""统一的数据源连接器包装类(用于API调用)"""
|
||||
|
||||
def __init__(self, source_type: str, config: Dict[str, Any]):
|
||||
self.connector = create_connector(source_type, config)
|
||||
self.source_type = source_type
|
||||
self.config = config
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
return self.connector.test_connection()
|
||||
|
||||
def query(self, query_params: Dict[str, Any]) -> Any:
|
||||
return self.connector.query(query_params)
|
||||
|
||||
|
||||
# 导出时使用包装类,这样API可以统一使用DataSourceConnector
|
||||
# 但实际返回的是具体的连接器实现
|
||||
255
backend/app/services/data_transformer.py
Normal file
255
backend/app/services/data_transformer.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""
|
||||
数据转换服务
|
||||
支持字段映射、数据过滤、数据转换等功能
|
||||
"""
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
class DataTransformer:
|
||||
"""数据转换器"""
|
||||
|
||||
@staticmethod
|
||||
def get_nested_value(data: Dict[str, Any], path: str) -> Any:
|
||||
"""
|
||||
从嵌套字典中获取值
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
path: 路径,如 'user.name' 或 'items[0].price'
|
||||
|
||||
Returns:
|
||||
值,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
# 处理混合路径,如 items[0].price
|
||||
if '[' in path and ']' in path:
|
||||
# 先处理数组索引部分
|
||||
bracket_match = re.search(r'(\w+)\[(\d+)\]', path)
|
||||
if bracket_match:
|
||||
array_key = bracket_match.group(1)
|
||||
array_index = int(bracket_match.group(2))
|
||||
rest_path = path[bracket_match.end():]
|
||||
|
||||
# 获取数组
|
||||
if array_key in data and isinstance(data[array_key], list):
|
||||
if array_index < len(data[array_key]):
|
||||
array_item = data[array_key][array_index]
|
||||
# 如果还有后续路径,继续获取
|
||||
if rest_path.startswith('.'):
|
||||
return DataTransformer.get_nested_value(array_item, rest_path[1:])
|
||||
else:
|
||||
return array_item
|
||||
return None
|
||||
|
||||
# 处理嵌套路径,如 user.name
|
||||
keys = path.split('.')
|
||||
value = data
|
||||
for key in keys:
|
||||
if isinstance(value, dict):
|
||||
value = value.get(key)
|
||||
elif isinstance(value, list) and key.isdigit():
|
||||
value = value[int(key)] if int(key) < len(value) else None
|
||||
else:
|
||||
return None
|
||||
if value is None:
|
||||
return None
|
||||
return value
|
||||
except (KeyError, IndexError, TypeError, AttributeError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def set_nested_value(data: Dict[str, Any], path: str, value: Any) -> None:
|
||||
"""
|
||||
在嵌套字典中设置值
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
path: 路径,如 'user.name'
|
||||
value: 要设置的值
|
||||
"""
|
||||
keys = path.split('.')
|
||||
current = data
|
||||
|
||||
# 创建嵌套结构
|
||||
for key in keys[:-1]:
|
||||
if key not in current:
|
||||
current[key] = {}
|
||||
current = current[key]
|
||||
|
||||
# 设置值
|
||||
current[keys[-1]] = value
|
||||
|
||||
@staticmethod
|
||||
def transform_mapping(input_data: Dict[str, Any], mapping: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""
|
||||
字段映射转换
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
mapping: 映射规则,格式: {"target_key": "source_key"}
|
||||
|
||||
Returns:
|
||||
转换后的数据
|
||||
"""
|
||||
result = {}
|
||||
for target_key, source_key in mapping.items():
|
||||
value = DataTransformer.get_nested_value(input_data, source_key)
|
||||
if value is not None:
|
||||
# 如果目标键包含点或方括号,使用嵌套设置
|
||||
if '.' in target_key or '[' in target_key:
|
||||
DataTransformer.set_nested_value(result, target_key, value)
|
||||
else:
|
||||
# 简单键直接设置
|
||||
result[target_key] = value
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def transform_filter(input_data: Dict[str, Any], filter_rules: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
数据过滤
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
filter_rules: 过滤规则列表,格式: [{"field": "key", "operator": ">", "value": 10}]
|
||||
|
||||
Returns:
|
||||
过滤后的数据
|
||||
"""
|
||||
result = {}
|
||||
|
||||
for rule in filter_rules:
|
||||
field = rule.get('field')
|
||||
operator = rule.get('operator', '==')
|
||||
value = rule.get('value')
|
||||
|
||||
if not field:
|
||||
continue
|
||||
|
||||
field_value = DataTransformer.get_nested_value(input_data, field)
|
||||
|
||||
# 应用过滤规则
|
||||
should_include = False
|
||||
if operator == '==' and field_value == value:
|
||||
should_include = True
|
||||
elif operator == '!=' and field_value != value:
|
||||
should_include = True
|
||||
elif operator == '>' and field_value > value:
|
||||
should_include = True
|
||||
elif operator == '>=' and field_value >= value:
|
||||
should_include = True
|
||||
elif operator == '<' and field_value < value:
|
||||
should_include = True
|
||||
elif operator == '<=' and field_value <= value:
|
||||
should_include = True
|
||||
elif operator == 'in' and field_value in value:
|
||||
should_include = True
|
||||
elif operator == 'not in' and field_value not in value:
|
||||
should_include = True
|
||||
|
||||
if should_include:
|
||||
# 包含该字段
|
||||
if field in input_data:
|
||||
result[field] = input_data[field]
|
||||
else:
|
||||
# 如果是嵌套字段,需要重建结构
|
||||
DataTransformer.set_nested_value(result, field, field_value)
|
||||
|
||||
return result if result else input_data
|
||||
|
||||
@staticmethod
|
||||
def transform_compute(input_data: Dict[str, Any], compute_rules: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""
|
||||
数据计算转换
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
compute_rules: 计算规则,格式: {"result": "{a} + {b}"}
|
||||
|
||||
Returns:
|
||||
转换后的数据
|
||||
"""
|
||||
result = input_data.copy()
|
||||
|
||||
for target_key, expression in compute_rules.items():
|
||||
try:
|
||||
# 替换变量
|
||||
computed_expression = expression
|
||||
for key, value in input_data.items():
|
||||
placeholder = f'{{{key}}}'
|
||||
if placeholder in computed_expression:
|
||||
if isinstance(value, (dict, list)):
|
||||
computed_expression = computed_expression.replace(
|
||||
placeholder,
|
||||
json.dumps(value, ensure_ascii=False)
|
||||
)
|
||||
else:
|
||||
computed_expression = computed_expression.replace(
|
||||
placeholder,
|
||||
str(value)
|
||||
)
|
||||
|
||||
# 安全评估表达式
|
||||
safe_dict = {
|
||||
'__builtins__': {},
|
||||
'abs': abs,
|
||||
'min': min,
|
||||
'max': max,
|
||||
'sum': sum,
|
||||
'len': len,
|
||||
}
|
||||
|
||||
# 添加输入数据中的值
|
||||
for key, value in input_data.items():
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
safe_dict[key] = value
|
||||
|
||||
computed_value = eval(computed_expression, safe_dict)
|
||||
result[target_key] = computed_value
|
||||
except Exception as e:
|
||||
# 计算失败,跳过该字段
|
||||
result[target_key] = None
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def transform_data(
|
||||
input_data: Dict[str, Any],
|
||||
mapping: Optional[Dict[str, str]] = None,
|
||||
filter_rules: Optional[List[Dict[str, Any]]] = None,
|
||||
compute_rules: Optional[Dict[str, str]] = None,
|
||||
mode: str = 'mapping'
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
数据转换(综合方法)
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
mapping: 字段映射规则
|
||||
filter_rules: 过滤规则
|
||||
compute_rules: 计算规则
|
||||
mode: 转换模式 ('mapping', 'filter', 'compute', 'all')
|
||||
|
||||
Returns:
|
||||
转换后的数据
|
||||
"""
|
||||
result = input_data.copy()
|
||||
|
||||
if mode == 'mapping' or mode == 'all':
|
||||
if mapping:
|
||||
result = DataTransformer.transform_mapping(result, mapping)
|
||||
|
||||
if mode == 'filter' or mode == 'all':
|
||||
if filter_rules:
|
||||
result = DataTransformer.transform_filter(result, filter_rules)
|
||||
|
||||
if mode == 'compute' or mode == 'all':
|
||||
if compute_rules:
|
||||
result = DataTransformer.transform_compute(result, compute_rules)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 全局实例
|
||||
data_transformer = DataTransformer()
|
||||
132
backend/app/services/encryption_service.py
Normal file
132
backend/app/services/encryption_service.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
加密服务
|
||||
提供敏感数据的加密和解密功能
|
||||
"""
|
||||
from cryptography.fernet import Fernet
|
||||
from app.core.config import settings
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EncryptionService:
|
||||
"""加密服务类"""
|
||||
|
||||
_fernet: Fernet = None
|
||||
|
||||
@classmethod
|
||||
def _get_fernet(cls) -> Fernet:
|
||||
"""获取Fernet加密实例(单例模式)"""
|
||||
if cls._fernet is None:
|
||||
# 使用SECRET_KEY生成Fernet密钥
|
||||
# Fernet需要32字节的密钥,我们使用SHA256哈希SECRET_KEY
|
||||
key = hashlib.sha256(settings.SECRET_KEY.encode()).digest()
|
||||
# Fernet需要base64编码的32字节密钥
|
||||
fernet_key = base64.urlsafe_b64encode(key)
|
||||
cls._fernet = Fernet(fernet_key)
|
||||
return cls._fernet
|
||||
|
||||
@classmethod
|
||||
def encrypt(cls, plaintext: str) -> str:
|
||||
"""
|
||||
加密明文
|
||||
|
||||
Args:
|
||||
plaintext: 要加密的明文
|
||||
|
||||
Returns:
|
||||
加密后的密文(base64编码)
|
||||
"""
|
||||
if not plaintext:
|
||||
return ""
|
||||
|
||||
try:
|
||||
fernet = cls._get_fernet()
|
||||
encrypted = fernet.encrypt(plaintext.encode('utf-8'))
|
||||
return encrypted.decode('utf-8')
|
||||
except Exception as e:
|
||||
logger.error(f"加密失败: {e}")
|
||||
raise ValueError(f"加密失败: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def decrypt(cls, ciphertext: str) -> str:
|
||||
"""
|
||||
解密密文
|
||||
|
||||
Args:
|
||||
ciphertext: 要解密的密文(base64编码)
|
||||
|
||||
Returns:
|
||||
解密后的明文
|
||||
"""
|
||||
if not ciphertext:
|
||||
return ""
|
||||
|
||||
try:
|
||||
fernet = cls._get_fernet()
|
||||
decrypted = fernet.decrypt(ciphertext.encode('utf-8'))
|
||||
return decrypted.decode('utf-8')
|
||||
except Exception as e:
|
||||
logger.error(f"解密失败: {e}")
|
||||
# 如果解密失败,可能是旧数据未加密,直接返回原值
|
||||
# 或者抛出异常,让调用者处理
|
||||
raise ValueError(f"解密失败: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def encrypt_dict_value(cls, data: dict, key: str) -> dict:
|
||||
"""
|
||||
加密字典中指定键的值
|
||||
|
||||
Args:
|
||||
data: 字典数据
|
||||
key: 要加密的键名
|
||||
|
||||
Returns:
|
||||
加密后的字典
|
||||
"""
|
||||
if key in data and data[key]:
|
||||
data[key] = cls.encrypt(str(data[key]))
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def decrypt_dict_value(cls, data: dict, key: str) -> dict:
|
||||
"""
|
||||
解密字典中指定键的值
|
||||
|
||||
Args:
|
||||
data: 字典数据
|
||||
key: 要解密的键名
|
||||
|
||||
Returns:
|
||||
解密后的字典
|
||||
"""
|
||||
if key in data and data[key]:
|
||||
try:
|
||||
data[key] = cls.decrypt(str(data[key]))
|
||||
except ValueError:
|
||||
# 如果解密失败,可能是未加密的数据,保持原值
|
||||
pass
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def is_encrypted(cls, text: str) -> bool:
|
||||
"""
|
||||
判断文本是否已加密
|
||||
|
||||
Args:
|
||||
text: 要检查的文本
|
||||
|
||||
Returns:
|
||||
是否已加密
|
||||
"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 尝试解密,如果成功则说明已加密
|
||||
cls.decrypt(text)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
117
backend/app/services/execution_logger.py
Normal file
117
backend/app/services/execution_logger.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
执行日志服务
|
||||
"""
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.execution_log import ExecutionLog
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExecutionLogger:
|
||||
"""执行日志记录器"""
|
||||
|
||||
def __init__(self, execution_id: str, db: Session):
|
||||
"""
|
||||
初始化日志记录器
|
||||
|
||||
Args:
|
||||
execution_id: 执行ID
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.execution_id = execution_id
|
||||
self.db = db
|
||||
|
||||
def log(
|
||||
self,
|
||||
level: str,
|
||||
message: str,
|
||||
node_id: Optional[str] = None,
|
||||
node_type: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
duration: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
记录日志
|
||||
|
||||
Args:
|
||||
level: 日志级别 (INFO/WARN/ERROR/DEBUG)
|
||||
message: 日志消息
|
||||
node_id: 节点ID(可选)
|
||||
node_type: 节点类型(可选)
|
||||
data: 附加数据(可选)
|
||||
duration: 执行耗时(毫秒,可选)
|
||||
"""
|
||||
try:
|
||||
log_entry = ExecutionLog(
|
||||
execution_id=self.execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
level=level.upper(),
|
||||
message=message,
|
||||
data=data,
|
||||
duration=duration,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
self.db.add(log_entry)
|
||||
self.db.commit()
|
||||
|
||||
# 同时输出到标准日志
|
||||
log_method = getattr(logger, level.lower(), logger.info)
|
||||
log_msg = f"[执行 {self.execution_id}]"
|
||||
if node_id:
|
||||
log_msg += f" [节点 {node_id}]"
|
||||
log_msg += f" {message}"
|
||||
log_method(log_msg)
|
||||
|
||||
except Exception as e:
|
||||
# 如果数据库记录失败,至少输出到标准日志
|
||||
logger.error(f"记录执行日志失败: {str(e)}")
|
||||
logger.error(f"[执行 {self.execution_id}] {message}")
|
||||
|
||||
def info(self, message: str, **kwargs):
|
||||
"""记录INFO级别日志"""
|
||||
self.log("INFO", message, **kwargs)
|
||||
|
||||
def warn(self, message: str, **kwargs):
|
||||
"""记录WARN级别日志"""
|
||||
self.log("WARN", message, **kwargs)
|
||||
|
||||
def error(self, message: str, **kwargs):
|
||||
"""记录ERROR级别日志"""
|
||||
self.log("ERROR", message, **kwargs)
|
||||
|
||||
def debug(self, message: str, **kwargs):
|
||||
"""记录DEBUG级别日志"""
|
||||
self.log("DEBUG", message, **kwargs)
|
||||
|
||||
def log_node_start(self, node_id: str, node_type: str, input_data: Optional[Dict[str, Any]] = None):
|
||||
"""记录节点开始执行"""
|
||||
self.info(
|
||||
f"节点 {node_id} ({node_type}) 开始执行",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
data={"input": input_data} if input_data else None
|
||||
)
|
||||
|
||||
def log_node_complete(self, node_id: str, node_type: str, output_data: Optional[Dict[str, Any]] = None, duration: Optional[int] = None):
|
||||
"""记录节点执行完成"""
|
||||
self.info(
|
||||
f"节点 {node_id} ({node_type}) 执行完成",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
data={"output": output_data} if output_data else None,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
def log_node_error(self, node_id: str, node_type: str, error: Exception, duration: Optional[int] = None):
|
||||
"""记录节点执行错误"""
|
||||
self.error(
|
||||
f"节点 {node_id} ({node_type}) 执行失败: {str(error)}",
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
data={"error": str(error), "error_type": type(error).__name__},
|
||||
duration=duration
|
||||
)
|
||||
220
backend/app/services/llm_service.py
Normal file
220
backend/app/services/llm_service.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
LLM服务 - 处理各种LLM提供商的调用
|
||||
"""
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
from openai import AsyncOpenAI
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""LLM服务类"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化LLM服务"""
|
||||
self.openai_client = None
|
||||
self.deepseek_client = None
|
||||
|
||||
# 初始化OpenAI客户端
|
||||
if settings.OPENAI_API_KEY:
|
||||
self.openai_client = AsyncOpenAI(
|
||||
api_key=settings.OPENAI_API_KEY,
|
||||
base_url=settings.OPENAI_BASE_URL
|
||||
)
|
||||
|
||||
# 初始化DeepSeek客户端(兼容OpenAI API)
|
||||
if settings.DEEPSEEK_API_KEY:
|
||||
self.deepseek_client = AsyncOpenAI(
|
||||
api_key=settings.DEEPSEEK_API_KEY,
|
||||
base_url=settings.DEEPSEEK_BASE_URL
|
||||
)
|
||||
|
||||
async def call_openai(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
调用OpenAI API
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model: 模型名称,默认gpt-3.5-turbo
|
||||
temperature: 温度参数,默认0.7
|
||||
max_tokens: 最大token数
|
||||
api_key: API密钥(可选,如果不提供则使用默认配置)
|
||||
base_url: API地址(可选,如果不提供则使用默认配置)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
LLM返回的文本
|
||||
"""
|
||||
# 如果提供了api_key或base_url,创建临时客户端
|
||||
# 注意:api_key 可能是空字符串,需要检查是否为 None
|
||||
if api_key is not None or base_url is not None:
|
||||
# 如果提供了 api_key,使用它;否则使用系统默认配置
|
||||
final_api_key = api_key if api_key else settings.OPENAI_API_KEY
|
||||
final_base_url = base_url if base_url else settings.OPENAI_BASE_URL
|
||||
|
||||
if not final_api_key:
|
||||
raise ValueError("OpenAI API Key未配置,请在节点配置中设置API Key或在环境变量中设置OPENAI_API_KEY")
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key=final_api_key,
|
||||
base_url=final_base_url
|
||||
)
|
||||
else:
|
||||
# 如果 openai_client 未初始化,尝试从 settings 重新读取并初始化
|
||||
if not self.openai_client:
|
||||
if settings.OPENAI_API_KEY:
|
||||
self.openai_client = AsyncOpenAI(
|
||||
api_key=settings.OPENAI_API_KEY,
|
||||
base_url=settings.OPENAI_BASE_URL
|
||||
)
|
||||
else:
|
||||
raise ValueError("OpenAI API Key未配置,请在节点配置中设置API Key或在环境变量中设置OPENAI_API_KEY")
|
||||
client = self.openai_client
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
raise Exception("OpenAI API返回的内容为空,请检查API配置和模型名称")
|
||||
return content
|
||||
except Exception as e:
|
||||
raise Exception(f"OpenAI API调用失败: {str(e)}")
|
||||
|
||||
async def call_deepseek(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str = "deepseek-chat",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
调用DeepSeek API
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model: 模型名称,默认deepseek-chat
|
||||
temperature: 温度参数,默认0.7
|
||||
max_tokens: 最大token数
|
||||
api_key: API密钥(可选,如果不提供则使用默认配置)
|
||||
base_url: API地址(可选,如果不提供则使用默认配置)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
LLM返回的文本
|
||||
"""
|
||||
# 如果提供了api_key或base_url,创建临时客户端
|
||||
# 注意:api_key 可能是空字符串,需要检查是否为 None
|
||||
if api_key is not None or base_url is not None:
|
||||
# 如果提供了 api_key,使用它;否则使用系统默认配置
|
||||
final_api_key = api_key if api_key else settings.DEEPSEEK_API_KEY
|
||||
final_base_url = base_url if base_url else settings.DEEPSEEK_BASE_URL
|
||||
|
||||
if not final_api_key:
|
||||
raise ValueError("DeepSeek API Key未配置,请在节点配置中设置API Key或在环境变量中设置DEEPSEEK_API_KEY")
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key=final_api_key,
|
||||
base_url=final_base_url
|
||||
)
|
||||
else:
|
||||
# 如果 deepseek_client 未初始化,尝试从 settings 重新读取并初始化
|
||||
if not self.deepseek_client:
|
||||
if settings.DEEPSEEK_API_KEY:
|
||||
self.deepseek_client = AsyncOpenAI(
|
||||
api_key=settings.DEEPSEEK_API_KEY,
|
||||
base_url=settings.DEEPSEEK_BASE_URL
|
||||
)
|
||||
else:
|
||||
raise ValueError("DeepSeek API Key未配置,请在节点配置中设置API Key或在环境变量中设置DEEPSEEK_API_KEY")
|
||||
client = self.deepseek_client
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
raise Exception("DeepSeek API返回的内容为空,请检查API配置和模型名称")
|
||||
return content
|
||||
except Exception as e:
|
||||
raise Exception(f"DeepSeek API调用失败: {str(e)}")
|
||||
|
||||
async def call_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
通用LLM调用接口
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
provider: 提供商,支持openai、deepseek
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
LLM返回的文本
|
||||
"""
|
||||
if provider == "openai":
|
||||
# 默认模型
|
||||
if not model:
|
||||
model = "gpt-3.5-turbo"
|
||||
return await self.call_openai(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
# 默认模型
|
||||
if not model:
|
||||
model = "deepseek-chat"
|
||||
return await self.call_deepseek(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的LLM提供商: {provider},目前支持: openai, deepseek")
|
||||
|
||||
|
||||
# 全局LLM服务实例
|
||||
llm_service = LLMService()
|
||||
276
backend/app/services/monitoring_service.py
Normal file
276
backend/app/services/monitoring_service.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
系统监控服务
|
||||
提供系统状态、执行统计、性能指标等监控数据
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, and_, case
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List
|
||||
from app.models.user import User
|
||||
from app.models.workflow import Workflow
|
||||
from app.models.agent import Agent
|
||||
from app.models.execution import Execution
|
||||
from app.models.execution_log import ExecutionLog
|
||||
from app.models.data_source import DataSource
|
||||
from app.models.model_config import ModelConfig
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MonitoringService:
|
||||
"""系统监控服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_system_overview(db: Session, user_id: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
获取系统概览统计
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID(如果提供,则只统计该用户的数据)
|
||||
|
||||
Returns:
|
||||
系统概览数据
|
||||
"""
|
||||
# 构建基础查询条件
|
||||
user_filter = Workflow.user_id == user_id if user_id else True
|
||||
|
||||
# 统计工作流数量
|
||||
workflow_count = db.query(func.count(Workflow.id)).filter(user_filter).scalar() or 0
|
||||
|
||||
# 统计Agent数量
|
||||
agent_filter = Agent.user_id == user_id if user_id else True
|
||||
agent_count = db.query(func.count(Agent.id)).filter(agent_filter).scalar() or 0
|
||||
|
||||
# 统计执行记录数量
|
||||
execution_filter = None
|
||||
if user_id:
|
||||
execution_filter = Execution.workflow_id.in_(
|
||||
db.query(Workflow.id).filter(Workflow.user_id == user_id)
|
||||
)
|
||||
execution_count = db.query(func.count(Execution.id)).filter(
|
||||
execution_filter if execution_filter else True
|
||||
).scalar() or 0
|
||||
|
||||
# 统计数据源数量
|
||||
data_source_filter = DataSource.user_id == user_id if user_id else True
|
||||
data_source_count = db.query(func.count(DataSource.id)).filter(
|
||||
data_source_filter
|
||||
).scalar() or 0
|
||||
|
||||
# 统计模型配置数量
|
||||
model_config_filter = ModelConfig.user_id == user_id if user_id else True
|
||||
model_config_count = db.query(func.count(ModelConfig.id)).filter(
|
||||
model_config_filter
|
||||
).scalar() or 0
|
||||
|
||||
# 统计用户数量(仅管理员可见)
|
||||
user_count = None
|
||||
if not user_id:
|
||||
user_count = db.query(func.count(User.id)).scalar() or 0
|
||||
|
||||
return {
|
||||
"workflows": workflow_count,
|
||||
"agents": agent_count,
|
||||
"executions": execution_count,
|
||||
"data_sources": data_source_count,
|
||||
"model_configs": model_config_count,
|
||||
"users": user_count
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_execution_statistics(
|
||||
db: Session,
|
||||
user_id: str = None,
|
||||
days: int = 7
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取执行统计信息
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID(如果提供,则只统计该用户的数据)
|
||||
days: 统计天数(默认7天)
|
||||
|
||||
Returns:
|
||||
执行统计数据
|
||||
"""
|
||||
# 构建时间范围
|
||||
end_time = datetime.utcnow()
|
||||
start_time = end_time - timedelta(days=days)
|
||||
|
||||
# 构建查询条件
|
||||
execution_filter = Execution.created_at >= start_time
|
||||
if user_id:
|
||||
execution_filter = and_(
|
||||
execution_filter,
|
||||
Execution.workflow_id.in_(
|
||||
db.query(Workflow.id).filter(Workflow.user_id == user_id)
|
||||
)
|
||||
)
|
||||
|
||||
# 统计总执行数
|
||||
total_executions = db.query(func.count(Execution.id)).filter(
|
||||
execution_filter
|
||||
).scalar() or 0
|
||||
|
||||
# 统计各状态执行数
|
||||
status_stats = db.query(
|
||||
Execution.status,
|
||||
func.count(Execution.id).label('count')
|
||||
).filter(execution_filter).group_by(Execution.status).all()
|
||||
|
||||
status_counts = {status: count for status, count in status_stats}
|
||||
|
||||
# 计算成功率
|
||||
completed = status_counts.get('completed', 0)
|
||||
failed = status_counts.get('failed', 0)
|
||||
success_rate = (completed / total_executions * 100) if total_executions > 0 else 0
|
||||
|
||||
# 统计平均执行时间
|
||||
avg_execution_time = db.query(
|
||||
func.avg(Execution.execution_time)
|
||||
).filter(
|
||||
and_(execution_filter, Execution.execution_time.isnot(None))
|
||||
).scalar() or 0
|
||||
|
||||
# 统计最近24小时的执行趋势
|
||||
hourly_trends = []
|
||||
for i in range(24):
|
||||
hour_start = end_time - timedelta(hours=24-i)
|
||||
hour_end = hour_start + timedelta(hours=1)
|
||||
hour_filter = and_(
|
||||
execution_filter,
|
||||
Execution.created_at >= hour_start,
|
||||
Execution.created_at < hour_end
|
||||
)
|
||||
hour_count = db.query(func.count(Execution.id)).filter(
|
||||
hour_filter
|
||||
).scalar() or 0
|
||||
hourly_trends.append({
|
||||
"hour": hour_start.strftime("%H:00"),
|
||||
"count": hour_count
|
||||
})
|
||||
|
||||
return {
|
||||
"total": total_executions,
|
||||
"status_counts": status_counts,
|
||||
"success_rate": round(success_rate, 2),
|
||||
"avg_execution_time": round(avg_execution_time, 2) if avg_execution_time else 0,
|
||||
"hourly_trends": hourly_trends
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_node_type_statistics(
|
||||
db: Session,
|
||||
user_id: str = None,
|
||||
days: int = 7
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取节点类型统计
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID(如果提供,则只统计该用户的数据)
|
||||
days: 统计天数(默认7天)
|
||||
|
||||
Returns:
|
||||
节点类型统计数据
|
||||
"""
|
||||
# 构建时间范围
|
||||
end_time = datetime.utcnow()
|
||||
start_time = end_time - timedelta(days=days)
|
||||
|
||||
# 构建查询条件
|
||||
execution_filter = Execution.created_at >= start_time
|
||||
if user_id:
|
||||
execution_filter = and_(
|
||||
execution_filter,
|
||||
Execution.workflow_id.in_(
|
||||
db.query(Workflow.id).filter(Workflow.user_id == user_id)
|
||||
)
|
||||
)
|
||||
|
||||
# 获取符合条件的执行ID列表
|
||||
execution_ids_query = db.query(Execution.id).filter(execution_filter)
|
||||
execution_ids = [row[0] for row in execution_ids_query.all()]
|
||||
|
||||
if not execution_ids:
|
||||
return []
|
||||
|
||||
# 统计各节点类型的执行情况
|
||||
node_stats = db.query(
|
||||
ExecutionLog.node_type,
|
||||
func.count(ExecutionLog.id).label('execution_count'),
|
||||
func.sum(ExecutionLog.duration).label('total_duration'),
|
||||
func.avg(ExecutionLog.duration).label('avg_duration'),
|
||||
func.count(
|
||||
case((ExecutionLog.level == 'ERROR', 1))
|
||||
).label('error_count')
|
||||
).filter(
|
||||
and_(
|
||||
ExecutionLog.execution_id.in_(execution_ids),
|
||||
ExecutionLog.node_type.isnot(None),
|
||||
ExecutionLog.duration.isnot(None)
|
||||
)
|
||||
).group_by(ExecutionLog.node_type).all()
|
||||
|
||||
result = []
|
||||
for node_type, exec_count, total_dur, avg_dur, error_count in node_stats:
|
||||
result.append({
|
||||
"node_type": node_type,
|
||||
"execution_count": exec_count,
|
||||
"total_duration": round(total_dur or 0, 2),
|
||||
"avg_duration": round(avg_dur or 0, 2),
|
||||
"error_count": error_count,
|
||||
"success_rate": round((exec_count - error_count) / exec_count * 100, 2) if exec_count > 0 else 0
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_recent_activities(
|
||||
db: Session,
|
||||
user_id: str = None,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取最近的活动记录
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID(如果提供,则只统计该用户的数据)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
最近活动列表
|
||||
"""
|
||||
# 构建查询条件
|
||||
execution_filter = True
|
||||
if user_id:
|
||||
execution_filter = Execution.workflow_id.in_(
|
||||
db.query(Workflow.id).filter(Workflow.user_id == user_id)
|
||||
)
|
||||
|
||||
# 获取最近的执行记录
|
||||
recent_executions = db.query(Execution).filter(
|
||||
execution_filter
|
||||
).order_by(Execution.created_at.desc()).limit(limit).all()
|
||||
|
||||
result = []
|
||||
for execution in recent_executions:
|
||||
workflow = db.query(Workflow).filter(
|
||||
Workflow.id == execution.workflow_id
|
||||
).first() if execution.workflow_id else None
|
||||
|
||||
result.append({
|
||||
"id": execution.id,
|
||||
"type": "execution",
|
||||
"workflow_name": workflow.name if workflow else "未知工作流",
|
||||
"status": execution.status,
|
||||
"created_at": execution.created_at.isoformat() if execution.created_at else None,
|
||||
"execution_time": execution.execution_time
|
||||
})
|
||||
|
||||
return result
|
||||
110
backend/app/services/permission_service.py
Normal file
110
backend/app/services/permission_service.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
权限服务
|
||||
提供权限检查的辅助函数
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.permission import WorkflowPermission, AgentPermission
|
||||
from app.models.user import User
|
||||
from app.models.workflow import Workflow
|
||||
from app.models.agent import Agent
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def check_workflow_permission(
|
||||
db: Session,
|
||||
user: User,
|
||||
workflow: Workflow,
|
||||
permission_type: str
|
||||
) -> bool:
|
||||
"""
|
||||
检查用户对工作流的权限
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user: 用户对象
|
||||
workflow: 工作流对象
|
||||
permission_type: 权限类型(read/write/execute/share)
|
||||
|
||||
Returns:
|
||||
bool: 是否有权限
|
||||
"""
|
||||
# 管理员拥有所有权限
|
||||
if user.role == "admin":
|
||||
return True
|
||||
|
||||
# 工作流所有者拥有所有权限
|
||||
if workflow.user_id == user.id:
|
||||
return True
|
||||
|
||||
# 检查用户直接权限
|
||||
user_permission = db.query(WorkflowPermission).filter(
|
||||
WorkflowPermission.workflow_id == workflow.id,
|
||||
WorkflowPermission.user_id == user.id,
|
||||
WorkflowPermission.permission_type == permission_type
|
||||
).first()
|
||||
|
||||
if user_permission:
|
||||
return True
|
||||
|
||||
# 检查角色权限
|
||||
for role in user.roles:
|
||||
role_permission = db.query(WorkflowPermission).filter(
|
||||
WorkflowPermission.workflow_id == workflow.id,
|
||||
WorkflowPermission.role_id == role.id,
|
||||
WorkflowPermission.permission_type == permission_type
|
||||
).first()
|
||||
|
||||
if role_permission:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_agent_permission(
|
||||
db: Session,
|
||||
user: User,
|
||||
agent: Agent,
|
||||
permission_type: str
|
||||
) -> bool:
|
||||
"""
|
||||
检查用户对Agent的权限
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user: 用户对象
|
||||
agent: Agent对象
|
||||
permission_type: 权限类型(read/write/execute/deploy)
|
||||
|
||||
Returns:
|
||||
bool: 是否有权限
|
||||
"""
|
||||
# 管理员拥有所有权限
|
||||
if user.role == "admin":
|
||||
return True
|
||||
|
||||
# Agent所有者拥有所有权限
|
||||
if agent.user_id == user.id:
|
||||
return True
|
||||
|
||||
# 检查用户直接权限
|
||||
user_permission = db.query(AgentPermission).filter(
|
||||
AgentPermission.agent_id == agent.id,
|
||||
AgentPermission.user_id == user.id,
|
||||
AgentPermission.permission_type == permission_type
|
||||
).first()
|
||||
|
||||
if user_permission:
|
||||
return True
|
||||
|
||||
# 检查角色权限
|
||||
for role in user.roles:
|
||||
role_permission = db.query(AgentPermission).filter(
|
||||
AgentPermission.agent_id == agent.id,
|
||||
AgentPermission.role_id == role.id,
|
||||
AgentPermission.permission_type == permission_type
|
||||
).first()
|
||||
|
||||
if role_permission:
|
||||
return True
|
||||
|
||||
return False
|
||||
1666
backend/app/services/workflow_engine.py
Normal file
1666
backend/app/services/workflow_engine.py
Normal file
File diff suppressed because it is too large
Load Diff
323
backend/app/services/workflow_templates.py
Normal file
323
backend/app/services/workflow_templates.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
工作流模板服务
|
||||
提供预设的工作流模板,支持快速创建
|
||||
"""
|
||||
from typing import Dict, Any, List
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 预设工作流模板
|
||||
WORKFLOW_TEMPLATES = {
|
||||
"simple_llm": {
|
||||
"name": "简单LLM工作流",
|
||||
"description": "一个简单的LLM调用工作流,包含开始、LLM和结束节点",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start-1",
|
||||
"type": "start",
|
||||
"position": {"x": 100, "y": 100},
|
||||
"data": {"label": "开始"}
|
||||
},
|
||||
{
|
||||
"id": "llm-1",
|
||||
"type": "llm",
|
||||
"position": {"x": 100, "y": 250},
|
||||
"data": {
|
||||
"label": "LLM处理",
|
||||
"prompt": "请处理以下输入:\n{input}",
|
||||
"provider": "openai",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"temperature": 0.7
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "end-1",
|
||||
"type": "end",
|
||||
"position": {"x": 100, "y": 400},
|
||||
"data": {"label": "结束"}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"source": "start-1",
|
||||
"target": "llm-1",
|
||||
"sourceHandle": "bottom",
|
||||
"targetHandle": "top"
|
||||
},
|
||||
{
|
||||
"id": "e2",
|
||||
"source": "llm-1",
|
||||
"target": "end-1",
|
||||
"sourceHandle": "bottom",
|
||||
"targetHandle": "top"
|
||||
}
|
||||
]
|
||||
},
|
||||
"conditional_llm": {
|
||||
"name": "条件判断LLM工作流",
|
||||
"description": "根据条件判断调用不同的LLM处理",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start-1",
|
||||
"type": "start",
|
||||
"position": {"x": 100, "y": 100},
|
||||
"data": {"label": "开始"}
|
||||
},
|
||||
{
|
||||
"id": "condition-1",
|
||||
"type": "condition",
|
||||
"position": {"x": 100, "y": 200},
|
||||
"data": {
|
||||
"label": "条件判断",
|
||||
"condition": "{value} > 10"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "llm-true",
|
||||
"type": "llm",
|
||||
"position": {"x": -100, "y": 350},
|
||||
"data": {
|
||||
"label": "True分支LLM",
|
||||
"prompt": "值大于10,请分析:{input}",
|
||||
"provider": "openai",
|
||||
"model": "gpt-3.5-turbo"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "llm-false",
|
||||
"type": "llm",
|
||||
"position": {"x": 300, "y": 350},
|
||||
"data": {
|
||||
"label": "False分支LLM",
|
||||
"prompt": "值小于等于10,请分析:{input}",
|
||||
"provider": "openai",
|
||||
"model": "gpt-3.5-turbo"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "end-1",
|
||||
"type": "end",
|
||||
"position": {"x": 100, "y": 500},
|
||||
"data": {"label": "结束"}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"source": "start-1",
|
||||
"target": "condition-1"
|
||||
},
|
||||
{
|
||||
"id": "e2",
|
||||
"source": "condition-1",
|
||||
"target": "llm-true",
|
||||
"sourceHandle": "true"
|
||||
},
|
||||
{
|
||||
"id": "e3",
|
||||
"source": "condition-1",
|
||||
"target": "llm-false",
|
||||
"sourceHandle": "false"
|
||||
},
|
||||
{
|
||||
"id": "e4",
|
||||
"source": "llm-true",
|
||||
"target": "end-1"
|
||||
},
|
||||
{
|
||||
"id": "e5",
|
||||
"source": "llm-false",
|
||||
"target": "end-1"
|
||||
}
|
||||
]
|
||||
},
|
||||
"data_transform_llm": {
|
||||
"name": "数据转换+LLM工作流",
|
||||
"description": "先进行数据转换,再调用LLM处理",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start-1",
|
||||
"type": "start",
|
||||
"position": {"x": 100, "y": 100},
|
||||
"data": {"label": "开始"}
|
||||
},
|
||||
{
|
||||
"id": "transform-1",
|
||||
"type": "transform",
|
||||
"position": {"x": 100, "y": 200},
|
||||
"data": {
|
||||
"label": "数据转换",
|
||||
"mode": "mapping",
|
||||
"mapping": {
|
||||
"input_text": "raw_input",
|
||||
"user_id": "id"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "llm-1",
|
||||
"type": "llm",
|
||||
"position": {"x": 100, "y": 300},
|
||||
"data": {
|
||||
"label": "LLM处理",
|
||||
"prompt": "处理转换后的数据:{input_text}",
|
||||
"provider": "openai",
|
||||
"model": "gpt-3.5-turbo"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "end-1",
|
||||
"type": "end",
|
||||
"position": {"x": 100, "y": 400},
|
||||
"data": {"label": "结束"}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"source": "start-1",
|
||||
"target": "transform-1"
|
||||
},
|
||||
{
|
||||
"id": "e2",
|
||||
"source": "transform-1",
|
||||
"target": "llm-1"
|
||||
},
|
||||
{
|
||||
"id": "e3",
|
||||
"source": "llm-1",
|
||||
"target": "end-1"
|
||||
}
|
||||
]
|
||||
},
|
||||
"multi_llm_chain": {
|
||||
"name": "多LLM链式工作流",
|
||||
"description": "多个LLM节点链式调用,实现复杂处理流程",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start-1",
|
||||
"type": "start",
|
||||
"position": {"x": 100, "y": 100},
|
||||
"data": {"label": "开始"}
|
||||
},
|
||||
{
|
||||
"id": "llm-1",
|
||||
"type": "llm",
|
||||
"position": {"x": 100, "y": 200},
|
||||
"data": {
|
||||
"label": "第一步分析",
|
||||
"prompt": "第一步:分析输入数据:{input}",
|
||||
"provider": "openai",
|
||||
"model": "gpt-3.5-turbo"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "llm-2",
|
||||
"type": "llm",
|
||||
"position": {"x": 100, "y": 300},
|
||||
"data": {
|
||||
"label": "第二步处理",
|
||||
"prompt": "第二步:基于第一步的结果进行处理:{input}",
|
||||
"provider": "openai",
|
||||
"model": "gpt-3.5-turbo"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "llm-3",
|
||||
"type": "llm",
|
||||
"position": {"x": 100, "y": 400},
|
||||
"data": {
|
||||
"label": "第三步总结",
|
||||
"prompt": "第三步:总结最终结果:{input}",
|
||||
"provider": "openai",
|
||||
"model": "gpt-3.5-turbo"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "end-1",
|
||||
"type": "end",
|
||||
"position": {"x": 100, "y": 500},
|
||||
"data": {"label": "结束"}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"source": "start-1",
|
||||
"target": "llm-1"
|
||||
},
|
||||
{
|
||||
"id": "e2",
|
||||
"source": "llm-1",
|
||||
"target": "llm-2"
|
||||
},
|
||||
{
|
||||
"id": "e3",
|
||||
"source": "llm-2",
|
||||
"target": "llm-3"
|
||||
},
|
||||
{
|
||||
"id": "e4",
|
||||
"source": "llm-3",
|
||||
"target": "end-1"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_template(template_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取工作流模板
|
||||
|
||||
Args:
|
||||
template_id: 模板ID
|
||||
|
||||
Returns:
|
||||
模板数据
|
||||
"""
|
||||
return WORKFLOW_TEMPLATES.get(template_id)
|
||||
|
||||
|
||||
def list_templates() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取所有模板列表
|
||||
|
||||
Returns:
|
||||
模板列表,每个模板包含id、name、description
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"id": template_id,
|
||||
"name": template["name"],
|
||||
"description": template["description"]
|
||||
}
|
||||
for template_id, template in WORKFLOW_TEMPLATES.items()
|
||||
]
|
||||
|
||||
|
||||
def create_from_template(template_id: str, name: str = None, description: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
从模板创建工作流数据
|
||||
|
||||
Args:
|
||||
template_id: 模板ID
|
||||
name: 工作流名称(可选,默认使用模板名称)
|
||||
description: 工作流描述(可选,默认使用模板描述)
|
||||
|
||||
Returns:
|
||||
工作流数据
|
||||
"""
|
||||
template = get_template(template_id)
|
||||
if not template:
|
||||
raise ValueError(f"模板不存在: {template_id}")
|
||||
|
||||
return {
|
||||
"name": name or template["name"],
|
||||
"description": description or template["description"],
|
||||
"nodes": template["nodes"],
|
||||
"edges": template["edges"]
|
||||
}
|
||||
268
backend/app/services/workflow_validator.py
Normal file
268
backend/app/services/workflow_validator.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
工作流验证服务
|
||||
验证工作流的节点连接、数据流、循环检测等
|
||||
"""
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from collections import defaultdict, deque
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowValidator:
|
||||
"""工作流验证器"""
|
||||
|
||||
def __init__(self, nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]]):
|
||||
"""
|
||||
初始化验证器
|
||||
|
||||
Args:
|
||||
nodes: 节点列表
|
||||
edges: 边列表
|
||||
"""
|
||||
self.nodes = {node['id']: node for node in nodes}
|
||||
self.edges = edges
|
||||
self.errors = []
|
||||
self.warnings = []
|
||||
|
||||
def validate(self) -> Tuple[bool, List[str], List[str]]:
|
||||
"""
|
||||
执行完整验证
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误列表, 警告列表)
|
||||
"""
|
||||
self.errors = []
|
||||
self.warnings = []
|
||||
|
||||
# 基础验证
|
||||
self._validate_nodes()
|
||||
self._validate_edges()
|
||||
|
||||
# 结构验证
|
||||
self._validate_has_start_node()
|
||||
self._validate_has_end_node()
|
||||
self._validate_no_cycles()
|
||||
self._validate_all_nodes_reachable()
|
||||
|
||||
# 连接验证
|
||||
self._validate_node_connections()
|
||||
self._validate_condition_branches()
|
||||
|
||||
# 配置验证
|
||||
self._validate_node_configs()
|
||||
|
||||
return len(self.errors) == 0, self.errors, self.warnings
|
||||
|
||||
def _validate_nodes(self):
|
||||
"""验证节点基础信息"""
|
||||
if not self.nodes:
|
||||
self.errors.append("工作流必须包含至少一个节点")
|
||||
return
|
||||
|
||||
node_ids = set()
|
||||
for node_id, node in self.nodes.items():
|
||||
# 检查节点ID唯一性
|
||||
if node_id in node_ids:
|
||||
self.errors.append(f"节点ID重复: {node_id}")
|
||||
node_ids.add(node_id)
|
||||
|
||||
# 检查节点类型
|
||||
node_type = node.get('type')
|
||||
if not node_type:
|
||||
self.errors.append(f"节点 {node_id} 缺少类型")
|
||||
elif node_type not in ['start', 'input', 'llm', 'condition', 'transform', 'output', 'end', 'default', 'loop', 'foreach', 'loop_end', 'agent', 'http', 'request', 'database', 'db', 'file', 'file_operation', 'schedule', 'delay', 'timer', 'webhook', 'email', 'mail', 'message_queue', 'mq', 'rabbitmq', 'kafka']:
|
||||
self.warnings.append(f"节点 {node_id} 使用了未知类型: {node_type}")
|
||||
|
||||
def _validate_edges(self):
|
||||
"""验证边的基础信息"""
|
||||
for edge in self.edges:
|
||||
source = edge.get('source')
|
||||
target = edge.get('target')
|
||||
|
||||
if not source or not target:
|
||||
self.errors.append(f"边缺少源节点或目标节点: {edge.get('id', 'unknown')}")
|
||||
continue
|
||||
|
||||
# 检查源节点是否存在
|
||||
if source not in self.nodes:
|
||||
self.errors.append(f"边的源节点不存在: {source}")
|
||||
|
||||
# 检查目标节点是否存在
|
||||
if target not in self.nodes:
|
||||
self.errors.append(f"边的目标节点不存在: {target}")
|
||||
|
||||
# 检查自环
|
||||
if source == target:
|
||||
self.errors.append(f"节点 {source} 不能连接到自身")
|
||||
|
||||
def _validate_has_start_node(self):
|
||||
"""验证是否有开始节点"""
|
||||
start_nodes = [node for node in self.nodes.values() if node.get('type') == 'start']
|
||||
if not start_nodes:
|
||||
self.errors.append("工作流必须包含至少一个开始节点")
|
||||
elif len(start_nodes) > 1:
|
||||
self.warnings.append(f"工作流包含多个开始节点: {len(start_nodes)}")
|
||||
|
||||
def _validate_has_end_node(self):
|
||||
"""验证是否有结束节点"""
|
||||
end_nodes = [node for node in self.nodes.values() if node.get('type') == 'end']
|
||||
if not end_nodes:
|
||||
self.warnings.append("工作流建议包含至少一个结束节点")
|
||||
|
||||
def _validate_no_cycles(self):
|
||||
"""验证工作流中是否有循环(使用DFS)"""
|
||||
# 构建邻接表
|
||||
graph = defaultdict(list)
|
||||
for edge in self.edges:
|
||||
source = edge.get('source')
|
||||
target = edge.get('target')
|
||||
if source and target:
|
||||
graph[source].append(target)
|
||||
|
||||
# DFS检测循环
|
||||
visited = set()
|
||||
rec_stack = set()
|
||||
|
||||
def has_cycle(node_id: str) -> bool:
|
||||
visited.add(node_id)
|
||||
rec_stack.add(node_id)
|
||||
|
||||
for neighbor in graph.get(node_id, []):
|
||||
if neighbor not in visited:
|
||||
if has_cycle(neighbor):
|
||||
return True
|
||||
elif neighbor in rec_stack:
|
||||
# 找到循环
|
||||
self.errors.append(f"检测到循环: {node_id} -> {neighbor}")
|
||||
return True
|
||||
|
||||
rec_stack.remove(node_id)
|
||||
return False
|
||||
|
||||
for node_id in self.nodes.keys():
|
||||
if node_id not in visited:
|
||||
has_cycle(node_id)
|
||||
|
||||
def _validate_all_nodes_reachable(self):
|
||||
"""验证所有节点是否可达"""
|
||||
# 找到所有开始节点
|
||||
start_nodes = [node_id for node_id, node in self.nodes.items() if node.get('type') == 'start']
|
||||
|
||||
if not start_nodes:
|
||||
return # 如果没有开始节点,跳过此验证
|
||||
|
||||
# 从开始节点BFS遍历
|
||||
reachable = set()
|
||||
queue = deque(start_nodes)
|
||||
|
||||
while queue:
|
||||
node_id = queue.popleft()
|
||||
if node_id in reachable:
|
||||
continue
|
||||
reachable.add(node_id)
|
||||
|
||||
# 添加所有可达的节点
|
||||
for edge in self.edges:
|
||||
if edge.get('source') == node_id:
|
||||
target = edge.get('target')
|
||||
if target and target not in reachable:
|
||||
queue.append(target)
|
||||
|
||||
# 检查未达节点
|
||||
unreachable = set(self.nodes.keys()) - reachable
|
||||
if unreachable:
|
||||
self.warnings.append(f"以下节点不可达: {', '.join(unreachable)}")
|
||||
|
||||
def _validate_node_connections(self):
|
||||
"""验证节点连接的正确性"""
|
||||
# 检查开始节点是否有入边
|
||||
for node_id, node in self.nodes.items():
|
||||
if node.get('type') == 'start':
|
||||
has_incoming = any(edge.get('target') == node_id for edge in self.edges)
|
||||
if has_incoming:
|
||||
self.warnings.append(f"开始节点 {node_id} 不应该有入边")
|
||||
|
||||
# 检查结束节点是否有出边
|
||||
for node_id, node in self.nodes.items():
|
||||
if node.get('type') == 'end':
|
||||
has_outgoing = any(edge.get('source') == node_id for edge in self.edges)
|
||||
if has_outgoing:
|
||||
self.warnings.append(f"结束节点 {node_id} 不应该有出边")
|
||||
|
||||
def _validate_condition_branches(self):
|
||||
"""验证条件节点的分支"""
|
||||
for node_id, node in self.nodes.items():
|
||||
if node.get('type') == 'condition':
|
||||
# 检查是否有条件表达式
|
||||
condition = node.get('data', {}).get('condition', '')
|
||||
if not condition:
|
||||
self.warnings.append(f"条件节点 {node_id} 没有配置条件表达式")
|
||||
|
||||
# 检查是否有true和false分支
|
||||
true_edges = [e for e in self.edges if e.get('source') == node_id and e.get('sourceHandle') == 'true']
|
||||
false_edges = [e for e in self.edges if e.get('source') == node_id and e.get('sourceHandle') == 'false']
|
||||
|
||||
if not true_edges and not false_edges:
|
||||
self.warnings.append(f"条件节点 {node_id} 没有配置分支连接")
|
||||
elif not true_edges:
|
||||
self.warnings.append(f"条件节点 {node_id} 缺少True分支")
|
||||
elif not false_edges:
|
||||
self.warnings.append(f"条件节点 {node_id} 缺少False分支")
|
||||
|
||||
def _validate_node_configs(self):
|
||||
"""验证节点配置"""
|
||||
for node_id, node in self.nodes.items():
|
||||
node_type = node.get('type')
|
||||
node_data = node.get('data', {})
|
||||
|
||||
# LLM节点验证
|
||||
if node_type == 'llm':
|
||||
prompt = node_data.get('prompt', '')
|
||||
if not prompt:
|
||||
self.warnings.append(f"LLM节点 {node_id} 没有配置提示词")
|
||||
|
||||
provider = node_data.get('provider', 'openai')
|
||||
model = node_data.get('model')
|
||||
if not model:
|
||||
self.warnings.append(f"LLM节点 {node_id} 没有配置模型")
|
||||
|
||||
# 转换节点验证
|
||||
elif node_type == 'transform' or node_type == 'data':
|
||||
mode = node_data.get('mode', 'mapping')
|
||||
mapping = node_data.get('mapping', {})
|
||||
filter_rules = node_data.get('filter_rules', [])
|
||||
compute_rules = node_data.get('compute_rules', {})
|
||||
|
||||
if mode == 'mapping' and not mapping:
|
||||
self.warnings.append(f"转换节点 {node_id} 选择了映射模式但没有配置映射规则")
|
||||
elif mode == 'filter' and not filter_rules:
|
||||
self.warnings.append(f"转换节点 {node_id} 选择了过滤模式但没有配置过滤规则")
|
||||
elif mode == 'compute' and not compute_rules:
|
||||
self.warnings.append(f"转换节点 {node_id} 选择了计算模式但没有配置计算规则")
|
||||
|
||||
|
||||
def validate_workflow(nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
验证工作流
|
||||
|
||||
Args:
|
||||
nodes: 节点列表
|
||||
edges: 边列表
|
||||
|
||||
Returns:
|
||||
验证结果字典
|
||||
{
|
||||
"valid": bool,
|
||||
"errors": List[str],
|
||||
"warnings": List[str]
|
||||
}
|
||||
"""
|
||||
validator = WorkflowValidator(nodes, edges)
|
||||
valid, errors, warnings = validator.validate()
|
||||
|
||||
return {
|
||||
"valid": valid,
|
||||
"errors": errors,
|
||||
"warnings": warnings
|
||||
}
|
||||
Reference in New Issue
Block a user