第一次提交

This commit is contained in:
rjb
2026-01-19 00:09:36 +08:00
parent de4b5059e9
commit 6674060f2f
191 changed files with 40940 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Services package

View File

@@ -0,0 +1,391 @@
"""
告警服务
提供告警检测和通知功能
"""
from sqlalchemy.orm import Session
from sqlalchemy import func, and_
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
import logging
from app.models.alert_rule import AlertRule, AlertLog
from app.models.execution import Execution
from app.models.workflow import Workflow
from app.models.execution_log import ExecutionLog
import httpx
import aiosmtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
logger = logging.getLogger(__name__)
class AlertService:
"""告警服务类"""
@staticmethod
async def check_execution_failed(
db: Session,
rule: AlertRule,
execution: Execution
) -> bool:
"""
检查执行失败告警
Args:
db: 数据库会话
rule: 告警规则
execution: 执行记录
Returns:
是否触发告警
"""
if execution.status != 'failed':
return False
# 检查目标是否匹配
if rule.target_type == 'workflow' and rule.target_id:
if execution.workflow_id != rule.target_id:
return False
elif rule.target_type == 'agent' and rule.target_id:
if execution.agent_id != rule.target_id:
return False
# 检查时间窗口内的失败次数
conditions = rule.conditions
threshold = conditions.get('threshold', 1)
time_window = conditions.get('time_window', 3600) # 默认1小时
comparison = conditions.get('comparison', 'gt') # gt, gte, eq
start_time = datetime.utcnow() - timedelta(seconds=time_window)
# 构建查询条件
query = db.query(func.count(Execution.id)).filter(
Execution.status == 'failed',
Execution.created_at >= start_time
)
if rule.target_type == 'workflow' and rule.target_id:
query = query.filter(Execution.workflow_id == rule.target_id)
elif rule.target_type == 'agent' and rule.target_id:
query = query.filter(Execution.agent_id == rule.target_id)
failed_count = query.scalar() or 0
# 根据比较操作符判断
if comparison == 'gt':
return failed_count > threshold
elif comparison == 'gte':
return failed_count >= threshold
elif comparison == 'eq':
return failed_count == threshold
else:
return failed_count > threshold
@staticmethod
async def check_execution_timeout(
db: Session,
rule: AlertRule,
execution: Execution
) -> bool:
"""
检查执行超时告警
Args:
db: 数据库会话
rule: 告警规则
execution: 执行记录
Returns:
是否触发告警
"""
if execution.status != 'running':
return False
# 检查目标是否匹配
if rule.target_type == 'workflow' and rule.target_id:
if execution.workflow_id != rule.target_id:
return False
elif rule.target_type == 'agent' and rule.target_id:
if execution.agent_id != rule.target_id:
return False
# 检查执行时间
conditions = rule.conditions
timeout_seconds = conditions.get('timeout_seconds', 3600) # 默认1小时
if execution.created_at:
elapsed = (datetime.utcnow() - execution.created_at).total_seconds()
return elapsed > timeout_seconds
return False
@staticmethod
async def check_error_rate(
db: Session,
rule: AlertRule
) -> bool:
"""
检查错误率告警
Args:
db: 数据库会话
rule: 告警规则
Returns:
是否触发告警
"""
conditions = rule.conditions
threshold = conditions.get('threshold', 0.1) # 默认10%
time_window = conditions.get('time_window', 3600) # 默认1小时
start_time = datetime.utcnow() - timedelta(seconds=time_window)
# 构建查询条件
query = db.query(Execution).filter(
Execution.created_at >= start_time
)
if rule.target_type == 'workflow' and rule.target_id:
query = query.filter(Execution.workflow_id == rule.target_id)
elif rule.target_type == 'agent' and rule.target_id:
query = query.filter(Execution.agent_id == rule.target_id)
executions = query.all()
if not executions:
return False
total_count = len(executions)
failed_count = sum(1 for e in executions if e.status == 'failed')
error_rate = failed_count / total_count if total_count > 0 else 0
return error_rate >= threshold
@staticmethod
async def check_alerts_for_execution(
db: Session,
execution: Execution
) -> List[AlertLog]:
"""
检查执行记录相关的告警规则
Args:
db: 数据库会话
execution: 执行记录
Returns:
触发的告警日志列表
"""
triggered_logs = []
# 获取相关的告警规则
query = db.query(AlertRule).filter(AlertRule.enabled == True)
# 根据执行记录筛选相关规则
if execution.workflow_id:
query = query.filter(
(AlertRule.target_type == 'workflow') &
((AlertRule.target_id == execution.workflow_id) | (AlertRule.target_id.is_(None)))
)
elif execution.agent_id:
query = query.filter(
(AlertRule.target_type == 'agent') &
((AlertRule.target_id == execution.agent_id) | (AlertRule.target_id.is_(None)))
)
# 也包含系统级告警
query = query.filter(AlertRule.target_type == 'system')
rules = query.all()
for rule in rules:
try:
should_trigger = False
alert_message = ""
alert_details = {}
if rule.alert_type == 'execution_failed':
should_trigger = await AlertService.check_execution_failed(db, rule, execution)
if should_trigger:
alert_message = f"执行失败告警: 工作流 {execution.workflow_id} 执行失败"
alert_details = {
"execution_id": execution.id,
"workflow_id": execution.workflow_id,
"status": execution.status,
"error_message": execution.error_message
}
elif rule.alert_type == 'execution_timeout':
should_trigger = await AlertService.check_execution_timeout(db, rule, execution)
if should_trigger:
elapsed = (datetime.utcnow() - execution.created_at).total_seconds() if execution.created_at else 0
alert_message = f"执行超时告警: 工作流 {execution.workflow_id} 执行超时 ({elapsed:.0f}秒)"
alert_details = {
"execution_id": execution.id,
"workflow_id": execution.workflow_id,
"elapsed_seconds": elapsed
}
elif rule.alert_type == 'error_rate':
should_trigger = await AlertService.check_error_rate(db, rule)
if should_trigger:
alert_message = f"错误率告警: {rule.target_type} 错误率超过阈值"
alert_details = {
"target_type": rule.target_type,
"target_id": rule.target_id
}
if should_trigger:
# 创建告警日志
alert_log = AlertLog(
rule_id=rule.id,
alert_type=rule.alert_type,
severity=rule.conditions.get('severity', 'warning'),
message=alert_message,
details=alert_details,
status='pending',
notification_type=rule.notification_type,
triggered_at=datetime.utcnow()
)
db.add(alert_log)
# 更新规则统计
rule.trigger_count += 1
rule.last_triggered_at = datetime.utcnow()
db.commit()
db.refresh(alert_log)
# 发送通知
await AlertService.send_notification(db, alert_log, rule)
triggered_logs.append(alert_log)
except Exception as e:
logger.error(f"检查告警规则失败 {rule.id}: {str(e)}")
continue
return triggered_logs
@staticmethod
async def send_notification(
db: Session,
alert_log: AlertLog,
rule: AlertRule
):
"""
发送告警通知
Args:
db: 数据库会话
alert_log: 告警日志
rule: 告警规则
"""
try:
if rule.notification_type == 'email':
await AlertService.send_email_notification(alert_log, rule)
elif rule.notification_type == 'webhook':
await AlertService.send_webhook_notification(alert_log, rule)
elif rule.notification_type == 'internal':
# 站内通知,只需要记录日志即可
pass
alert_log.status = 'sent'
alert_log.notification_result = '通知发送成功'
except Exception as e:
logger.error(f"发送告警通知失败: {str(e)}")
alert_log.status = 'failed'
alert_log.notification_result = f"通知发送失败: {str(e)}"
finally:
db.commit()
@staticmethod
async def send_email_notification(
alert_log: AlertLog,
rule: AlertRule
):
"""
发送邮件通知
Args:
alert_log: 告警日志
rule: 告警规则
"""
config = rule.notification_config or {}
smtp_host = config.get('smtp_host', 'smtp.gmail.com')
smtp_port = config.get('smtp_port', 587)
smtp_user = config.get('smtp_user')
smtp_password = config.get('smtp_password')
to_email = config.get('to_email')
if not to_email:
raise ValueError("邮件通知配置缺少收件人地址")
# 创建邮件
message = MIMEMultipart()
message['From'] = smtp_user
message['To'] = to_email
message['Subject'] = f"告警通知: {rule.name}"
body = f"""
告警规则: {rule.name}
告警类型: {alert_log.alert_type}
严重程度: {alert_log.severity}
告警消息: {alert_log.message}
触发时间: {alert_log.triggered_at}
"""
if alert_log.details:
body += f"\n详细信息:\n{alert_log.details}"
message.attach(MIMEText(body, 'plain', 'utf-8'))
# 发送邮件
await aiosmtplib.send(
message,
hostname=smtp_host,
port=smtp_port,
username=smtp_user,
password=smtp_password,
use_tls=True
)
@staticmethod
async def send_webhook_notification(
alert_log: AlertLog,
rule: AlertRule
):
"""
发送Webhook通知
Args:
alert_log: 告警日志
rule: 告警规则
"""
config = rule.notification_config or {}
webhook_url = config.get('webhook_url')
if not webhook_url:
raise ValueError("Webhook通知配置缺少URL")
# 构建请求数据
payload = {
"rule_name": rule.name,
"alert_type": alert_log.alert_type,
"severity": alert_log.severity,
"message": alert_log.message,
"details": alert_log.details,
"triggered_at": alert_log.triggered_at.isoformat() if alert_log.triggered_at else None
}
# 发送HTTP请求
async with httpx.AsyncClient() as client:
response = await client.post(
webhook_url,
json=payload,
headers=config.get('headers', {}),
timeout=10
)
response.raise_for_status()

