287 lines
8.6 KiB
Python
287 lines
8.6 KiB
Python
|
|
"""
|
|||
|
|
模型配置管理API
|
|||
|
|
"""
|
|||
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from typing import List, Optional
|
|||
|
|
from datetime import datetime
|
|||
|
|
import logging
|
|||
|
|
from app.core.database import get_db
|
|||
|
|
from app.models.model_config import ModelConfig
|
|||
|
|
from app.api.auth import get_current_user
|
|||
|
|
from app.models.user import User
|
|||
|
|
from app.core.exceptions import NotFoundError, ValidationError, ConflictError
|
|||
|
|
from app.services.encryption_service import EncryptionService
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
router = APIRouter(
|
|||
|
|
prefix="/api/v1/model-configs",
|
|||
|
|
tags=["model-configs"],
|
|||
|
|
responses={
|
|||
|
|
401: {"description": "未授权"},
|
|||
|
|
404: {"description": "资源不存在"},
|
|||
|
|
400: {"description": "请求参数错误"},
|
|||
|
|
500: {"description": "服务器内部错误"}
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ModelConfigCreate(BaseModel):
|
|||
|
|
"""模型配置创建模型"""
|
|||
|
|
name: str
|
|||
|
|
provider: str # openai/deepseek/anthropic/local
|
|||
|
|
model_name: str
|
|||
|
|
api_key: str
|
|||
|
|
base_url: Optional[str] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ModelConfigUpdate(BaseModel):
|
|||
|
|
"""模型配置更新模型"""
|
|||
|
|
name: Optional[str] = None
|
|||
|
|
provider: Optional[str] = None
|
|||
|
|
model_name: Optional[str] = None
|
|||
|
|
api_key: Optional[str] = None
|
|||
|
|
base_url: Optional[str] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ModelConfigResponse(BaseModel):
|
|||
|
|
"""模型配置响应模型"""
|
|||
|
|
id: str
|
|||
|
|
name: str
|
|||
|
|
provider: str
|
|||
|
|
model_name: str
|
|||
|
|
base_url: Optional[str]
|
|||
|
|
user_id: str
|
|||
|
|
created_at: datetime
|
|||
|
|
updated_at: datetime
|
|||
|
|
|
|||
|
|
class Config:
|
|||
|
|
from_attributes = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("", response_model=List[ModelConfigResponse])
|
|||
|
|
async def get_model_configs(
|
|||
|
|
skip: int = Query(0, ge=0, description="跳过记录数"),
|
|||
|
|
limit: int = Query(100, ge=1, le=100, description="每页记录数"),
|
|||
|
|
provider: Optional[str] = Query(None, description="提供商筛选"),
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
current_user: User = Depends(get_current_user)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
获取模型配置列表
|
|||
|
|
|
|||
|
|
支持分页和提供商筛选
|
|||
|
|
"""
|
|||
|
|
query = db.query(ModelConfig).filter(ModelConfig.user_id == current_user.id)
|
|||
|
|
|
|||
|
|
# 筛选:按提供商筛选
|
|||
|
|
if provider:
|
|||
|
|
query = query.filter(ModelConfig.provider == provider)
|
|||
|
|
|
|||
|
|
# 排序和分页
|
|||
|
|
configs = query.order_by(ModelConfig.created_at.desc()).offset(skip).limit(limit).all()
|
|||
|
|
return configs
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("", response_model=ModelConfigResponse, status_code=status.HTTP_201_CREATED)
|
|||
|
|
async def create_model_config(
|
|||
|
|
config_data: ModelConfigCreate,
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
current_user: User = Depends(get_current_user)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
创建模型配置
|
|||
|
|
|
|||
|
|
注意:API密钥会加密存储
|
|||
|
|
"""
|
|||
|
|
# 验证提供商
|
|||
|
|
valid_providers = ['openai', 'deepseek', 'anthropic', 'local']
|
|||
|
|
if config_data.provider not in valid_providers:
|
|||
|
|
raise ValidationError(f"不支持的提供商: {config_data.provider}")
|
|||
|
|
|
|||
|
|
# 检查名称是否重复
|
|||
|
|
existing_config = db.query(ModelConfig).filter(
|
|||
|
|
ModelConfig.name == config_data.name,
|
|||
|
|
ModelConfig.user_id == current_user.id
|
|||
|
|
).first()
|
|||
|
|
if existing_config:
|
|||
|
|
raise ConflictError(f"模型配置名称 '{config_data.name}' 已存在")
|
|||
|
|
|
|||
|
|
# 创建模型配置
|
|||
|
|
# API密钥加密存储
|
|||
|
|
encrypted_api_key = EncryptionService.encrypt(config_data.api_key)
|
|||
|
|
model_config = ModelConfig(
|
|||
|
|
name=config_data.name,
|
|||
|
|
provider=config_data.provider,
|
|||
|
|
model_name=config_data.model_name,
|
|||
|
|
api_key=encrypted_api_key,
|
|||
|
|
base_url=config_data.base_url,
|
|||
|
|
user_id=current_user.id
|
|||
|
|
)
|
|||
|
|
db.add(model_config)
|
|||
|
|
db.commit()
|
|||
|
|
db.refresh(model_config)
|
|||
|
|
|
|||
|
|
logger.info(f"用户 {current_user.username} 创建了模型配置: {model_config.name} ({model_config.id})")
|
|||
|
|
return model_config
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/{config_id}", response_model=ModelConfigResponse)
|
|||
|
|
async def get_model_config(
|
|||
|
|
config_id: str,
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
current_user: User = Depends(get_current_user)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
获取模型配置详情
|
|||
|
|
|
|||
|
|
注意:API密钥不会返回
|
|||
|
|
"""
|
|||
|
|
config = db.query(ModelConfig).filter(
|
|||
|
|
ModelConfig.id == config_id,
|
|||
|
|
ModelConfig.user_id == current_user.id
|
|||
|
|
).first()
|
|||
|
|
|
|||
|
|
if not config:
|
|||
|
|
raise NotFoundError(f"模型配置不存在: {config_id}")
|
|||
|
|
|
|||
|
|
return config
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.put("/{config_id}", response_model=ModelConfigResponse)
|
|||
|
|
async def update_model_config(
|
|||
|
|
config_id: str,
|
|||
|
|
config_data: ModelConfigUpdate,
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
current_user: User = Depends(get_current_user)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
更新模型配置
|
|||
|
|
"""
|
|||
|
|
config = db.query(ModelConfig).filter(
|
|||
|
|
ModelConfig.id == config_id,
|
|||
|
|
ModelConfig.user_id == current_user.id
|
|||
|
|
).first()
|
|||
|
|
|
|||
|
|
if not config:
|
|||
|
|
raise NotFoundError(f"模型配置不存在: {config_id}")
|
|||
|
|
|
|||
|
|
# 更新字段
|
|||
|
|
if config_data.name is not None:
|
|||
|
|
# 检查名称是否重复(排除当前配置)
|
|||
|
|
existing_config = db.query(ModelConfig).filter(
|
|||
|
|
ModelConfig.name == config_data.name,
|
|||
|
|
ModelConfig.user_id == current_user.id,
|
|||
|
|
ModelConfig.id != config_id
|
|||
|
|
).first()
|
|||
|
|
if existing_config:
|
|||
|
|
raise ConflictError(f"模型配置名称 '{config_data.name}' 已存在")
|
|||
|
|
config.name = config_data.name
|
|||
|
|
|
|||
|
|
if config_data.provider is not None:
|
|||
|
|
valid_providers = ['openai', 'deepseek', 'anthropic', 'local']
|
|||
|
|
if config_data.provider not in valid_providers:
|
|||
|
|
raise ValidationError(f"不支持的提供商: {config_data.provider}")
|
|||
|
|
config.provider = config_data.provider
|
|||
|
|
|
|||
|
|
if config_data.model_name is not None:
|
|||
|
|
config.model_name = config_data.model_name
|
|||
|
|
|
|||
|
|
if config_data.api_key is not None:
|
|||
|
|
# API密钥加密存储
|
|||
|
|
config.api_key = EncryptionService.encrypt(config_data.api_key)
|
|||
|
|
|
|||
|
|
if config_data.base_url is not None:
|
|||
|
|
config.base_url = config_data.base_url
|
|||
|
|
|
|||
|
|
db.commit()
|
|||
|
|
db.refresh(config)
|
|||
|
|
|
|||
|
|
logger.info(f"用户 {current_user.username} 更新了模型配置: {config.name} ({config.id})")
|
|||
|
|
return config
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/{config_id}", status_code=status.HTTP_200_OK)
|
|||
|
|
async def delete_model_config(
|
|||
|
|
config_id: str,
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
current_user: User = Depends(get_current_user)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
删除模型配置
|
|||
|
|
"""
|
|||
|
|
config = db.query(ModelConfig).filter(
|
|||
|
|
ModelConfig.id == config_id,
|
|||
|
|
ModelConfig.user_id == current_user.id
|
|||
|
|
).first()
|
|||
|
|
|
|||
|
|
if not config:
|
|||
|
|
raise NotFoundError(f"模型配置不存在: {config_id}")
|
|||
|
|
|
|||
|
|
config_name = config.name
|
|||
|
|
db.delete(config)
|
|||
|
|
db.commit()
|
|||
|
|
|
|||
|
|
logger.info(f"用户 {current_user.username} 删除了模型配置: {config_name} ({config_id})")
|
|||
|
|
return {"message": "模型配置已删除"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/{config_id}/test", status_code=status.HTTP_200_OK)
|
|||
|
|
async def test_model_config(
|
|||
|
|
config_id: str,
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
current_user: User = Depends(get_current_user)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
测试模型配置连接
|
|||
|
|
|
|||
|
|
尝试调用模型API验证配置是否正确
|
|||
|
|
"""
|
|||
|
|
config = db.query(ModelConfig).filter(
|
|||
|
|
ModelConfig.id == config_id,
|
|||
|
|
ModelConfig.user_id == current_user.id
|
|||
|
|
).first()
|
|||
|
|
|
|||
|
|
if not config:
|
|||
|
|
raise NotFoundError(f"模型配置不存在: {config_id}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 根据提供商测试连接
|
|||
|
|
from app.services.llm_service import llm_service
|
|||
|
|
|
|||
|
|
# 解密API密钥用于测试
|
|||
|
|
decrypted_api_key = EncryptionService.decrypt(config.api_key)
|
|||
|
|
|
|||
|
|
if config.provider == 'openai':
|
|||
|
|
result = await llm_service.call_openai(
|
|||
|
|
prompt="test",
|
|||
|
|
model=config.model_name,
|
|||
|
|
api_key=decrypted_api_key,
|
|||
|
|
base_url=config.base_url
|
|||
|
|
)
|
|||
|
|
elif config.provider == 'deepseek':
|
|||
|
|
result = await llm_service.call_deepseek(
|
|||
|
|
prompt="test",
|
|||
|
|
model=config.model_name,
|
|||
|
|
api_key=decrypted_api_key,
|
|||
|
|
base_url=config.base_url
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
return {
|
|||
|
|
"status": "warning",
|
|||
|
|
"message": f"提供商 {config.provider} 的测试功能暂未实现"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"status": "success",
|
|||
|
|
"message": "模型配置测试成功"
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"模型配置测试失败: {str(e)}")
|
|||
|
|
return {
|
|||
|
|
"status": "error",
|
|||
|
|
"message": f"模型配置测试失败: {str(e)}"
|
|||
|
|
}
|