Files
aiagent/test_database_query_tool.py

136 lines
4.4 KiB
Python
Raw Permalink Normal View History

2026-03-06 22:31:41 +08:00
#!/usr/bin/env python3
"""
测试数据库查询工具
"""
import sys
import os
import asyncio
import json
# 添加项目路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend'))
from app.services.builtin_tools import database_query_tool, _validate_sql_query
def test_sql_validation():
"""测试SQL验证功能"""
print("=" * 60)
print("测试SQL验证功能")
print("=" * 60)
test_cases = [
("SELECT * FROM users", True, "正常SELECT查询"),
("select * from users", True, "小写SELECT查询"),
("INSERT INTO users VALUES (1, 'test')", False, "INSERT查询应拒绝"),
("UPDATE users SET name='test'", False, "UPDATE查询应拒绝"),
("DELETE FROM users", False, "DELETE查询应拒绝"),
("DROP TABLE users", False, "DROP查询应拒绝"),
("SELECT * FROM users; DROP TABLE users", False, "多语句查询(应拒绝)"),
("SELECT * FROM users WHERE id = 1", True, "带WHERE的SELECT查询"),
("SELECT u.id, u.name FROM users u", True, "带别名的SELECT查询"),
]
for sql, expected, description in test_cases:
is_safe, error_msg = _validate_sql_query(sql)
status = "" if is_safe == expected else ""
print(f"{status} {description}")
print(f" SQL: {sql[:50]}...")
if not is_safe:
print(f" 错误: {error_msg}")
print()
async def test_database_query():
"""测试数据库查询功能"""
print("=" * 60)
print("测试数据库查询功能")
print("=" * 60)
# 测试1: 查询系统表(如果存在)
print("\n1. 测试查询系统表users表")
try:
result = await database_query_tool(
query="SELECT COUNT(*) as user_count FROM users LIMIT 1",
timeout=10
)
data = json.loads(result)
if data.get("success"):
print(f" ✅ 查询成功")
print(f" 结果: {json.dumps(data, ensure_ascii=False, indent=2)}")
else:
print(f" ❌ 查询失败: {data.get('error')}")
except Exception as e:
print(f" ⚠️ 查询异常: {str(e)}")
# 测试2: 测试SQL注入防护
print("\n2. 测试SQL注入防护")
try:
result = await database_query_tool(
query="INSERT INTO users (username) VALUES ('hacker')",
timeout=10
)
data = json.loads(result)
if not data.get("success") and "不允许" in data.get("error", ""):
print(f" ✅ SQL注入防护生效")
print(f" 错误信息: {data.get('error')}")
else:
print(f" ❌ SQL注入防护失效")
except Exception as e:
print(f" ⚠️ 异常: {str(e)}")
# 测试3: 测试复杂查询
print("\n3. 测试复杂SELECT查询")
try:
result = await database_query_tool(
query="SELECT id, username, email FROM users LIMIT 5",
timeout=10
)
data = json.loads(result)
if data.get("success"):
print(f" ✅ 查询成功")
print(f" 返回行数: {data.get('row_count', 0)}")
if data.get('data'):
print(f" 示例数据: {json.dumps(data['data'][0] if data['data'] else {}, ensure_ascii=False, indent=2)}")
else:
print(f" ❌ 查询失败: {data.get('error')}")
except Exception as e:
print(f" ⚠️ 查询异常: {str(e)}")
# 测试4: 测试超时控制
print("\n4. 测试超时控制(使用长时间查询)")
try:
result = await database_query_tool(
query="SELECT SLEEP(5) as test",
timeout=2
)
data = json.loads(result)
if "超时" in data.get("error", ""):
print(f" ✅ 超时控制生效")
else:
print(f" ⚠️ 超时控制未生效可能数据库不支持SLEEP函数")
except Exception as e:
print(f" ⚠️ 异常: {str(e)}")
def main():
"""主函数"""
print("\n" + "=" * 60)
print("数据库查询工具测试")
print("=" * 60 + "\n")
# 测试SQL验证
test_sql_validation()
# 测试数据库查询
print("\n")
asyncio.run(test_database_query())
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)
if __name__ == "__main__":
main()