""" Pytest配置和共享fixtures """ import pytest from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from fastapi.testclient import TestClient from app.core.database import Base, get_db, SessionLocal from app.main import app as fastapi_app from app.core.config import settings # 导入所有模型,确保 Base.metadata 包含所有表 from app.core.database import Base as _Base import app.models.user # noqa: F401 import app.models.workflow # noqa: F401 import app.models.agent # noqa: F401 import app.models.execution # noqa: F401 import app.models.model_config # noqa: F401 import app.models.workflow_template # noqa: F401 import app.models.permission # noqa: F401 import app.models.alert_rule # noqa: F401 import app.models.agent_llm_log # noqa: F401 import app.models.agent_vector_memory # noqa: F401 import app.models.knowledge_base # noqa: F401 import app.models.tool # noqa: F401 assert len(_Base.metadata.tables) > 0, "没有模型表被注册" # 测试数据库URL # 使用临时文件而非 :memory: 避免 FastAPI 异步/多线程请求中的 # SQLite 内存数据库连接隔离问题(每个连接看到不同的数据库) import tempfile as _tempfile import os as _os import atexit as _atexit _test_db_fd, _test_db_path = _tempfile.mkstemp(suffix=".db") _os.close(_test_db_fd) TEST_DATABASE_URL = f"sqlite:///{_test_db_path}" # 创建测试数据库引擎 test_engine = create_engine( TEST_DATABASE_URL, connect_args={"check_same_thread": False} ) # 创建测试会话工厂 TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine) # 在进程退出时清理临时数据库文件 def _cleanup_test_db(): test_engine.dispose() if _os.path.exists(_test_db_path): _os.unlink(_test_db_path) _atexit.register(_cleanup_test_db) @pytest.fixture(scope="function") def db_session(): """创建测试数据库会话""" _Base.metadata.create_all(bind=test_engine) session = TestingSessionLocal() try: yield session finally: session.close() _Base.metadata.drop_all(bind=test_engine) @pytest.fixture(scope="function") def client(db_session): """创建测试客户端""" def override_get_db(): try: yield db_session finally: pass fastapi_app.dependency_overrides[get_db] = override_get_db with TestClient(fastapi_app) as test_client: yield test_client fastapi_app.dependency_overrides.clear() @pytest.fixture def test_user_data(): """测试用户数据""" return { "username": "testuser", "email": "test@example.com", "password": "testpassword123" } @pytest.fixture def authenticated_client(client, test_user_data): """创建已认证的测试客户端""" # 注册用户 response = client.post("/api/v1/auth/register", json=test_user_data) assert response.status_code == 201 # 登录获取token login_response = client.post( "/api/v1/auth/login", data={ "username": test_user_data["username"], "password": test_user_data["password"] } ) assert login_response.status_code == 200 token = login_response.json()["access_token"] # 设置认证头 client.headers.update({"Authorization": f"Bearer {token}"}) return client @pytest.fixture def sample_workflow_data(): """示例工作流数据""" return { "name": "测试工作流", "description": "这是一个测试工作流", "nodes": [ { "id": "start-1", "type": "start", "position": {"x": 0, "y": 0}, "data": {"label": "开始"} }, { "id": "llm-1", "type": "llm", "position": {"x": 200, "y": 0}, "data": { "label": "LLM节点", "provider": "deepseek", "prompt": "请回答:{input}", "model": "deepseek-chat" } }, { "id": "end-1", "type": "end", "position": {"x": 400, "y": 0}, "data": {"label": "结束"} } ], "edges": [ { "id": "e1", "source": "start-1", "target": "llm-1" }, { "id": "e2", "source": "llm-1", "target": "end-1" } ] }