287 lines
8.6 KiB
Python
287 lines
8.6 KiB
Python
"""
|
||
模型配置管理API
|
||
"""
|
||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||
from sqlalchemy.orm import Session
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional
|
||
from datetime import datetime
|
||
import logging
|
||
from app.core.database import get_db
|
||
from app.models.model_config import ModelConfig
|
||
from app.api.auth import get_current_user
|
||
from app.models.user import User
|
||
from app.core.exceptions import NotFoundError, ValidationError, ConflictError
|
||
from app.services.encryption_service import EncryptionService
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(
|
||
prefix="/api/v1/model-configs",
|
||
tags=["model-configs"],
|
||
responses={
|
||
401: {"description": "未授权"},
|
||
404: {"description": "资源不存在"},
|
||
400: {"description": "请求参数错误"},
|
||
500: {"description": "服务器内部错误"}
|
||
}
|
||
)
|
||
|
||
|
||
class ModelConfigCreate(BaseModel):
|
||
"""模型配置创建模型"""
|
||
name: str
|
||
provider: str # openai/deepseek/anthropic/local
|
||
model_name: str
|
||
api_key: str
|
||
base_url: Optional[str] = None
|
||
|
||
|
||
class ModelConfigUpdate(BaseModel):
|
||
"""模型配置更新模型"""
|
||
name: Optional[str] = None
|
||
provider: Optional[str] = None
|
||
model_name: Optional[str] = None
|
||
api_key: Optional[str] = None
|
||
base_url: Optional[str] = None
|
||
|
||
|
||
class ModelConfigResponse(BaseModel):
|
||
"""模型配置响应模型"""
|
||
id: str
|
||
name: str
|
||
provider: str
|
||
model_name: str
|
||
base_url: Optional[str]
|
||
user_id: str
|
||
created_at: datetime
|
||
updated_at: datetime
|
||
|
||
class Config:
|
||
from_attributes = True
|
||
|
||
|
||
@router.get("", response_model=List[ModelConfigResponse])
|
||
async def get_model_configs(
|
||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||
limit: int = Query(100, ge=1, le=100, description="每页记录数"),
|
||
provider: Optional[str] = Query(None, description="提供商筛选"),
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取模型配置列表
|
||
|
||
支持分页和提供商筛选
|
||
"""
|
||
query = db.query(ModelConfig).filter(ModelConfig.user_id == current_user.id)
|
||
|
||
# 筛选:按提供商筛选
|
||
if provider:
|
||
query = query.filter(ModelConfig.provider == provider)
|
||
|
||
# 排序和分页
|
||
configs = query.order_by(ModelConfig.created_at.desc()).offset(skip).limit(limit).all()
|
||
return configs
|
||
|
||
|
||
@router.post("", response_model=ModelConfigResponse, status_code=status.HTTP_201_CREATED)
|
||
async def create_model_config(
|
||
config_data: ModelConfigCreate,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
创建模型配置
|
||
|
||
注意:API密钥会加密存储
|
||
"""
|
||
# 验证提供商
|
||
valid_providers = ['openai', 'deepseek', 'anthropic', 'local']
|
||
if config_data.provider not in valid_providers:
|
||
raise ValidationError(f"不支持的提供商: {config_data.provider}")
|
||
|
||
# 检查名称是否重复
|
||
existing_config = db.query(ModelConfig).filter(
|
||
ModelConfig.name == config_data.name,
|
||
ModelConfig.user_id == current_user.id
|
||
).first()
|
||
if existing_config:
|
||
raise ConflictError(f"模型配置名称 '{config_data.name}' 已存在")
|
||
|
||
# 创建模型配置
|
||
# API密钥加密存储
|
||
encrypted_api_key = EncryptionService.encrypt(config_data.api_key)
|
||
model_config = ModelConfig(
|
||
name=config_data.name,
|
||
provider=config_data.provider,
|
||
model_name=config_data.model_name,
|
||
api_key=encrypted_api_key,
|
||
base_url=config_data.base_url,
|
||
user_id=current_user.id
|
||
)
|
||
db.add(model_config)
|
||
db.commit()
|
||
db.refresh(model_config)
|
||
|
||
logger.info(f"用户 {current_user.username} 创建了模型配置: {model_config.name} ({model_config.id})")
|
||
return model_config
|
||
|
||
|
||
@router.get("/{config_id}", response_model=ModelConfigResponse)
|
||
async def get_model_config(
|
||
config_id: str,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取模型配置详情
|
||
|
||
注意:API密钥不会返回
|
||
"""
|
||
config = db.query(ModelConfig).filter(
|
||
ModelConfig.id == config_id,
|
||
ModelConfig.user_id == current_user.id
|
||
).first()
|
||
|
||
if not config:
|
||
raise NotFoundError(f"模型配置不存在: {config_id}")
|
||
|
||
return config
|
||
|
||
|
||
@router.put("/{config_id}", response_model=ModelConfigResponse)
|
||
async def update_model_config(
|
||
config_id: str,
|
||
config_data: ModelConfigUpdate,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
更新模型配置
|
||
"""
|
||
config = db.query(ModelConfig).filter(
|
||
ModelConfig.id == config_id,
|
||
ModelConfig.user_id == current_user.id
|
||
).first()
|
||
|
||
if not config:
|
||
raise NotFoundError(f"模型配置不存在: {config_id}")
|
||
|
||
# 更新字段
|
||
if config_data.name is not None:
|
||
# 检查名称是否重复(排除当前配置)
|
||
existing_config = db.query(ModelConfig).filter(
|
||
ModelConfig.name == config_data.name,
|
||
ModelConfig.user_id == current_user.id,
|
||
ModelConfig.id != config_id
|
||
).first()
|
||
if existing_config:
|
||
raise ConflictError(f"模型配置名称 '{config_data.name}' 已存在")
|
||
config.name = config_data.name
|
||
|
||
if config_data.provider is not None:
|
||
valid_providers = ['openai', 'deepseek', 'anthropic', 'local']
|
||
if config_data.provider not in valid_providers:
|
||
raise ValidationError(f"不支持的提供商: {config_data.provider}")
|
||
config.provider = config_data.provider
|
||
|
||
if config_data.model_name is not None:
|
||
config.model_name = config_data.model_name
|
||
|
||
if config_data.api_key is not None:
|
||
# API密钥加密存储
|
||
config.api_key = EncryptionService.encrypt(config_data.api_key)
|
||
|
||
if config_data.base_url is not None:
|
||
config.base_url = config_data.base_url
|
||
|
||
db.commit()
|
||
db.refresh(config)
|
||
|
||
logger.info(f"用户 {current_user.username} 更新了模型配置: {config.name} ({config.id})")
|
||
return config
|
||
|
||
|
||
@router.delete("/{config_id}", status_code=status.HTTP_200_OK)
|
||
async def delete_model_config(
|
||
config_id: str,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
删除模型配置
|
||
"""
|
||
config = db.query(ModelConfig).filter(
|
||
ModelConfig.id == config_id,
|
||
ModelConfig.user_id == current_user.id
|
||
).first()
|
||
|
||
if not config:
|
||
raise NotFoundError(f"模型配置不存在: {config_id}")
|
||
|
||
config_name = config.name
|
||
db.delete(config)
|
||
db.commit()
|
||
|
||
logger.info(f"用户 {current_user.username} 删除了模型配置: {config_name} ({config_id})")
|
||
return {"message": "模型配置已删除"}
|
||
|
||
|
||
@router.post("/{config_id}/test", status_code=status.HTTP_200_OK)
|
||
async def test_model_config(
|
||
config_id: str,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
测试模型配置连接
|
||
|
||
尝试调用模型API验证配置是否正确
|
||
"""
|
||
config = db.query(ModelConfig).filter(
|
||
ModelConfig.id == config_id,
|
||
ModelConfig.user_id == current_user.id
|
||
).first()
|
||
|
||
if not config:
|
||
raise NotFoundError(f"模型配置不存在: {config_id}")
|
||
|
||
try:
|
||
# 根据提供商测试连接
|
||
from app.services.llm_service import llm_service
|
||
|
||
# 解密API密钥用于测试
|
||
decrypted_api_key = EncryptionService.decrypt(config.api_key)
|
||
|
||
if config.provider == 'openai':
|
||
result = await llm_service.call_openai(
|
||
prompt="test",
|
||
model=config.model_name,
|
||
api_key=decrypted_api_key,
|
||
base_url=config.base_url
|
||
)
|
||
elif config.provider == 'deepseek':
|
||
result = await llm_service.call_deepseek(
|
||
prompt="test",
|
||
model=config.model_name,
|
||
api_key=decrypted_api_key,
|
||
base_url=config.base_url
|
||
)
|
||
else:
|
||
return {
|
||
"status": "warning",
|
||
"message": f"提供商 {config.provider} 的测试功能暂未实现"
|
||
}
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": "模型配置测试成功"
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"模型配置测试失败: {str(e)}")
|
||
return {
|
||
"status": "error",
|
||
"message": f"模型配置测试失败: {str(e)}"
|
||
}
|