#!/usr/bin/env python3 """ 测试数据库查询工具 """ import sys import os import asyncio import json # 添加项目路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend')) from app.services.builtin_tools import database_query_tool, _validate_sql_query def test_sql_validation(): """测试SQL验证功能""" print("=" * 60) print("测试SQL验证功能") print("=" * 60) test_cases = [ ("SELECT * FROM users", True, "正常SELECT查询"), ("select * from users", True, "小写SELECT查询"), ("INSERT INTO users VALUES (1, 'test')", False, "INSERT查询(应拒绝)"), ("UPDATE users SET name='test'", False, "UPDATE查询(应拒绝)"), ("DELETE FROM users", False, "DELETE查询(应拒绝)"), ("DROP TABLE users", False, "DROP查询(应拒绝)"), ("SELECT * FROM users; DROP TABLE users", False, "多语句查询(应拒绝)"), ("SELECT * FROM users WHERE id = 1", True, "带WHERE的SELECT查询"), ("SELECT u.id, u.name FROM users u", True, "带别名的SELECT查询"), ] for sql, expected, description in test_cases: is_safe, error_msg = _validate_sql_query(sql) status = "✅" if is_safe == expected else "❌" print(f"{status} {description}") print(f" SQL: {sql[:50]}...") if not is_safe: print(f" 错误: {error_msg}") print() async def test_database_query(): """测试数据库查询功能""" print("=" * 60) print("测试数据库查询功能") print("=" * 60) # 测试1: 查询系统表(如果存在) print("\n1. 测试查询系统表(users表)") try: result = await database_query_tool( query="SELECT COUNT(*) as user_count FROM users LIMIT 1", timeout=10 ) data = json.loads(result) if data.get("success"): print(f" ✅ 查询成功") print(f" 结果: {json.dumps(data, ensure_ascii=False, indent=2)}") else: print(f" ❌ 查询失败: {data.get('error')}") except Exception as e: print(f" ⚠️ 查询异常: {str(e)}") # 测试2: 测试SQL注入防护 print("\n2. 测试SQL注入防护") try: result = await database_query_tool( query="INSERT INTO users (username) VALUES ('hacker')", timeout=10 ) data = json.loads(result) if not data.get("success") and "不允许" in data.get("error", ""): print(f" ✅ SQL注入防护生效") print(f" 错误信息: {data.get('error')}") else: print(f" ❌ SQL注入防护失效!") except Exception as e: print(f" ⚠️ 异常: {str(e)}") # 测试3: 测试复杂查询 print("\n3. 测试复杂SELECT查询") try: result = await database_query_tool( query="SELECT id, username, email FROM users LIMIT 5", timeout=10 ) data = json.loads(result) if data.get("success"): print(f" ✅ 查询成功") print(f" 返回行数: {data.get('row_count', 0)}") if data.get('data'): print(f" 示例数据: {json.dumps(data['data'][0] if data['data'] else {}, ensure_ascii=False, indent=2)}") else: print(f" ❌ 查询失败: {data.get('error')}") except Exception as e: print(f" ⚠️ 查询异常: {str(e)}") # 测试4: 测试超时控制 print("\n4. 测试超时控制(使用长时间查询)") try: result = await database_query_tool( query="SELECT SLEEP(5) as test", timeout=2 ) data = json.loads(result) if "超时" in data.get("error", ""): print(f" ✅ 超时控制生效") else: print(f" ⚠️ 超时控制未生效(可能数据库不支持SLEEP函数)") except Exception as e: print(f" ⚠️ 异常: {str(e)}") def main(): """主函数""" print("\n" + "=" * 60) print("数据库查询工具测试") print("=" * 60 + "\n") # 测试SQL验证 test_sql_validation() # 测试数据库查询 print("\n") asyncio.run(test_database_query()) print("\n" + "=" * 60) print("测试完成") print("=" * 60) if __name__ == "__main__": main()