221 lines
7.8 KiB
Python
221 lines
7.8 KiB
Python
"""
|
||
LLM服务 - 处理各种LLM提供商的调用
|
||
"""
|
||
from typing import Dict, Any, Optional
|
||
import json
|
||
from openai import AsyncOpenAI
|
||
from app.core.config import settings
|
||
|
||
|
||
class LLMService:
|
||
"""LLM服务类"""
|
||
|
||
def __init__(self):
|
||
"""初始化LLM服务"""
|
||
self.openai_client = None
|
||
self.deepseek_client = None
|
||
|
||
# 初始化OpenAI客户端
|
||
if settings.OPENAI_API_KEY:
|
||
self.openai_client = AsyncOpenAI(
|
||
api_key=settings.OPENAI_API_KEY,
|
||
base_url=settings.OPENAI_BASE_URL
|
||
)
|
||
|
||
# 初始化DeepSeek客户端(兼容OpenAI API)
|
||
if settings.DEEPSEEK_API_KEY:
|
||
self.deepseek_client = AsyncOpenAI(
|
||
api_key=settings.DEEPSEEK_API_KEY,
|
||
base_url=settings.DEEPSEEK_BASE_URL
|
||
)
|
||
|
||
async def call_openai(
|
||
self,
|
||
prompt: str,
|
||
model: str = "gpt-3.5-turbo",
|
||
temperature: float = 0.7,
|
||
max_tokens: Optional[int] = None,
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
调用OpenAI API
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
model: 模型名称,默认gpt-3.5-turbo
|
||
temperature: 温度参数,默认0.7
|
||
max_tokens: 最大token数
|
||
api_key: API密钥(可选,如果不提供则使用默认配置)
|
||
base_url: API地址(可选,如果不提供则使用默认配置)
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
LLM返回的文本
|
||
"""
|
||
# 如果提供了api_key或base_url,创建临时客户端
|
||
# 注意:api_key 可能是空字符串,需要检查是否为 None
|
||
if api_key is not None or base_url is not None:
|
||
# 如果提供了 api_key,使用它;否则使用系统默认配置
|
||
final_api_key = api_key if api_key else settings.OPENAI_API_KEY
|
||
final_base_url = base_url if base_url else settings.OPENAI_BASE_URL
|
||
|
||
if not final_api_key:
|
||
raise ValueError("OpenAI API Key未配置,请在节点配置中设置API Key或在环境变量中设置OPENAI_API_KEY")
|
||
|
||
client = AsyncOpenAI(
|
||
api_key=final_api_key,
|
||
base_url=final_base_url
|
||
)
|
||
else:
|
||
# 如果 openai_client 未初始化,尝试从 settings 重新读取并初始化
|
||
if not self.openai_client:
|
||
if settings.OPENAI_API_KEY:
|
||
self.openai_client = AsyncOpenAI(
|
||
api_key=settings.OPENAI_API_KEY,
|
||
base_url=settings.OPENAI_BASE_URL
|
||
)
|
||
else:
|
||
raise ValueError("OpenAI API Key未配置,请在节点配置中设置API Key或在环境变量中设置OPENAI_API_KEY")
|
||
client = self.openai_client
|
||
|
||
try:
|
||
response = await client.chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
if content is None:
|
||
raise Exception("OpenAI API返回的内容为空,请检查API配置和模型名称")
|
||
return content
|
||
except Exception as e:
|
||
raise Exception(f"OpenAI API调用失败: {str(e)}")
|
||
|
||
async def call_deepseek(
|
||
self,
|
||
prompt: str,
|
||
model: str = "deepseek-chat",
|
||
temperature: float = 0.7,
|
||
max_tokens: Optional[int] = None,
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
调用DeepSeek API
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
model: 模型名称,默认deepseek-chat
|
||
temperature: 温度参数,默认0.7
|
||
max_tokens: 最大token数
|
||
api_key: API密钥(可选,如果不提供则使用默认配置)
|
||
base_url: API地址(可选,如果不提供则使用默认配置)
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
LLM返回的文本
|
||
"""
|
||
# 如果提供了api_key或base_url,创建临时客户端
|
||
# 注意:api_key 可能是空字符串,需要检查是否为 None
|
||
if api_key is not None or base_url is not None:
|
||
# 如果提供了 api_key,使用它;否则使用系统默认配置
|
||
final_api_key = api_key if api_key else settings.DEEPSEEK_API_KEY
|
||
final_base_url = base_url if base_url else settings.DEEPSEEK_BASE_URL
|
||
|
||
if not final_api_key:
|
||
raise ValueError("DeepSeek API Key未配置,请在节点配置中设置API Key或在环境变量中设置DEEPSEEK_API_KEY")
|
||
|
||
client = AsyncOpenAI(
|
||
api_key=final_api_key,
|
||
base_url=final_base_url
|
||
)
|
||
else:
|
||
# 如果 deepseek_client 未初始化,尝试从 settings 重新读取并初始化
|
||
if not self.deepseek_client:
|
||
if settings.DEEPSEEK_API_KEY:
|
||
self.deepseek_client = AsyncOpenAI(
|
||
api_key=settings.DEEPSEEK_API_KEY,
|
||
base_url=settings.DEEPSEEK_BASE_URL
|
||
)
|
||
else:
|
||
raise ValueError("DeepSeek API Key未配置,请在节点配置中设置API Key或在环境变量中设置DEEPSEEK_API_KEY")
|
||
client = self.deepseek_client
|
||
|
||
try:
|
||
response = await client.chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
if content is None:
|
||
raise Exception("DeepSeek API返回的内容为空,请检查API配置和模型名称")
|
||
return content
|
||
except Exception as e:
|
||
raise Exception(f"DeepSeek API调用失败: {str(e)}")
|
||
|
||
async def call_llm(
|
||
self,
|
||
prompt: str,
|
||
provider: str = "openai",
|
||
model: Optional[str] = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: Optional[int] = None,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
通用LLM调用接口
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
provider: 提供商,支持openai、deepseek
|
||
model: 模型名称
|
||
temperature: 温度参数
|
||
max_tokens: 最大token数
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
LLM返回的文本
|
||
"""
|
||
if provider == "openai":
|
||
# 默认模型
|
||
if not model:
|
||
model = "gpt-3.5-turbo"
|
||
return await self.call_openai(
|
||
prompt=prompt,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs
|
||
)
|
||
elif provider == "deepseek":
|
||
# 默认模型
|
||
if not model:
|
||
model = "deepseek-chat"
|
||
return await self.call_deepseek(
|
||
prompt=prompt,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs
|
||
)
|
||
else:
|
||
raise ValueError(f"不支持的LLM提供商: {provider},目前支持: openai, deepseek")
|
||
|
||
|
||
# 全局LLM服务实例
|
||
llm_service = LLMService()
|