Files
aiagent/backend/tests/conftest.py
2026-01-19 00:09:36 +08:00

137 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Pytest配置和共享fixtures
"""
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from fastapi.testclient import TestClient
from app.core.database import Base, get_db, SessionLocal
from app.main import app
from app.core.config import settings
import os
# 测试数据库URL使用SQLite内存数据库
TEST_DATABASE_URL = "sqlite:///:memory:"
# 创建测试数据库引擎
test_engine = create_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False}
)
# 创建测试会话工厂
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine)
@pytest.fixture(scope="function")
def db_session():
"""创建测试数据库会话"""
# 创建所有表
Base.metadata.create_all(bind=test_engine)
# 创建会话
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
# 删除所有表
Base.metadata.drop_all(bind=test_engine)
@pytest.fixture(scope="function")
def client(db_session):
"""创建测试客户端"""
def override_get_db():
try:
yield db_session
finally:
pass
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
app.dependency_overrides.clear()
@pytest.fixture
def test_user_data():
"""测试用户数据"""
return {
"username": "testuser",
"email": "test@example.com",
"password": "testpassword123"
}
@pytest.fixture
def authenticated_client(client, test_user_data):
"""创建已认证的测试客户端"""
# 注册用户
response = client.post("/api/v1/auth/register", json=test_user_data)
assert response.status_code == 201
# 登录获取token
login_response = client.post(
"/api/v1/auth/login",
data={
"username": test_user_data["username"],
"password": test_user_data["password"]
}
)
assert login_response.status_code == 200
token = login_response.json()["access_token"]
# 设置认证头
client.headers.update({"Authorization": f"Bearer {token}"})
return client
@pytest.fixture
def sample_workflow_data():
"""示例工作流数据"""
return {
"name": "测试工作流",
"description": "这是一个测试工作流",
"nodes": [
{
"id": "start-1",
"type": "start",
"position": {"x": 0, "y": 0},
"data": {"label": "开始"}
},
{
"id": "llm-1",
"type": "llm",
"position": {"x": 200, "y": 0},
"data": {
"label": "LLM节点",
"provider": "deepseek",
"prompt": "请回答:{input}",
"model": "deepseek-chat"
}
},
{
"id": "end-1",
"type": "end",
"position": {"x": 400, "y": 0},
"data": {"label": "结束"}
}
],
"edges": [
{
"id": "e1",
"source": "start-1",
"target": "llm-1"
},
{
"id": "e2",
"source": "llm-1",
"target": "end-1"
}
]
}