View File

@@ -0,0 +1,276 @@
"""
条件表达式解析器
支持更复杂的条件判断表达式
"""
import re
import json
from typing import Dict, Any, Union
class ConditionParser:
"""条件表达式解析器"""
# 支持的运算符
OPERATORS = {
'==': lambda a, b: a == b,
'!=': lambda a, b: a != b,
'>': lambda a, b: a > b,
'>=': lambda a, b: a >= b,
'<': lambda a, b: a < b,
'<=': lambda a, b: a <= b,
'in': lambda a, b: a in b if isinstance(b, (list, str, dict)) else False,
'not in': lambda a, b: a not in b if isinstance(b, (list, str, dict)) else False,
'contains': lambda a, b: b in str(a) if a is not None else False,
'not contains': lambda a, b: b not in str(a) if a is not None else True,
}
# 逻辑运算符
LOGICAL_OPERATORS = {
'and': lambda a, b: a and b,
'or': lambda a, b: a or b,
'not': lambda a: not a,
}
@staticmethod
def get_value(path: str, data: Dict[str, Any]) -> Any:
"""
从数据中获取值(支持嵌套路径)
Args:
path: 路径,如 'user.name''items[0].price'
data: 数据字典
Returns:
如果不存在返回None
"""
try:
# 处理数组索引,如 items[0]
if '[' in path and ']' in path:
parts = re.split(r'\[|\]', path)
value = data
for part in parts:
if not part:
continue
if part.isdigit():
value = value[int(part)]
else:
value = value.get(part) if isinstance(value, dict) else None
if value is None:
return None
return value
# 处理嵌套路径,如 user.name
keys = path.split('.')
value = data
for key in keys:
if isinstance(value, dict):
value = value.get(key)
elif isinstance(value, list) and key.isdigit():
value = value[int(key)] if int(key) < len(value) else None
else:
return None
if value is None:
return None
return value
except (KeyError, IndexError, TypeError, AttributeError):
return None
@staticmethod
def parse_value(value_str: str) -> Any:
"""
解析值字符串支持字符串、数字、布尔值、JSON
Args:
value_str: 值字符串
Returns:
解析后的值
"""
value_str = value_str.strip()
# 布尔值
if value_str.lower() == 'true':
return True
if value_str.lower() == 'false':
return False
# None
if value_str.lower() == 'null' or value_str.lower() == 'none':
return None
# 数字
try:
if '.' in value_str:
return float(value_str)
return int(value_str)
except ValueError:
pass
# JSON
if value_str.startswith('{') or value_str.startswith('['):
try:
return json.loads(value_str)
except json.JSONDecodeError:
pass
# 字符串(移除引号)
if (value_str.startswith('"') and value_str.endswith('"')) or \
(value_str.startswith("'") and value_str.endswith("'")):
return value_str[1:-1]
return value_str
@staticmethod
def evaluate_simple_condition(condition: str, data: Dict[str, Any]) -> bool:
"""
评估简单条件表达式
支持的格式:
- {key} == value
- {key} > value
- {key} in [value1, value2]
- {key} contains "text"
Args:
condition: 条件表达式
data: 输入数据
Returns:
条件结果
"""
condition = condition.strip()
# 替换变量 {key}
for key, value in data.items():
placeholder = f'{{{key}}}'
if placeholder in condition:
# 如果值是复杂类型转换为JSON字符串
if isinstance(value, (dict, list)):
condition = condition.replace(placeholder, json.dumps(value, ensure_ascii=False))
else:
condition = condition.replace(placeholder, str(value))
# 尝试解析为Python表达式安全方式
try:
# 只允许安全的操作
safe_dict = {
'__builtins__': {},
'True': True,
'False': False,
'None': None,
'null': None,
}
# 添加数据中的值到安全字典
for key, value in data.items():
# 只添加简单的值,避免复杂对象
if isinstance(value, (str, int, float, bool, type(None))):
safe_dict[key] = value
# 尝试评估
result = eval(condition, safe_dict)
if isinstance(result, bool):
return result
except:
pass
# 如果eval失败尝试手动解析
# 匹配运算符
for op in ConditionParser.OPERATORS.keys():
if op in condition:
parts = condition.split(op, 1)
if len(parts) == 2:
left = parts[0].strip()
right = parts[1].strip()
# 获取左侧值
if left.startswith('{') and left.endswith('}'):
key = left[1:-1]
left_value = ConditionParser.get_value(key, data)
else:
left_value = ConditionParser.parse_value(left)
# 获取右侧值
right_value = ConditionParser.parse_value(right)
# 执行比较
if left_value is not None:
return ConditionParser.OPERATORS[op](left_value, right_value)
# 默认返回False
return False
@staticmethod
def evaluate_condition(condition: str, data: Dict[str, Any]) -> bool:
"""
评估条件表达式(支持复杂表达式)
支持的格式:
- 简单条件: {key} == value
- 逻辑组合: {key} > 10 and {key} < 20
- 括号分组: ({key} == 'a' or {key} == 'b') and {other} > 0
Args:
condition: 条件表达式
data: 输入数据
Returns:
条件结果
"""
if not condition:
return False
condition = condition.strip()
# 处理括号表达式(递归处理)
def process_parentheses(expr: str) -> str:
"""处理括号表达式"""
while '(' in expr and ')' in expr:
# 找到最内层的括号
start = expr.rfind('(')
end = expr.find(')', start)
if end == -1:
break
# 提取括号内的表达式
inner_expr = expr[start+1:end]
inner_result = ConditionParser.evaluate_condition(inner_expr, data)
# 替换括号表达式为结果
expr = expr[:start] + str(inner_result) + expr[end+1:]
return expr
condition = process_parentheses(condition)
# 分割逻辑运算符,按优先级处理
# 先处理 and优先级更高
if ' and ' in condition.lower():
parts = re.split(r'\s+and\s+', condition, flags=re.IGNORECASE)
results = []
for part in parts:
part = part.strip()
if part:
results.append(ConditionParser.evaluate_simple_condition(part, data))
return all(results)
# 再处理 or
if ' or ' in condition.lower():
parts = re.split(r'\s+or\s+', condition, flags=re.IGNORECASE)
results = []
for part in parts:
part = part.strip()
if part:
results.append(ConditionParser.evaluate_simple_condition(part, data))
return any(results)
# 处理 not
if condition.lower().startswith('not '):
inner = condition[4:].strip()
return not ConditionParser.evaluate_simple_condition(inner, data)
# 最终评估简单条件
return ConditionParser.evaluate_simple_condition(condition, data)
# 全局实例
condition_parser = ConditionParser()

