Files
aiagent/test_database_query_tool.py
2026-03-06 22:31:41 +08:00

136 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
测试数据库查询工具
"""
import sys
import os
import asyncio
import json
# 添加项目路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend'))
from app.services.builtin_tools import database_query_tool, _validate_sql_query
def test_sql_validation():
"""测试SQL验证功能"""
print("=" * 60)
print("测试SQL验证功能")
print("=" * 60)
test_cases = [
("SELECT * FROM users", True, "正常SELECT查询"),
("select * from users", True, "小写SELECT查询"),
("INSERT INTO users VALUES (1, 'test')", False, "INSERT查询应拒绝"),
("UPDATE users SET name='test'", False, "UPDATE查询应拒绝"),
("DELETE FROM users", False, "DELETE查询应拒绝"),
("DROP TABLE users", False, "DROP查询应拒绝"),
("SELECT * FROM users; DROP TABLE users", False, "多语句查询(应拒绝)"),
("SELECT * FROM users WHERE id = 1", True, "带WHERE的SELECT查询"),
("SELECT u.id, u.name FROM users u", True, "带别名的SELECT查询"),
]
for sql, expected, description in test_cases:
is_safe, error_msg = _validate_sql_query(sql)
status = "" if is_safe == expected else ""
print(f"{status} {description}")
print(f" SQL: {sql[:50]}...")
if not is_safe:
print(f" 错误: {error_msg}")
print()
async def test_database_query():
"""测试数据库查询功能"""
print("=" * 60)
print("测试数据库查询功能")
print("=" * 60)
# 测试1: 查询系统表(如果存在)
print("\n1. 测试查询系统表users表")
try:
result = await database_query_tool(
query="SELECT COUNT(*) as user_count FROM users LIMIT 1",
timeout=10
)
data = json.loads(result)
if data.get("success"):
print(f" ✅ 查询成功")
print(f" 结果: {json.dumps(data, ensure_ascii=False, indent=2)}")
else:
print(f" ❌ 查询失败: {data.get('error')}")
except Exception as e:
print(f" ⚠️ 查询异常: {str(e)}")
# 测试2: 测试SQL注入防护
print("\n2. 测试SQL注入防护")
try:
result = await database_query_tool(
query="INSERT INTO users (username) VALUES ('hacker')",
timeout=10
)
data = json.loads(result)
if not data.get("success") and "不允许" in data.get("error", ""):
print(f" ✅ SQL注入防护生效")
print(f" 错误信息: {data.get('error')}")
else:
print(f" ❌ SQL注入防护失效")
except Exception as e:
print(f" ⚠️ 异常: {str(e)}")
# 测试3: 测试复杂查询
print("\n3. 测试复杂SELECT查询")
try:
result = await database_query_tool(
query="SELECT id, username, email FROM users LIMIT 5",
timeout=10
)
data = json.loads(result)
if data.get("success"):
print(f" ✅ 查询成功")
print(f" 返回行数: {data.get('row_count', 0)}")
if data.get('data'):
print(f" 示例数据: {json.dumps(data['data'][0] if data['data'] else {}, ensure_ascii=False, indent=2)}")
else:
print(f" ❌ 查询失败: {data.get('error')}")
except Exception as e:
print(f" ⚠️ 查询异常: {str(e)}")
# 测试4: 测试超时控制
print("\n4. 测试超时控制(使用长时间查询)")
try:
result = await database_query_tool(
query="SELECT SLEEP(5) as test",
timeout=2
)
data = json.loads(result)
if "超时" in data.get("error", ""):
print(f" ✅ 超时控制生效")
else:
print(f" ⚠️ 超时控制未生效可能数据库不支持SLEEP函数")
except Exception as e:
print(f" ⚠️ 异常: {str(e)}")
def main():
"""主函数"""
print("\n" + "=" * 60)
print("数据库查询工具测试")
print("=" * 60 + "\n")
# 测试SQL验证
test_sql_validation()
# 测试数据库查询
print("\n")
asyncio.run(test_database_query())
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)
if __name__ == "__main__":
main()