Files
aiagent/backend/app/services/data_source_connector.py
2026-01-19 00:09:36 +08:00

286 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
数据源连接器服务
"""
from typing import Dict, Any, List, Optional
import logging
import json
logger = logging.getLogger(__name__)
class DataSourceConnector:
"""数据源连接器基类"""
def __init__(self, source_type: str, config: Dict[str, Any]):
"""
初始化数据源连接器
Args:
source_type: 数据源类型
config: 连接配置
"""
self.source_type = source_type
self.config = config
def test_connection(self) -> Dict[str, Any]:
"""
测试连接
Returns:
连接测试结果
"""
raise NotImplementedError("子类必须实现test_connection方法")
def query(self, query_params: Dict[str, Any]) -> Any:
"""
查询数据
Args:
query_params: 查询参数
Returns:
查询结果
"""
raise NotImplementedError("子类必须实现query方法")
class MySQLConnector(DataSourceConnector):
"""MySQL连接器"""
def test_connection(self) -> Dict[str, Any]:
try:
import pymysql
connection = pymysql.connect(
host=self.config.get('host'),
port=self.config.get('port', 3306),
user=self.config.get('user'),
password=self.config.get('password'),
database=self.config.get('database'),
connect_timeout=5
)
connection.close()
return {"status": "success", "message": "连接成功"}
except Exception as e:
raise Exception(f"MySQL连接失败: {str(e)}")
def query(self, query_params: Dict[str, Any]) -> List[Dict[str, Any]]:
try:
import pymysql
sql = query_params.get('sql')
if not sql:
raise ValueError("缺少SQL查询语句")
connection = pymysql.connect(
host=self.config.get('host'),
port=self.config.get('port', 3306),
user=self.config.get('user'),
password=self.config.get('password'),
database=self.config.get('database')
)
try:
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
cursor.execute(sql)
result = cursor.fetchall()
return result
finally:
connection.close()
except Exception as e:
raise Exception(f"MySQL查询失败: {str(e)}")
class PostgreSQLConnector(DataSourceConnector):
"""PostgreSQL连接器"""
def test_connection(self) -> Dict[str, Any]:
try:
import psycopg2
connection = psycopg2.connect(
host=self.config.get('host'),
port=self.config.get('port', 5432),
user=self.config.get('user'),
password=self.config.get('password'),
database=self.config.get('database'),
connect_timeout=5
)
connection.close()
return {"status": "success", "message": "连接成功"}
except Exception as e:
raise Exception(f"PostgreSQL连接失败: {str(e)}")
def query(self, query_params: Dict[str, Any]) -> List[Dict[str, Any]]:
try:
import psycopg2
import psycopg2.extras
sql = query_params.get('sql')
if not sql:
raise ValueError("缺少SQL查询语句")
connection = psycopg2.connect(
host=self.config.get('host'),
port=self.config.get('port', 5432),
user=self.config.get('user'),
password=self.config.get('password'),
database=self.config.get('database')
)
try:
with connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
cursor.execute(sql)
result = cursor.fetchall()
return [dict(row) for row in result]
finally:
connection.close()
except Exception as e:
raise Exception(f"PostgreSQL查询失败: {str(e)}")
class APIConnector(DataSourceConnector):
"""API连接器"""
def test_connection(self) -> Dict[str, Any]:
try:
import httpx
url = self.config.get('base_url')
if not url:
raise ValueError("缺少base_url配置")
headers = self.config.get('headers', {})
timeout = self.config.get('timeout', 10)
response = httpx.get(url, headers=headers, timeout=timeout)
response.raise_for_status()
return {"status": "success", "message": "连接成功", "status_code": response.status_code}
except Exception as e:
raise Exception(f"API连接失败: {str(e)}")
def query(self, query_params: Dict[str, Any]) -> Any:
try:
import httpx
method = query_params.get('method', 'GET').upper()
endpoint = query_params.get('endpoint', '')
params = query_params.get('params', {})
data = query_params.get('data', {})
headers = self.config.get('headers', {})
timeout = self.config.get('timeout', 10)
base_url = self.config.get('base_url', '').rstrip('/')
url = f"{base_url}/{endpoint.lstrip('/')}"
if method == 'GET':
response = httpx.get(url, params=params, headers=headers, timeout=timeout)
elif method == 'POST':
response = httpx.post(url, json=data, headers=headers, timeout=timeout)
elif method == 'PUT':
response = httpx.put(url, json=data, headers=headers, timeout=timeout)
elif method == 'DELETE':
response = httpx.delete(url, headers=headers, timeout=timeout)
else:
raise ValueError(f"不支持的HTTP方法: {method}")
response.raise_for_status()
return response.json() if response.content else {}
except Exception as e:
raise Exception(f"API查询失败: {str(e)}")
class JSONFileConnector(DataSourceConnector):
"""JSON文件连接器"""
def test_connection(self) -> Dict[str, Any]:
try:
import os
file_path = self.config.get('file_path')
if not file_path:
raise ValueError("缺少file_path配置")
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
return {"status": "success", "message": "文件存在"}
except Exception as e:
raise Exception(f"JSON文件连接失败: {str(e)}")
def query(self, query_params: Dict[str, Any]) -> Any:
try:
import json
import os
file_path = self.config.get('file_path')
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 支持简单的查询过滤
filter_path = query_params.get('path')
if filter_path:
# 支持JSONPath风格的路径查询
parts = filter_path.split('.')
result = data
for part in parts:
if isinstance(result, dict):
result = result.get(part)
elif isinstance(result, list):
try:
index = int(part)
result = result[index]
except (ValueError, IndexError):
return None
else:
return None
return result
return data
except Exception as e:
raise Exception(f"JSON文件查询失败: {str(e)}")
# 连接器工厂
_connector_classes = {
'mysql': MySQLConnector,
'postgresql': PostgreSQLConnector,
'api': APIConnector,
'json': JSONFileConnector,
}
def create_connector(source_type: str, config: Dict[str, Any]):
"""
创建数据源连接器
Args:
source_type: 数据源类型
config: 连接配置
Returns:
数据源连接器实例
"""
connector_class = _connector_classes.get(source_type)
if not connector_class:
raise ValueError(f"不支持的数据源类型: {source_type}")
return connector_class(source_type, config)
# 为了兼容API创建一个统一的DataSourceConnector包装类
class DataSourceConnectorWrapper:
"""统一的数据源连接器包装类用于API调用"""
def __init__(self, source_type: str, config: Dict[str, Any]):
self.connector = create_connector(source_type, config)
self.source_type = source_type
self.config = config
def test_connection(self) -> Dict[str, Any]:
return self.connector.test_connection()
def query(self, query_params: Dict[str, Any]) -> Any:
return self.connector.query(query_params)
# 导出时使用包装类这样API可以统一使用DataSourceConnector
# 但实际返回的是具体的连接器实现