View File

@@ -0,0 +1,285 @@
"""
数据源连接器服务
"""
from typing import Dict, Any, List, Optional
import logging
import json
logger = logging.getLogger(__name__)
class DataSourceConnector:
"""数据源连接器基类"""
def __init__(self, source_type: str, config: Dict[str, Any]):
"""
初始化数据源连接器
Args:
source_type: 数据源类型
config: 连接配置
"""
self.source_type = source_type
self.config = config
def test_connection(self) -> Dict[str, Any]:
"""
测试连接
Returns:
连接测试结果
"""
raise NotImplementedError("子类必须实现test_connection方法")
def query(self, query_params: Dict[str, Any]) -> Any:
"""
查询数据
Args:
query_params: 查询参数
Returns:
查询结果
"""
raise NotImplementedError("子类必须实现query方法")
class MySQLConnector(DataSourceConnector):
"""MySQL连接器"""
def test_connection(self) -> Dict[str, Any]:
try:
import pymysql
connection = pymysql.connect(
host=self.config.get('host'),
port=self.config.get('port', 3306),
user=self.config.get('user'),
password=self.config.get('password'),
database=self.config.get('database'),
connect_timeout=5
)
connection.close()
return {"status": "success", "message": "连接成功"}
except Exception as e:
raise Exception(f"MySQL连接失败: {str(e)}")
def query(self, query_params: Dict[str, Any]) -> List[Dict[str, Any]]:
try:
import pymysql
sql = query_params.get('sql')
if not sql:
raise ValueError("缺少SQL查询语句")
connection = pymysql.connect(
host=self.config.get('host'),
port=self.config.get('port', 3306),
user=self.config.get('user'),
password=self.config.get('password'),
database=self.config.get('database')
)
try:
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
cursor.execute(sql)
result = cursor.fetchall()
return result
finally:
connection.close()
except Exception as e:
raise Exception(f"MySQL查询失败: {str(e)}")
class PostgreSQLConnector(DataSourceConnector):
"""PostgreSQL连接器"""
def test_connection(self) -> Dict[str, Any]:
try:
import psycopg2
connection = psycopg2.connect(
host=self.config.get('host'),
port=self.config.get('port', 5432),
user=self.config.get('user'),
password=self.config.get('password'),
database=self.config.get('database'),
connect_timeout=5
)
connection.close()
return {"status": "success", "message": "连接成功"}
except Exception as e:
raise Exception(f"PostgreSQL连接失败: {str(e)}")
def query(self, query_params: Dict[str, Any]) -> List[Dict[str, Any]]:
try:
import psycopg2
import psycopg2.extras
sql = query_params.get('sql')
if not sql:
raise ValueError("缺少SQL查询语句")
connection = psycopg2.connect(
host=self.config.get('host'),
port=self.config.get('port', 5432),
user=self.config.get('user'),
password=self.config.get('password'),
database=self.config.get('database')
)
try:
with connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
cursor.execute(sql)
result = cursor.fetchall()
return [dict(row) for row in result]
finally:
connection.close()
except Exception as e:
raise Exception(f"PostgreSQL查询失败: {str(e)}")
class APIConnector(DataSourceConnector):
"""API连接器"""
def test_connection(self) -> Dict[str, Any]:
try:
import httpx
url = self.config.get('base_url')
if not url:
raise ValueError("缺少base_url配置")
headers = self.config.get('headers', {})
timeout = self.config.get('timeout', 10)
response = httpx.get(url, headers=headers, timeout=timeout)
response.raise_for_status()
return {"status": "success", "message": "连接成功", "status_code": response.status_code}
except Exception as e:
raise Exception(f"API连接失败: {str(e)}")
def query(self, query_params: Dict[str, Any]) -> Any:
try:
import httpx
method = query_params.get('method', 'GET').upper()
endpoint = query_params.get('endpoint', '')
params = query_params.get('params', {})
data = query_params.get('data', {})
headers = self.config.get('headers', {})
timeout = self.config.get('timeout', 10)
base_url = self.config.get('base_url', '').rstrip('/')
url = f"{base_url}/{endpoint.lstrip('/')}"
if method == 'GET':
response = httpx.get(url, params=params, headers=headers, timeout=timeout)
elif method == 'POST':
response = httpx.post(url, json=data, headers=headers, timeout=timeout)
elif method == 'PUT':
response = httpx.put(url, json=data, headers=headers, timeout=timeout)
elif method == 'DELETE':
response = httpx.delete(url, headers=headers, timeout=timeout)
else:
raise ValueError(f"不支持的HTTP方法: {method}")
response.raise_for_status()
return response.json() if response.content else {}
except Exception as e:
raise Exception(f"API查询失败: {str(e)}")
class JSONFileConnector(DataSourceConnector):
"""JSON文件连接器"""
def test_connection(self) -> Dict[str, Any]:
try:
import os
file_path = self.config.get('file_path')
if not file_path:
raise ValueError("缺少file_path配置")
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
return {"status": "success", "message": "文件存在"}
except Exception as e:
raise Exception(f"JSON文件连接失败: {str(e)}")
def query(self, query_params: Dict[str, Any]) -> Any:
try:
import json
import os
file_path = self.config.get('file_path')
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 支持简单的查询过滤
filter_path = query_params.get('path')
if filter_path:
# 支持JSONPath风格的路径查询
parts = filter_path.split('.')
result = data
for part in parts:
if isinstance(result, dict):
result = result.get(part)
elif isinstance(result, list):
try:
index = int(part)
result = result[index]
except (ValueError, IndexError):
return None
else:
return None
return result
return data
except Exception as e:
raise Exception(f"JSON文件查询失败: {str(e)}")
# 连接器工厂
_connector_classes = {
'mysql': MySQLConnector,
'postgresql': PostgreSQLConnector,
'api': APIConnector,
'json': JSONFileConnector,
}
def create_connector(source_type: str, config: Dict[str, Any]):
"""
创建数据源连接器
Args:
source_type: 数据源类型
config: 连接配置
Returns:
数据源连接器实例
"""
connector_class = _connector_classes.get(source_type)
if not connector_class:
raise ValueError(f"不支持的数据源类型: {source_type}")
return connector_class(source_type, config)
# 为了兼容API创建一个统一的DataSourceConnector包装类
class DataSourceConnectorWrapper:
"""统一的数据源连接器包装类用于API调用"""
def __init__(self, source_type: str, config: Dict[str, Any]):
self.connector = create_connector(source_type, config)
self.source_type = source_type
self.config = config
def test_connection(self) -> Dict[str, Any]:
return self.connector.test_connection()
def query(self, query_params: Dict[str, Any]) -> Any:
return self.connector.query(query_params)
# 导出时使用包装类这样API可以统一使用DataSourceConnector
# 但实际返回的是具体的连接器实现

View File

