482 lines
18 KiB
Python
482 lines
18 KiB
Python
"""
|
||
LLM服务 - 处理各种LLM提供商的调用
|
||
"""
|
||
from typing import Dict, Any, Optional, List
|
||
import json
|
||
import asyncio
|
||
import logging
|
||
from openai import AsyncOpenAI
|
||
from app.core.config import settings
|
||
from app.services.tool_registry import tool_registry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class LLMService:
|
||
"""LLM服务类"""
|
||
|
||
def __init__(self):
|
||
"""初始化LLM服务"""
|
||
self.openai_client = None
|
||
self.deepseek_client = None
|
||
|
||
# 初始化OpenAI客户端
|
||
if settings.OPENAI_API_KEY:
|
||
self.openai_client = AsyncOpenAI(
|
||
api_key=settings.OPENAI_API_KEY,
|
||
base_url=settings.OPENAI_BASE_URL
|
||
)
|
||
|
||
# 初始化DeepSeek客户端(兼容OpenAI API)
|
||
if settings.DEEPSEEK_API_KEY:
|
||
self.deepseek_client = AsyncOpenAI(
|
||
api_key=settings.DEEPSEEK_API_KEY,
|
||
base_url=settings.DEEPSEEK_BASE_URL
|
||
)
|
||
|
||
async def call_openai(
|
||
self,
|
||
prompt: str,
|
||
model: str = "gpt-3.5-turbo",
|
||
temperature: float = 0.7,
|
||
max_tokens: Optional[int] = None,
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
调用OpenAI API
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
model: 模型名称,默认gpt-3.5-turbo
|
||
temperature: 温度参数,默认0.7
|
||
max_tokens: 最大token数
|
||
api_key: API密钥(可选,如果不提供则使用默认配置)
|
||
base_url: API地址(可选,如果不提供则使用默认配置)
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
LLM返回的文本
|
||
"""
|
||
# 如果提供了api_key或base_url,创建临时客户端
|
||
# 注意:api_key 可能是空字符串,需要检查是否为 None
|
||
if api_key is not None or base_url is not None:
|
||
# 如果提供了 api_key,使用它;否则使用系统默认配置
|
||
final_api_key = api_key if api_key else settings.OPENAI_API_KEY
|
||
final_base_url = base_url if base_url else settings.OPENAI_BASE_URL
|
||
|
||
if not final_api_key:
|
||
raise ValueError("OpenAI API Key未配置,请在节点配置中设置API Key或在环境变量中设置OPENAI_API_KEY")
|
||
|
||
client = AsyncOpenAI(
|
||
api_key=final_api_key,
|
||
base_url=final_base_url
|
||
)
|
||
else:
|
||
# 如果 openai_client 未初始化,尝试从 settings 重新读取并初始化
|
||
if not self.openai_client:
|
||
if settings.OPENAI_API_KEY:
|
||
self.openai_client = AsyncOpenAI(
|
||
api_key=settings.OPENAI_API_KEY,
|
||
base_url=settings.OPENAI_BASE_URL
|
||
)
|
||
else:
|
||
raise ValueError("OpenAI API Key未配置,请在节点配置中设置API Key或在环境变量中设置OPENAI_API_KEY")
|
||
client = self.openai_client
|
||
|
||
try:
|
||
response = await client.chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
if content is None:
|
||
raise Exception("OpenAI API返回的内容为空,请检查API配置和模型名称")
|
||
return content
|
||
except Exception as e:
|
||
raise Exception(f"OpenAI API调用失败: {str(e)}")
|
||
|
||
async def call_deepseek(
|
||
self,
|
||
prompt: str,
|
||
model: str = "deepseek-chat",
|
||
temperature: float = 0.7,
|
||
max_tokens: Optional[int] = None,
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
调用DeepSeek API
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
model: 模型名称,默认deepseek-chat
|
||
temperature: 温度参数,默认0.7
|
||
max_tokens: 最大token数
|
||
api_key: API密钥(可选,如果不提供则使用默认配置)
|
||
base_url: API地址(可选,如果不提供则使用默认配置)
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
LLM返回的文本
|
||
"""
|
||
# 如果提供了api_key或base_url,创建临时客户端
|
||
# 注意:api_key 可能是空字符串,需要检查是否为 None
|
||
if api_key is not None or base_url is not None:
|
||
# 如果提供了 api_key,使用它;否则使用系统默认配置
|
||
final_api_key = api_key if api_key else settings.DEEPSEEK_API_KEY
|
||
final_base_url = base_url if base_url else settings.DEEPSEEK_BASE_URL
|
||
|
||
if not final_api_key:
|
||
raise ValueError("DeepSeek API Key未配置,请在节点配置中设置API Key或在环境变量中设置DEEPSEEK_API_KEY")
|
||
|
||
client = AsyncOpenAI(
|
||
api_key=final_api_key,
|
||
base_url=final_base_url
|
||
)
|
||
else:
|
||
# 如果 deepseek_client 未初始化,尝试从 settings 重新读取并初始化
|
||
if not self.deepseek_client:
|
||
if settings.DEEPSEEK_API_KEY:
|
||
self.deepseek_client = AsyncOpenAI(
|
||
api_key=settings.DEEPSEEK_API_KEY,
|
||
base_url=settings.DEEPSEEK_BASE_URL
|
||
)
|
||
else:
|
||
raise ValueError("DeepSeek API Key未配置,请在节点配置中设置API Key或在环境变量中设置DEEPSEEK_API_KEY")
|
||
client = self.deepseek_client
|
||
|
||
try:
|
||
# 记录实际发送给LLM的prompt
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
logger.info(f"[rjb] DeepSeek实际发送的prompt前200字符: {prompt[:200] if len(prompt) > 200 else prompt}")
|
||
|
||
response = await client.chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
if content is None:
|
||
raise Exception("DeepSeek API返回的内容为空,请检查API配置和模型名称")
|
||
return content
|
||
except Exception as e:
|
||
raise Exception(f"DeepSeek API调用失败: {str(e)}")
|
||
|
||
async def call_llm(
|
||
self,
|
||
prompt: str,
|
||
provider: str = "openai",
|
||
model: Optional[str] = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: Optional[int] = None,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
通用LLM调用接口
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
provider: 提供商,支持openai、deepseek
|
||
model: 模型名称
|
||
temperature: 温度参数
|
||
max_tokens: 最大token数
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
LLM返回的文本
|
||
"""
|
||
if provider == "openai":
|
||
# 默认模型
|
||
if not model:
|
||
model = "gpt-3.5-turbo"
|
||
return await self.call_openai(
|
||
prompt=prompt,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs
|
||
)
|
||
elif provider == "deepseek":
|
||
# 默认模型
|
||
if not model:
|
||
model = "deepseek-chat"
|
||
return await self.call_deepseek(
|
||
prompt=prompt,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs
|
||
)
|
||
else:
|
||
raise ValueError(f"不支持的LLM提供商: {provider},目前支持: openai, deepseek")
|
||
|
||
async def call_openai_with_tools(
|
||
self,
|
||
prompt: str,
|
||
tools: List[Dict[str, Any]],
|
||
model: str = "gpt-3.5-turbo",
|
||
temperature: float = 0.7,
|
||
max_tokens: Optional[int] = None,
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
max_iterations: int = 5
|
||
) -> str:
|
||
"""
|
||
调用OpenAI API,支持工具调用
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
tools: 工具定义列表(OpenAI Function格式)
|
||
model: 模型名称
|
||
temperature: 温度参数
|
||
max_tokens: 最大token数
|
||
api_key: API密钥
|
||
base_url: API地址
|
||
max_iterations: 最大工具调用迭代次数
|
||
|
||
Returns:
|
||
LLM返回的最终文本
|
||
"""
|
||
# 获取客户端
|
||
if api_key is not None or base_url is not None:
|
||
final_api_key = api_key if api_key else settings.OPENAI_API_KEY
|
||
final_base_url = base_url if base_url else settings.OPENAI_BASE_URL
|
||
if not final_api_key:
|
||
raise ValueError("OpenAI API Key未配置")
|
||
client = AsyncOpenAI(api_key=final_api_key, base_url=final_base_url)
|
||
else:
|
||
if not self.openai_client:
|
||
if settings.OPENAI_API_KEY:
|
||
self.openai_client = AsyncOpenAI(
|
||
api_key=settings.OPENAI_API_KEY,
|
||
base_url=settings.OPENAI_BASE_URL
|
||
)
|
||
else:
|
||
raise ValueError("OpenAI API Key未配置")
|
||
client = self.openai_client
|
||
|
||
messages = [{"role": "user", "content": prompt}]
|
||
|
||
try:
|
||
for iteration in range(max_iterations):
|
||
# 准备工具参数(只在第一次调用时传递tools)
|
||
create_kwargs = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens
|
||
}
|
||
|
||
if iteration == 0:
|
||
# 转换工具格式为OpenAI格式
|
||
openai_tools = []
|
||
for tool in tools:
|
||
if isinstance(tool, dict):
|
||
if "type" in tool and tool["type"] == "function":
|
||
openai_tools.append(tool)
|
||
elif "function" in tool:
|
||
openai_tools.append(tool)
|
||
else:
|
||
# 假设是function格式,包装一下
|
||
openai_tools.append({
|
||
"type": "function",
|
||
"function": tool
|
||
})
|
||
create_kwargs["tools"] = openai_tools
|
||
create_kwargs["tool_choice"] = "auto"
|
||
|
||
# 调用LLM
|
||
response = await client.chat.completions.create(**create_kwargs)
|
||
|
||
message = response.choices[0].message
|
||
|
||
# 添加助手回复到消息历史
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": message.content,
|
||
"tool_calls": [
|
||
{
|
||
"id": tc.id,
|
||
"type": tc.type,
|
||
"function": {
|
||
"name": tc.function.name,
|
||
"arguments": tc.function.arguments
|
||
}
|
||
} for tc in (message.tool_calls or [])
|
||
]
|
||
})
|
||
|
||
# 检查是否有工具调用
|
||
if message.tool_calls and len(message.tool_calls) > 0:
|
||
logger.info(f"检测到 {len(message.tool_calls)} 个工具调用")
|
||
# 处理每个工具调用
|
||
for tool_call in message.tool_calls:
|
||
tool_name = tool_call.function.name
|
||
try:
|
||
tool_args = json.loads(tool_call.function.arguments)
|
||
except:
|
||
tool_args = {}
|
||
|
||
logger.info(f"执行工具: {tool_name}, 参数: {tool_args}")
|
||
|
||
# 执行工具
|
||
tool_result = await self._execute_tool(tool_name, tool_args)
|
||
|
||
# 添加工具结果到消息历史
|
||
messages.append({
|
||
"role": "tool",
|
||
"tool_call_id": tool_call.id,
|
||
"content": tool_result
|
||
})
|
||
else:
|
||
# 没有工具调用,返回最终回复
|
||
final_content = message.content or ""
|
||
if final_content:
|
||
logger.info("LLM返回最终回复,工具调用完成")
|
||
return final_content
|
||
|
||
# 达到最大迭代次数
|
||
logger.warning(f"达到最大工具调用迭代次数 ({max_iterations})")
|
||
last_message = messages[-1] if messages else {}
|
||
return last_message.get("content", "达到最大工具调用次数")
|
||
except Exception as e:
|
||
logger.error(f"工具调用过程中出错: {str(e)}")
|
||
raise Exception(f"OpenAI工具调用失败: {str(e)}")
|
||
|
||
async def call_deepseek_with_tools(
|
||
self,
|
||
prompt: str,
|
||
tools: List[Dict[str, Any]],
|
||
model: str = "deepseek-chat",
|
||
temperature: float = 0.7,
|
||
max_tokens: Optional[int] = None,
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
max_iterations: int = 5
|
||
) -> str:
|
||
"""
|
||
调用DeepSeek API,支持工具调用(DeepSeek兼容OpenAI API格式)
|
||
"""
|
||
# DeepSeek使用相同的实现
|
||
return await self.call_openai_with_tools(
|
||
prompt=prompt,
|
||
tools=tools,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
api_key=api_key or settings.DEEPSEEK_API_KEY,
|
||
base_url=base_url or settings.DEEPSEEK_BASE_URL,
|
||
max_iterations=max_iterations
|
||
)
|
||
|
||
async def call_llm_with_tools(
|
||
self,
|
||
prompt: str,
|
||
tools: List[Dict[str, Any]],
|
||
provider: str = "openai",
|
||
model: Optional[str] = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: Optional[int] = None,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
通用LLM调用接口(支持工具)
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
tools: 工具定义列表
|
||
provider: 提供商,支持openai、deepseek
|
||
model: 模型名称
|
||
temperature: 温度参数
|
||
max_tokens: 最大token数
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
LLM返回的最终文本
|
||
"""
|
||
if provider == "openai":
|
||
if not model:
|
||
model = "gpt-3.5-turbo"
|
||
return await self.call_openai_with_tools(
|
||
prompt=prompt,
|
||
tools=tools,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs
|
||
)
|
||
elif provider == "deepseek":
|
||
if not model:
|
||
model = "deepseek-chat"
|
||
return await self.call_deepseek_with_tools(
|
||
prompt=prompt,
|
||
tools=tools,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs
|
||
)
|
||
else:
|
||
raise ValueError(f"不支持的LLM提供商: {provider},目前支持: openai, deepseek")
|
||
|
||
async def _execute_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> str:
|
||
"""
|
||
执行工具
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
tool_args: 工具参数
|
||
|
||
Returns:
|
||
工具执行结果(JSON字符串)
|
||
"""
|
||
# 从注册表获取工具函数
|
||
tool_func = tool_registry.get_tool_function(tool_name)
|
||
|
||
if not tool_func:
|
||
error_msg = f"工具 {tool_name} 未找到"
|
||
logger.error(error_msg)
|
||
return json.dumps({"error": error_msg}, ensure_ascii=False)
|
||
|
||
try:
|
||
logger.info(f"执行工具 {tool_name},参数: {tool_args}")
|
||
|
||
# 执行工具(支持异步函数)
|
||
if asyncio.iscoroutinefunction(tool_func):
|
||
result = await tool_func(**tool_args)
|
||
else:
|
||
# 同步函数在事件循环中执行
|
||
result = tool_func(**tool_args)
|
||
|
||
# 将结果转换为字符串
|
||
if isinstance(result, (dict, list)):
|
||
result_str = json.dumps(result, ensure_ascii=False)
|
||
else:
|
||
result_str = str(result)
|
||
|
||
logger.info(f"工具 {tool_name} 执行成功,结果长度: {len(result_str)}")
|
||
return result_str
|
||
except Exception as e:
|
||
error_msg = f"工具 {tool_name} 执行失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
return json.dumps({"error": error_msg}, ensure_ascii=False)
|
||
|
||
|
||
# 全局LLM服务实例
|
||
llm_service = LLMService()
|