136 lines
4.4 KiB
Python
136 lines
4.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试数据库查询工具
|
||
"""
|
||
import sys
|
||
import os
|
||
import asyncio
|
||
import json
|
||
|
||
# 添加项目路径
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend'))
|
||
|
||
from app.services.builtin_tools import database_query_tool, _validate_sql_query
|
||
|
||
|
||
def test_sql_validation():
|
||
"""测试SQL验证功能"""
|
||
print("=" * 60)
|
||
print("测试SQL验证功能")
|
||
print("=" * 60)
|
||
|
||
test_cases = [
|
||
("SELECT * FROM users", True, "正常SELECT查询"),
|
||
("select * from users", True, "小写SELECT查询"),
|
||
("INSERT INTO users VALUES (1, 'test')", False, "INSERT查询(应拒绝)"),
|
||
("UPDATE users SET name='test'", False, "UPDATE查询(应拒绝)"),
|
||
("DELETE FROM users", False, "DELETE查询(应拒绝)"),
|
||
("DROP TABLE users", False, "DROP查询(应拒绝)"),
|
||
("SELECT * FROM users; DROP TABLE users", False, "多语句查询(应拒绝)"),
|
||
("SELECT * FROM users WHERE id = 1", True, "带WHERE的SELECT查询"),
|
||
("SELECT u.id, u.name FROM users u", True, "带别名的SELECT查询"),
|
||
]
|
||
|
||
for sql, expected, description in test_cases:
|
||
is_safe, error_msg = _validate_sql_query(sql)
|
||
status = "✅" if is_safe == expected else "❌"
|
||
print(f"{status} {description}")
|
||
print(f" SQL: {sql[:50]}...")
|
||
if not is_safe:
|
||
print(f" 错误: {error_msg}")
|
||
print()
|
||
|
||
|
||
async def test_database_query():
|
||
"""测试数据库查询功能"""
|
||
print("=" * 60)
|
||
print("测试数据库查询功能")
|
||
print("=" * 60)
|
||
|
||
# 测试1: 查询系统表(如果存在)
|
||
print("\n1. 测试查询系统表(users表)")
|
||
try:
|
||
result = await database_query_tool(
|
||
query="SELECT COUNT(*) as user_count FROM users LIMIT 1",
|
||
timeout=10
|
||
)
|
||
data = json.loads(result)
|
||
if data.get("success"):
|
||
print(f" ✅ 查询成功")
|
||
print(f" 结果: {json.dumps(data, ensure_ascii=False, indent=2)}")
|
||
else:
|
||
print(f" ❌ 查询失败: {data.get('error')}")
|
||
except Exception as e:
|
||
print(f" ⚠️ 查询异常: {str(e)}")
|
||
|
||
# 测试2: 测试SQL注入防护
|
||
print("\n2. 测试SQL注入防护")
|
||
try:
|
||
result = await database_query_tool(
|
||
query="INSERT INTO users (username) VALUES ('hacker')",
|
||
timeout=10
|
||
)
|
||
data = json.loads(result)
|
||
if not data.get("success") and "不允许" in data.get("error", ""):
|
||
print(f" ✅ SQL注入防护生效")
|
||
print(f" 错误信息: {data.get('error')}")
|
||
else:
|
||
print(f" ❌ SQL注入防护失效!")
|
||
except Exception as e:
|
||
print(f" ⚠️ 异常: {str(e)}")
|
||
|
||
# 测试3: 测试复杂查询
|
||
print("\n3. 测试复杂SELECT查询")
|
||
try:
|
||
result = await database_query_tool(
|
||
query="SELECT id, username, email FROM users LIMIT 5",
|
||
timeout=10
|
||
)
|
||
data = json.loads(result)
|
||
if data.get("success"):
|
||
print(f" ✅ 查询成功")
|
||
print(f" 返回行数: {data.get('row_count', 0)}")
|
||
if data.get('data'):
|
||
print(f" 示例数据: {json.dumps(data['data'][0] if data['data'] else {}, ensure_ascii=False, indent=2)}")
|
||
else:
|
||
print(f" ❌ 查询失败: {data.get('error')}")
|
||
except Exception as e:
|
||
print(f" ⚠️ 查询异常: {str(e)}")
|
||
|
||
# 测试4: 测试超时控制
|
||
print("\n4. 测试超时控制(使用长时间查询)")
|
||
try:
|
||
result = await database_query_tool(
|
||
query="SELECT SLEEP(5) as test",
|
||
timeout=2
|
||
)
|
||
data = json.loads(result)
|
||
if "超时" in data.get("error", ""):
|
||
print(f" ✅ 超时控制生效")
|
||
else:
|
||
print(f" ⚠️ 超时控制未生效(可能数据库不支持SLEEP函数)")
|
||
except Exception as e:
|
||
print(f" ⚠️ 异常: {str(e)}")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("\n" + "=" * 60)
|
||
print("数据库查询工具测试")
|
||
print("=" * 60 + "\n")
|
||
|
||
# 测试SQL验证
|
||
test_sql_validation()
|
||
|
||
# 测试数据库查询
|
||
print("\n")
|
||
asyncio.run(test_database_query())
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试完成")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|