@@ -0,0 +1,255 @@
"""
数据转换服务
支持字段映射、数据过滤、数据转换等功能
"""
from typing import Dict, Any, List, Optional, Union
import json
import re
class DataTransformer:
"""数据转换器"""
@staticmethod
def get_nested_value(data: Dict[str, Any], path: str) -> Any:
"""
从嵌套字典中获取值
Args:
data: 数据字典
path: 路径,如 'user.name''items[0].price'
Returns:
如果不存在返回None
"""
try:
# 处理混合路径,如 items[0].price
if '[' in path and ']' in path:
# 先处理数组索引部分
bracket_match = re.search(r'(\w+)\[(\d+)\]', path)
if bracket_match:
array_key = bracket_match.group(1)
array_index = int(bracket_match.group(2))
rest_path = path[bracket_match.end():]
# 获取数组
if array_key in data and isinstance(data[array_key], list):
if array_index < len(data[array_key]):
array_item = data[array_key][array_index]
# 如果还有后续路径,继续获取
if rest_path.startswith('.'):
return DataTransformer.get_nested_value(array_item, rest_path[1:])
else:
return array_item
return None
# 处理嵌套路径,如 user.name
keys = path.split('.')
value = data
for key in keys:
if isinstance(value, dict):
value = value.get(key)
elif isinstance(value, list) and key.isdigit():
value = value[int(key)] if int(key) < len(value) else None
else:
return None
if value is None:
return None
return value
except (KeyError, IndexError, TypeError, AttributeError):
return None
@staticmethod
def set_nested_value(data: Dict[str, Any], path: str, value: Any) -> None:
"""
在嵌套字典中设置值
Args:
data: 数据字典
path: 路径,如 'user.name'
value: 要设置的值
"""
keys = path.split('.')
current = data
# 创建嵌套结构
for key in keys[:-1]:
if key not in current:
current[key] = {}
current = current[key]
# 设置值
current[keys[-1]] = value
@staticmethod
def transform_mapping(input_data: Dict[str, Any], mapping: Dict[str, str]) -> Dict[str, Any]:
"""
字段映射转换
Args:
input_data: 输入数据
mapping: 映射规则,格式: {"target_key": "source_key"}
Returns:
转换后的数据
"""
result = {}
for target_key, source_key in mapping.items():
value = DataTransformer.get_nested_value(input_data, source_key)
if value is not None:
# 如果目标键包含点或方括号,使用嵌套设置
if '.' in target_key or '[' in target_key:
DataTransformer.set_nested_value(result, target_key, value)
else:
# 简单键直接设置
result[target_key] = value
return result
@staticmethod
def transform_filter(input_data: Dict[str, Any], filter_rules: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
数据过滤
Args:
input_data: 输入数据
filter_rules: 过滤规则列表,格式: [{"field": "key", "operator": ">", "value": 10}]
Returns:
过滤后的数据
"""
result = {}
for rule in filter_rules:
field = rule.get('field')
operator = rule.get('operator', '==')
value = rule.get('value')
if not field:
continue
field_value = DataTransformer.get_nested_value(input_data, field)
# 应用过滤规则
should_include = False
if operator == '==' and field_value == value:
should_include = True
elif operator == '!=' and field_value != value:
should_include = True
elif operator == '>' and field_value > value:
should_include = True
elif operator == '>=' and field_value >= value:
should_include = True
elif operator == '<' and field_value < value:
should_include = True
elif operator == '<=' and field_value <= value:
should_include = True
elif operator == 'in' and field_value in value:
should_include = True
elif operator == 'not in' and field_value not in value:
should_include = True
if should_include:
# 包含该字段
if field in input_data:
result[field] = input_data[field]
else:
# 如果是嵌套字段,需要重建结构
DataTransformer.set_nested_value(result, field, field_value)
return result if result else input_data
@staticmethod
def transform_compute(input_data: Dict[str, Any], compute_rules: Dict[str, str]) -> Dict[str, Any]:
"""
数据计算转换
Args:
input_data: 输入数据
compute_rules: 计算规则,格式: {"result": "{a} + {b}"}
Returns:
转换后的数据
"""
result = input_data.copy()
for target_key, expression in compute_rules.items():
try:
# 替换变量
computed_expression = expression
for key, value in input_data.items():
placeholder = f'{{{key}}}'
if placeholder in computed_expression:
if isinstance(value, (dict, list)):
computed_expression = computed_expression.replace(
placeholder,
json.dumps(value, ensure_ascii=False)
)
else:
computed_expression = computed_expression.replace(
placeholder,
str(value)
)
# 安全评估表达式
safe_dict = {
'__builtins__': {},
'abs': abs,
'min': min,
'max': max,
'sum': sum,
'len': len,
}
# 添加输入数据中的值
for key, value in input_data.items():
if isinstance(value, (str, int, float, bool, type(None))):
safe_dict[key] = value
computed_value = eval(computed_expression, safe_dict)
result[target_key] = computed_value
except Exception as e:
# 计算失败,跳过该字段
result[target_key] = None
return result
@staticmethod
def transform_data(
input_data: Dict[str, Any],
mapping: Optional[Dict[str, str]] = None,
filter_rules: Optional[List[Dict[str, Any]]] = None,
compute_rules: Optional[Dict[str, str]] = None,
mode: str = 'mapping'
) -> Dict[str, Any]:
"""
数据转换(综合方法)
Args:
input_data: 输入数据
mapping: 字段映射规则
filter_rules: 过滤规则
compute_rules: 计算规则
mode: 转换模式 ('mapping', 'filter', 'compute', 'all')
Returns:
转换后的数据
"""
result = input_data.copy()
if mode == 'mapping' or mode == 'all':
if mapping:
result = DataTransformer.transform_mapping(result, mapping)
if mode == 'filter' or mode == 'all':
if filter_rules:
result = DataTransformer.transform_filter(result, filter_rules)
if mode == 'compute' or mode == 'all':
if compute_rules:
result = DataTransformer.transform_compute(result, compute_rules)
return result
# 全局实例
data_transformer = DataTransformer()

View File

@@ -0,0 +1,132 @@
"""
加密服务
提供敏感数据的加密和解密功能
"""
from cryptography.fernet import Fernet
from app.core.config import settings
import base64
import hashlib
import logging
logger = logging.getLogger(__name__)
class EncryptionService:
"""加密服务类"""
_fernet: Fernet = None
@classmethod
def _get_fernet(cls) -> Fernet:
"""获取Fernet加密实例单例模式"""
if cls._fernet is None:
# 使用SECRET_KEY生成Fernet密钥
# Fernet需要32字节的密钥我们使用SHA256哈希SECRET_KEY
key = hashlib.sha256(settings.SECRET_KEY.encode()).digest()
# Fernet需要base64编码的32字节密钥
fernet_key = base64.urlsafe_b64encode(key)
cls._fernet = Fernet(fernet_key)
return cls._fernet
@classmethod
def encrypt(cls, plaintext: str) -> str:
"""
加密明文
Args:
plaintext: 要加密的明文
Returns:
加密后的密文base64编码
"""
if not plaintext:
return ""
try:
fernet = cls._get_fernet()
encrypted = fernet.encrypt(plaintext.encode('utf-8'))
return encrypted.decode('utf-8')
except Exception as e:
logger.error(f"加密失败: {e}")
raise ValueError(f"加密失败: {str(e)}")
@classmethod
def decrypt(cls, ciphertext: str) -> str:
"""
解密密文
Args:
ciphertext: 要解密的密文base64编码
Returns:
解密后的明文
"""
if not ciphertext:
return ""
try:
fernet = cls._get_fernet()
decrypted = fernet.decrypt(ciphertext.encode('utf-8'))
return decrypted.decode('utf-8')
except Exception as e:
logger.error(f"解密失败: {e}")
# 如果解密失败,可能是旧数据未加密,直接返回原值
# 或者抛出异常,让调用者处理
raise ValueError(f"解密失败: {str(e)}")
@classmethod
def encrypt_dict_value(cls, data: dict, key: str) -> dict:
"""
加密字典中指定键的值
Args:
data: 字典数据
key: 要加密的键名
Returns:
加密后的字典
"""
if key in data and data[key]:
data[key] = cls.encrypt(str(data[key]))
return data
@classmethod
def decrypt_dict_value(cls, data: dict, key: str) -> dict:
"""
解密字典中指定键的值
Args:
data: 字典数据
key: 要解密的键名
Returns:
解密后的字典
"""
if key in data and data[key]:
try:
data[key] = cls.decrypt(str(data[key]))
except ValueError:
# 如果解密失败,可能是未加密的数据,保持原值
pass
return data
@classmethod
def is_encrypted(cls, text: str) -> bool:
"""
判断文本是否已加密
Args:
text: 要检查的文本
Returns:
是否已加密
"""
if not text:
return False
try:
# 尝试解密,如果成功则说明已加密
cls.decrypt(text)
return True
except:
return False

View File

@@ -0,0 +1,117 @@
"""
执行日志服务
"""
from typing import Dict, Any, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.execution_log import ExecutionLog
import logging
logger = logging.getLogger(__name__)
class ExecutionLogger:
"""执行日志记录器"""
def __init__(self, execution_id: str, db: Session):
"""
初始化日志记录器
Args:
execution_id: 执行ID
db: 数据库会话
"""
self.execution_id = execution_id
self.db = db
def log(
self,
level: str,
message: str,
node_id: Optional[str] = None,
node_type: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
duration: Optional[int] = None
):
"""
记录日志
Args:
level: 日志级别 (INFO/WARN/ERROR/DEBUG)
message: 日志消息
node_id: 节点ID可选
node_type: 节点类型(可选)
data: 附加数据(可选)
duration: 执行耗时(毫秒,可选)
"""
try:
log_entry = ExecutionLog(
execution_id=self.execution_id,
node_id=node_id,
node_type=node_type,
level=level.upper(),
message=message,
data=data,
duration=duration,
timestamp=datetime.utcnow()
)
self.db.add(log_entry)
self.db.commit()
# 同时输出到标准日志
log_method = getattr(logger, level.lower(), logger.info)
log_msg = f"[执行 {self.execution_id}]"
if node_id:
log_msg += f" [节点 {node_id}]"
log_msg += f" {message}"
log_method(log_msg)
except Exception as e:
# 如果数据库记录失败,至少输出到标准日志
logger.error(f"记录执行日志失败: {str(e)}")
logger.error(f"[执行 {self.execution_id}] {message}")
def info(self, message: str, **kwargs):
"""记录INFO级别日志"""
self.log("INFO", message, **kwargs)
def warn(self, message: str, **kwargs):
"""记录WARN级别日志"""
self.log("WARN", message, **kwargs)
def error(self, message: str, **kwargs):
"""记录ERROR级别日志"""
self.log("ERROR", message, **kwargs)
def debug(self, message: str, **kwargs):
"""记录DEBUG级别日志"""
self.log("DEBUG", message, **kwargs)
def log_node_start(self, node_id: str, node_type: str, input_data: Optional[Dict[str, Any]] = None):
"""记录节点开始执行"""
self.info(
f"节点 {node_id} ({node_type}) 开始执行",
node_id=node_id,
node_type=node_type,
data={"input": input_data} if input_data else None
)
def log_node_complete(self, node_id: str, node_type: str, output_data: Optional[Dict[str, Any]] = None, duration: Optional[int] = None):
"""记录节点执行完成"""
self.info(
f"节点 {node_id} ({node_type}) 执行完成",
node_id=node_id,
node_type=node_type,
data={"output": output_data} if output_data else None,
duration=duration
)
def log_node_error(self, node_id: str, node_type: str, error: Exception, duration: Optional[int] = None):
"""记录节点执行错误"""
self.error(
f"节点 {node_id} ({node_type}) 执行失败: {str(error)}",
node_id=node_id,
node_type=node_type,
data={"error": str(error), "error_type": type(error).__name__},
duration=duration
)

View File

@@ -0,0 +1,220 @@
"""
LLM服务 - 处理各种LLM提供商的调用
"""
from typing import Dict, Any, Optional
import json
from openai import AsyncOpenAI
from app.core.config import settings
class LLMService:
"""LLM服务类"""
def __init__(self):
"""初始化LLM服务"""
self.openai_client = None
self.deepseek_client = None
# 初始化OpenAI客户端
if settings.OPENAI_API_KEY:
self.openai_client = AsyncOpenAI(
api_key=settings.OPENAI_API_KEY,
base_url=settings.OPENAI_BASE_URL
)
# 初始化DeepSeek客户端兼容OpenAI API
if settings.DEEPSEEK_API_KEY:
self.deepseek_client = AsyncOpenAI(
api_key=settings.DEEPSEEK_API_KEY,
base_url=settings.DEEPSEEK_BASE_URL
)
async def call_openai(
self,
prompt: str,
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs
) -> str:
"""
调用OpenAI API
Args:
prompt: 提示词
model: 模型名称默认gpt-3.5-turbo
temperature: 温度参数默认0.7
max_tokens: 最大token数
api_key: API密钥可选如果不提供则使用默认配置
base_url: API地址可选如果不提供则使用默认配置
**kwargs: 其他参数
Returns:
LLM返回的文本
"""
# 如果提供了api_key或base_url创建临时客户端
# 注意api_key 可能是空字符串,需要检查是否为 None
if api_key is not None or base_url is not None:
# 如果提供了 api_key使用它否则使用系统默认配置
final_api_key = api_key if api_key else settings.OPENAI_API_KEY
final_base_url = base_url if base_url else settings.OPENAI_BASE_URL
if not final_api_key:
raise ValueError("OpenAI API Key未配置请在节点配置中设置API Key或在环境变量中设置OPENAI_API_KEY")
client = AsyncOpenAI(
api_key=final_api_key,
base_url=final_base_url
)
else:
# 如果 openai_client 未初始化,尝试从 settings 重新读取并初始化
if not self.openai_client:
if settings.OPENAI_API_KEY:
self.openai_client = AsyncOpenAI(
api_key=settings.OPENAI_API_KEY,
base_url=settings.OPENAI_BASE_URL
)
else:
raise ValueError("OpenAI API Key未配置请在节点配置中设置API Key或在环境变量中设置OPENAI_API_KEY")
client = self.openai_client
try:
response = await client.chat.completions.create(
model=model,
messages=[
{"role": "user", "content": prompt}
],
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
content = response.choices[0].message.content
if content is None:
raise Exception("OpenAI API返回的内容为空请检查API配置和模型名称")
return content
except Exception as e:
raise Exception(f"OpenAI API调用失败: {str(e)}")
async def call_deepseek(
self,
prompt: str,
model: str = "deepseek-chat",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs
) -> str:
"""
调用DeepSeek API
Args:
prompt: 提示词
model: 模型名称默认deepseek-chat
temperature: 温度参数默认0.7
max_tokens: 最大token数
api_key: API密钥可选如果不提供则使用默认配置
base_url: API地址可选如果不提供则使用默认配置
**kwargs: 其他参数
Returns:
LLM返回的文本
"""
# 如果提供了api_key或base_url创建临时客户端
# 注意api_key 可能是空字符串,需要检查是否为 None
if api_key is not None or base_url is not None:
# 如果提供了 api_key使用它否则使用系统默认配置
final_api_key = api_key if api_key else settings.DEEPSEEK_API_KEY
final_base_url = base_url if base_url else settings.DEEPSEEK_BASE_URL
if not final_api_key:
raise ValueError("DeepSeek API Key未配置请在节点配置中设置API Key或在环境变量中设置DEEPSEEK_API_KEY")
client = AsyncOpenAI(
api_key=final_api_key,
base_url=final_base_url
)
else:
# 如果 deepseek_client 未初始化,尝试从 settings 重新读取并初始化
if not self.deepseek_client:
if settings.DEEPSEEK_API_KEY:
self.deepseek_client = AsyncOpenAI(
api_key=settings.DEEPSEEK_API_KEY,
base_url=settings.DEEPSEEK_BASE_URL
)
else:
raise ValueError("DeepSeek API Key未配置请在节点配置中设置API Key或在环境变量中设置DEEPSEEK_API_KEY")
client = self.deepseek_client
try:
response = await client.chat.completions.create(
model=model,
messages=[
{"role": "user", "content": prompt}
],
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
content = response.choices[0].message.content
if content is None:
raise Exception("DeepSeek API返回的内容为空请检查API配置和模型名称")
return content
except Exception as e:
raise Exception(f"DeepSeek API调用失败: {str(e)}")
async def call_llm(
self,
prompt: str,
provider: str = "openai",
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> str:
"""
通用LLM调用接口
Args:
prompt: 提示词
provider: 提供商支持openai、deepseek
model: 模型名称
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 其他参数
Returns:
LLM返回的文本
"""
if provider == "openai":
# 默认模型
if not model:
model = "gpt-3.5-turbo"
return await self.call_openai(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
elif provider == "deepseek":
# 默认模型
if not model:
model = "deepseek-chat"
return await self.call_deepseek(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
else:
raise ValueError(f"不支持的LLM提供商: {provider},目前支持: openai, deepseek")
# 全局LLM服务实例
llm_service = LLMService()

View File

@@ -0,0 +1,276 @@
"""
系统监控服务
提供系统状态、执行统计、性能指标等监控数据
"""
from sqlalchemy.orm import Session
from sqlalchemy import func, and_, case
from datetime import datetime, timedelta
from typing import Dict, Any, List
from app.models.user import User
from app.models.workflow import Workflow
from app.models.agent import Agent
from app.models.execution import Execution
from app.models.execution_log import ExecutionLog
from app.models.data_source import DataSource
from app.models.model_config import ModelConfig
import logging
logger = logging.getLogger(__name__)
class MonitoringService:
"""系统监控服务"""
@staticmethod
def get_system_overview(db: Session, user_id: str = None) -> Dict[str, Any]:
"""
获取系统概览统计
Args:
db: 数据库会话
user_id: 用户ID如果提供则只统计该用户的数据
Returns:
系统概览数据
"""
# 构建基础查询条件
user_filter = Workflow.user_id == user_id if user_id else True
# 统计工作流数量
workflow_count = db.query(func.count(Workflow.id)).filter(user_filter).scalar() or 0
# 统计Agent数量
agent_filter = Agent.user_id == user_id if user_id else True
agent_count = db.query(func.count(Agent.id)).filter(agent_filter).scalar() or 0
# 统计执行记录数量
execution_filter = None
if user_id:
execution_filter = Execution.workflow_id.in_(
db.query(Workflow.id).filter(Workflow.user_id == user_id)
)
execution_count = db.query(func.count(Execution.id)).filter(
execution_filter if execution_filter else True
).scalar() or 0
# 统计数据源数量
data_source_filter = DataSource.user_id == user_id if user_id else True
data_source_count = db.query(func.count(DataSource.id)).filter(
data_source_filter
).scalar() or 0
# 统计模型配置数量
model_config_filter = ModelConfig.user_id == user_id if user_id else True
model_config_count = db.query(func.count(ModelConfig.id)).filter(
model_config_filter
).scalar() or 0
# 统计用户数量(仅管理员可见)
user_count = None
if not user_id:
user_count = db.query(func.count(User.id)).scalar() or 0
return {
"workflows": workflow_count,
"agents": agent_count,
"executions": execution_count,
"data_sources": data_source_count,
"model_configs": model_config_count,
"users": user_count
}
@staticmethod
def get_execution_statistics(
db: Session,
user_id: str = None,
days: int = 7
) -> Dict[str, Any]:
"""
获取执行统计信息
Args:
db: 数据库会话
user_id: 用户ID如果提供则只统计该用户的数据
days: 统计天数默认7天
Returns:
执行统计数据
"""
# 构建时间范围
end_time = datetime.utcnow()
start_time = end_time - timedelta(days=days)
# 构建查询条件
execution_filter = Execution.created_at >= start_time
if user_id:
execution_filter = and_(
execution_filter,
Execution.workflow_id.in_(
db.query(Workflow.id).filter(Workflow.user_id == user_id)
)
)
# 统计总执行数
total_executions = db.query(func.count(Execution.id)).filter(
execution_filter
).scalar() or 0
# 统计各状态执行数
status_stats = db.query(
Execution.status,
func.count(Execution.id).label('count')
).filter(execution_filter).group_by(Execution.status).all()
status_counts = {status: count for status, count in status_stats}
# 计算成功率
completed = status_counts.get('completed', 0)
failed = status_counts.get('failed', 0)
success_rate = (completed / total_executions * 100) if total_executions > 0 else 0
# 统计平均执行时间
avg_execution_time = db.query(
func.avg(Execution.execution_time)
).filter(
and_(execution_filter, Execution.execution_time.isnot(None))
).scalar() or 0
# 统计最近24小时的执行趋势
hourly_trends = []
for i in range(24):
hour_start = end_time - timedelta(hours=24-i)
hour_end = hour_start + timedelta(hours=1)
hour_filter = and_(
execution_filter,
Execution.created_at >= hour_start,
Execution.created_at < hour_end
)
hour_count = db.query(func.count(Execution.id)).filter(
hour_filter
).scalar() or 0
hourly_trends.append({
"hour": hour_start.strftime("%H:00"),
"count": hour_count
})
return {
"total": total_executions,
"status_counts": status_counts,
"success_rate": round(success_rate, 2),
"avg_execution_time": round(avg_execution_time, 2) if avg_execution_time else 0,
"hourly_trends": hourly_trends
}
@staticmethod
def get_node_type_statistics(
db: Session,
user_id: str = None,
days: int = 7
) -> List[Dict[str, Any]]:
"""
获取节点类型统计
Args:
db: 数据库会话
user_id: 用户ID如果提供则只统计该用户的数据
days: 统计天数默认7天
Returns:
节点类型统计数据
"""
# 构建时间范围
end_time = datetime.utcnow()
start_time = end_time - timedelta(days=days)
# 构建查询条件
execution_filter = Execution.created_at >= start_time
if user_id:
execution_filter = and_(
execution_filter,
Execution.workflow_id.in_(
db.query(Workflow.id).filter(Workflow.user_id == user_id)
)
)
# 获取符合条件的执行ID列表
execution_ids_query = db.query(Execution.id).filter(execution_filter)
execution_ids = [row[0] for row in execution_ids_query.all()]
if not execution_ids:
return []
# 统计各节点类型的执行情况
node_stats = db.query(
ExecutionLog.node_type,
func.count(ExecutionLog.id).label('execution_count'),
func.sum(ExecutionLog.duration).label('total_duration'),
func.avg(ExecutionLog.duration).label('avg_duration'),
func.count(
case((ExecutionLog.level == 'ERROR', 1))
).label('error_count')
).filter(
and_(
ExecutionLog.execution_id.in_(execution_ids),
ExecutionLog.node_type.isnot(None),
ExecutionLog.duration.isnot(None)
)
).group_by(ExecutionLog.node_type).all()
result = []
for node_type, exec_count, total_dur, avg_dur, error_count in node_stats:
result.append({
"node_type": node_type,
"execution_count": exec_count,
"total_duration": round(total_dur or 0, 2),
"avg_duration": round(avg_dur or 0, 2),
"error_count": error_count,
"success_rate": round((exec_count - error_count) / exec_count * 100, 2) if exec_count > 0 else 0
})
return result
@staticmethod
def get_recent_activities(
db: Session,
user_id: str = None,
limit: int = 10
) -> List[Dict[str, Any]]:
"""
获取最近的活动记录
Args:
db: 数据库会话
user_id: 用户ID如果提供则只统计该用户的数据
limit: 返回数量限制
Returns:
最近活动列表
"""
# 构建查询条件
execution_filter = True
if user_id:
execution_filter = Execution.workflow_id.in_(
db.query(Workflow.id).filter(Workflow.user_id == user_id)
)
# 获取最近的执行记录
recent_executions = db.query(Execution).filter(
execution_filter
).order_by(Execution.created_at.desc()).limit(limit).all()
result = []
for execution in recent_executions:
workflow = db.query(Workflow).filter(
Workflow.id == execution.workflow_id
).first() if execution.workflow_id else None
result.append({
"id": execution.id,
"type": "execution",
"workflow_name": workflow.name if workflow else "未知工作流",
"status": execution.status,
"created_at": execution.created_at.isoformat() if execution.created_at else None,
"execution_time": execution.execution_time
})
return result

View File

@@ -0,0 +1,110 @@
"""
权限服务
提供权限检查的辅助函数
"""
from sqlalchemy.orm import Session
from app.models.permission import WorkflowPermission, AgentPermission
from app.models.user import User
from app.models.workflow import Workflow
from app.models.agent import Agent
from typing import Optional
def check_workflow_permission(
db: Session,
user: User,
workflow: Workflow,
permission_type: str
) -> bool:
"""
检查用户对工作流的权限
Args:
db: 数据库会话
user: 用户对象
workflow: 工作流对象
permission_type: 权限类型read/write/execute/share
Returns:
bool: 是否有权限
"""
# 管理员拥有所有权限
if user.role == "admin":
return True
# 工作流所有者拥有所有权限
if workflow.user_id == user.id:
return True
# 检查用户直接权限
user_permission = db.query(WorkflowPermission).filter(
WorkflowPermission.workflow_id == workflow.id,
WorkflowPermission.user_id == user.id,
WorkflowPermission.permission_type == permission_type
).first()
if user_permission:
return True
# 检查角色权限
for role in user.roles:
role_permission = db.query(WorkflowPermission).filter(
WorkflowPermission.workflow_id == workflow.id,
WorkflowPermission.role_id == role.id,
WorkflowPermission.permission_type == permission_type
).first()
if role_permission:
return True
return False
def check_agent_permission(
db: Session,
user: User,
agent: Agent,
permission_type: str
) -> bool:
"""
检查用户对Agent的权限
Args:
db: 数据库会话
user: 用户对象
agent: Agent对象
permission_type: 权限类型read/write/execute/deploy
Returns:
bool: 是否有权限
"""
# 管理员拥有所有权限
if user.role == "admin":
return True
# Agent所有者拥有所有权限
if agent.user_id == user.id:
return True
# 检查用户直接权限
user_permission = db.query(AgentPermission).filter(
AgentPermission.agent_id == agent.id,
AgentPermission.user_id == user.id,
AgentPermission.permission_type == permission_type
).first()
if user_permission:
return True
# 检查角色权限
for role in user.roles:
role_permission = db.query(AgentPermission).filter(
AgentPermission.agent_id == agent.id,
AgentPermission.role_id == role.id,
AgentPermission.permission_type == permission_type
).first()
if role_permission:
return True
return False

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,323 @@
"""
工作流模板服务
提供预设的工作流模板,支持快速创建
"""
from typing import Dict, Any, List
import logging
logger = logging.getLogger(__name__)
# 预设工作流模板
WORKFLOW_TEMPLATES = {
"simple_llm": {
"name": "简单LLM工作流",
"description": "一个简单的LLM调用工作流包含开始、LLM和结束节点",
"nodes": [
{
"id": "start-1",
"type": "start",
"position": {"x": 100, "y": 100},
"data": {"label": "开始"}
},
{
"id": "llm-1",
"type": "llm",
"position": {"x": 100, "y": 250},
"data": {
"label": "LLM处理",
"prompt": "请处理以下输入:\n{input}",
"provider": "openai",
"model": "gpt-3.5-turbo",
"temperature": 0.7
}
},
{
"id": "end-1",
"type": "end",
"position": {"x": 100, "y": 400},
"data": {"label": "结束"}
}
],
"edges": [
{
"id": "e1",
"source": "start-1",
"target": "llm-1",
"sourceHandle": "bottom",
"targetHandle": "top"
},
{
"id": "e2",
"source": "llm-1",
"target": "end-1",
"sourceHandle": "bottom",
"targetHandle": "top"
}
]
},
"conditional_llm": {
"name": "条件判断LLM工作流",
"description": "根据条件判断调用不同的LLM处理",
"nodes": [
{
"id": "start-1",
"type": "start",
"position": {"x": 100, "y": 100},
"data": {"label": "开始"}
},
{
"id": "condition-1",
"type": "condition",
"position": {"x": 100, "y": 200},
"data": {
"label": "条件判断",
"condition": "{value} > 10"
}
},
{
"id": "llm-true",
"type": "llm",
"position": {"x": -100, "y": 350},
"data": {
"label": "True分支LLM",
"prompt": "值大于10请分析{input}",
"provider": "openai",
"model": "gpt-3.5-turbo"
}
},
{
"id": "llm-false",
"type": "llm",
"position": {"x": 300, "y": 350},
"data": {
"label": "False分支LLM",
"prompt": "值小于等于10请分析{input}",
"provider": "openai",
"model": "gpt-3.5-turbo"
}
},
{
"id": "end-1",
"type": "end",
"position": {"x": 100, "y": 500},
"data": {"label": "结束"}
}
],
"edges": [
{
"id": "e1",
"source": "start-1",
"target": "condition-1"
},
{
"id": "e2",
"source": "condition-1",
"target": "llm-true",
"sourceHandle": "true"
},
{
"id": "e3",
"source": "condition-1",
"target": "llm-false",
"sourceHandle": "false"
},
{
"id": "e4",
"source": "llm-true",
"target": "end-1"
},
{
"id": "e5",
"source": "llm-false",
"target": "end-1"
}
]
},
"data_transform_llm": {
"name": "数据转换+LLM工作流",
"description": "先进行数据转换再调用LLM处理",
"nodes": [
{
"id": "start-1",
"type": "start",
"position": {"x": 100, "y": 100},
"data": {"label": "开始"}
},
{
"id": "transform-1",
"type": "transform",
"position": {"x": 100, "y": 200},
"data": {
"label": "数据转换",
"mode": "mapping",
"mapping": {
"input_text": "raw_input",
"user_id": "id"
}
}
},
{
"id": "llm-1",
"type": "llm",
"position": {"x": 100, "y": 300},
"data": {
"label": "LLM处理",
"prompt": "处理转换后的数据:{input_text}",
"provider": "openai",
"model": "gpt-3.5-turbo"
}
},
{
"id": "end-1",
"type": "end",
"position": {"x": 100, "y": 400},
"data": {"label": "结束"}
}
],
"edges": [
{
"id": "e1",
"source": "start-1",
"target": "transform-1"
},
{
"id": "e2",
"source": "transform-1",
"target": "llm-1"
},
{
"id": "e3",
"source": "llm-1",
"target": "end-1"
}
]
},
"multi_llm_chain": {
"name": "多LLM链式工作流",
"description": "多个LLM节点链式调用实现复杂处理流程",
"nodes": [
{
"id": "start-1",
"type": "start",
"position": {"x": 100, "y": 100},
"data": {"label": "开始"}
},
{
"id": "llm-1",
"type": "llm",
"position": {"x": 100, "y": 200},
"data": {
"label": "第一步分析",
"prompt": "第一步:分析输入数据:{input}",
"provider": "openai",
"model": "gpt-3.5-turbo"
}
},
{
"id": "llm-2",
"type": "llm",
"position": {"x": 100, "y": 300},
"data": {
"label": "第二步处理",
"prompt": "第二步:基于第一步的结果进行处理:{input}",
"provider": "openai",
"model": "gpt-3.5-turbo"
}
},
{
"id": "llm-3",
"type": "llm",
"position": {"x": 100, "y": 400},
"data": {
"label": "第三步总结",
"prompt": "第三步:总结最终结果:{input}",
"provider": "openai",
"model": "gpt-3.5-turbo"
}
},
{
"id": "end-1",
"type": "end",
"position": {"x": 100, "y": 500},
"data": {"label": "结束"}
}
],
"edges": [
{
"id": "e1",
"source": "start-1",
"target": "llm-1"
},
{
"id": "e2",
"source": "llm-1",
"target": "llm-2"
},
{
"id": "e3",
"source": "llm-2",
"target": "llm-3"
},
{
"id": "e4",
"source": "llm-3",
"target": "end-1"
}
]
}
}
def get_template(template_id: str) -> Dict[str, Any]:
"""
获取工作流模板
Args:
template_id: 模板ID
Returns:
模板数据
"""
return WORKFLOW_TEMPLATES.get(template_id)
def list_templates() -> List[Dict[str, Any]]:
"""
获取所有模板列表
Returns:
模板列表每个模板包含id、name、description
"""
return [
{
"id": template_id,
"name": template["name"],
"description": template["description"]
}
for template_id, template in WORKFLOW_TEMPLATES.items()
]
def create_from_template(template_id: str, name: str = None, description: str = None) -> Dict[str, Any]:
"""
从模板创建工作流数据
Args:
template_id: 模板ID
name: 工作流名称(可选,默认使用模板名称)
description: 工作流描述(可选,默认使用模板描述)
Returns:
工作流数据
"""
template = get_template(template_id)
if not template:
raise ValueError(f"模板不存在: {template_id}")
return {
"name": name or template["name"],
"description": description or template["description"],
"nodes": template["nodes"],
"edges": template["edges"]
}

View File

@@ -0,0 +1,268 @@
"""
工作流验证服务
验证工作流的节点连接、数据流、循环检测等
"""
from typing import Dict, Any, List, Optional, Tuple
from collections import defaultdict, deque
import logging
logger = logging.getLogger(__name__)
class WorkflowValidator:
"""工作流验证器"""
def __init__(self, nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]]):
"""
初始化验证器
Args:
nodes: 节点列表
edges: 边列表
"""
self.nodes = {node['id']: node for node in nodes}
self.edges = edges
self.errors = []
self.warnings = []
def validate(self) -> Tuple[bool, List[str], List[str]]:
"""
执行完整验证
Returns:
(是否有效, 错误列表, 警告列表)
"""
self.errors = []
self.warnings = []
# 基础验证
self._validate_nodes()
self._validate_edges()
# 结构验证
self._validate_has_start_node()
self._validate_has_end_node()
self._validate_no_cycles()
self._validate_all_nodes_reachable()
# 连接验证
self._validate_node_connections()
self._validate_condition_branches()
# 配置验证
self._validate_node_configs()
return len(self.errors) == 0, self.errors, self.warnings
def _validate_nodes(self):
"""验证节点基础信息"""
if not self.nodes:
self.errors.append("工作流必须包含至少一个节点")
return
node_ids = set()
for node_id, node in self.nodes.items():
# 检查节点ID唯一性
if node_id in node_ids:
self.errors.append(f"节点ID重复: {node_id}")
node_ids.add(node_id)
# 检查节点类型
node_type = node.get('type')
if not node_type:
self.errors.append(f"节点 {node_id} 缺少类型")
elif node_type not in ['start', 'input', 'llm', 'condition', 'transform', 'output', 'end', 'default', 'loop', 'foreach', 'loop_end', 'agent', 'http', 'request', 'database', 'db', 'file', 'file_operation', 'schedule', 'delay', 'timer', 'webhook', 'email', 'mail', 'message_queue', 'mq', 'rabbitmq', 'kafka']:
self.warnings.append(f"节点 {node_id} 使用了未知类型: {node_type}")
def _validate_edges(self):
"""验证边的基础信息"""
for edge in self.edges:
source = edge.get('source')
target = edge.get('target')
if not source or not target:
self.errors.append(f"边缺少源节点或目标节点: {edge.get('id', 'unknown')}")
continue
# 检查源节点是否存在
if source not in self.nodes:
self.errors.append(f"边的源节点不存在: {source}")
# 检查目标节点是否存在
if target not in self.nodes:
self.errors.append(f"边的目标节点不存在: {target}")
# 检查自环
if source == target:
self.errors.append(f"节点 {source} 不能连接到自身")
def _validate_has_start_node(self):
"""验证是否有开始节点"""
start_nodes = [node for node in self.nodes.values() if node.get('type') == 'start']
if not start_nodes:
self.errors.append("工作流必须包含至少一个开始节点")
elif len(start_nodes) > 1:
self.warnings.append(f"工作流包含多个开始节点: {len(start_nodes)}")
def _validate_has_end_node(self):
"""验证是否有结束节点"""
end_nodes = [node for node in self.nodes.values() if node.get('type') == 'end']
if not end_nodes:
self.warnings.append("工作流建议包含至少一个结束节点")
def _validate_no_cycles(self):
"""验证工作流中是否有循环使用DFS"""
# 构建邻接表
graph = defaultdict(list)
for edge in self.edges:
source = edge.get('source')
target = edge.get('target')
if source and target:
graph[source].append(target)
# DFS检测循环
visited = set()
rec_stack = set()
def has_cycle(node_id: str) -> bool:
visited.add(node_id)
rec_stack.add(node_id)
for neighbor in graph.get(node_id, []):
if neighbor not in visited:
if has_cycle(neighbor):
return True
elif neighbor in rec_stack:
# 找到循环
self.errors.append(f"检测到循环: {node_id} -> {neighbor}")
return True
rec_stack.remove(node_id)
return False
for node_id in self.nodes.keys():
if node_id not in visited:
has_cycle(node_id)
def _validate_all_nodes_reachable(self):
"""验证所有节点是否可达"""
# 找到所有开始节点
start_nodes = [node_id for node_id, node in self.nodes.items() if node.get('type') == 'start']
if not start_nodes:
return # 如果没有开始节点,跳过此验证
# 从开始节点BFS遍历
reachable = set()
queue = deque(start_nodes)
while queue:
node_id = queue.popleft()
if node_id in reachable:
continue
reachable.add(node_id)
# 添加所有可达的节点
for edge in self.edges:
if edge.get('source') == node_id:
target = edge.get('target')
if target and target not in reachable:
queue.append(target)
# 检查未达节点
unreachable = set(self.nodes.keys()) - reachable
if unreachable:
self.warnings.append(f"以下节点不可达: {', '.join(unreachable)}")
def _validate_node_connections(self):
"""验证节点连接的正确性"""
# 检查开始节点是否有入边
for node_id, node in self.nodes.items():
if node.get('type') == 'start':
has_incoming = any(edge.get('target') == node_id for edge in self.edges)
if has_incoming:
self.warnings.append(f"开始节点 {node_id} 不应该有入边")
# 检查结束节点是否有出边
for node_id, node in self.nodes.items():
if node.get('type') == 'end':
has_outgoing = any(edge.get('source') == node_id for edge in self.edges)
if has_outgoing:
self.warnings.append(f"结束节点 {node_id} 不应该有出边")
def _validate_condition_branches(self):
"""验证条件节点的分支"""
for node_id, node in self.nodes.items():
if node.get('type') == 'condition':
# 检查是否有条件表达式
condition = node.get('data', {}).get('condition', '')
if not condition:
self.warnings.append(f"条件节点 {node_id} 没有配置条件表达式")
# 检查是否有true和false分支
true_edges = [e for e in self.edges if e.get('source') == node_id and e.get('sourceHandle') == 'true']
false_edges = [e for e in self.edges if e.get('source') == node_id and e.get('sourceHandle') == 'false']
if not true_edges and not false_edges:
self.warnings.append(f"条件节点 {node_id} 没有配置分支连接")
elif not true_edges:
self.warnings.append(f"条件节点 {node_id} 缺少True分支")
elif not false_edges:
self.warnings.append(f"条件节点 {node_id} 缺少False分支")
def _validate_node_configs(self):
"""验证节点配置"""
for node_id, node in self.nodes.items():
node_type = node.get('type')
node_data = node.get('data', {})
# LLM节点验证
if node_type == 'llm':
prompt = node_data.get('prompt', '')
if not prompt:
self.warnings.append(f"LLM节点 {node_id} 没有配置提示词")
provider = node_data.get('provider', 'openai')
model = node_data.get('model')
if not model:
self.warnings.append(f"LLM节点 {node_id} 没有配置模型")
# 转换节点验证
elif node_type == 'transform' or node_type == 'data':
mode = node_data.get('mode', 'mapping')
mapping = node_data.get('mapping', {})
filter_rules = node_data.get('filter_rules', [])
compute_rules = node_data.get('compute_rules', {})
if mode == 'mapping' and not mapping:
self.warnings.append(f"转换节点 {node_id} 选择了映射模式但没有配置映射规则")
elif mode == 'filter' and not filter_rules:
self.warnings.append(f"转换节点 {node_id} 选择了过滤模式但没有配置过滤规则")
elif mode == 'compute' and not compute_rules:
self.warnings.append(f"转换节点 {node_id} 选择了计算模式但没有配置计算规则")
def validate_workflow(nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
验证工作流
Args:
nodes: 节点列表
edges: 边列表
Returns:
验证结果字典
{
"valid": bool,
"errors": List[str],
"warnings": List[str]
}
"""
validator = WorkflowValidator(nodes, edges)
valid, errors, warnings = validator.validate()
return {
"valid": valid,
"errors": errors,
"warnings": warnings
}