Files
aiagent/backend/app/services/llm_service.py
2026-01-23 09:49:45 +08:00

482 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
LLM服务 - 处理各种LLM提供商的调用
"""
from typing import Dict, Any, Optional, List
import json
import asyncio
import logging
from openai import AsyncOpenAI
from app.core.config import settings
from app.services.tool_registry import tool_registry
logger = logging.getLogger(__name__)
class LLMService:
"""LLM服务类"""
def __init__(self):
"""初始化LLM服务"""
self.openai_client = None
self.deepseek_client = None
# 初始化OpenAI客户端
if settings.OPENAI_API_KEY:
self.openai_client = AsyncOpenAI(
api_key=settings.OPENAI_API_KEY,
base_url=settings.OPENAI_BASE_URL
)
# 初始化DeepSeek客户端兼容OpenAI API
if settings.DEEPSEEK_API_KEY:
self.deepseek_client = AsyncOpenAI(
api_key=settings.DEEPSEEK_API_KEY,
base_url=settings.DEEPSEEK_BASE_URL
)
async def call_openai(
self,
prompt: str,
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs
) -> str:
"""
调用OpenAI API
Args:
prompt: 提示词
model: 模型名称默认gpt-3.5-turbo
temperature: 温度参数默认0.7
max_tokens: 最大token数
api_key: API密钥可选如果不提供则使用默认配置
base_url: API地址可选如果不提供则使用默认配置
**kwargs: 其他参数
Returns:
LLM返回的文本
"""
# 如果提供了api_key或base_url创建临时客户端
# 注意api_key 可能是空字符串,需要检查是否为 None
if api_key is not None or base_url is not None:
# 如果提供了 api_key使用它否则使用系统默认配置
final_api_key = api_key if api_key else settings.OPENAI_API_KEY
final_base_url = base_url if base_url else settings.OPENAI_BASE_URL
if not final_api_key:
raise ValueError("OpenAI API Key未配置请在节点配置中设置API Key或在环境变量中设置OPENAI_API_KEY")
client = AsyncOpenAI(
api_key=final_api_key,
base_url=final_base_url
)
else:
# 如果 openai_client 未初始化,尝试从 settings 重新读取并初始化
if not self.openai_client:
if settings.OPENAI_API_KEY:
self.openai_client = AsyncOpenAI(
api_key=settings.OPENAI_API_KEY,
base_url=settings.OPENAI_BASE_URL
)
else:
raise ValueError("OpenAI API Key未配置请在节点配置中设置API Key或在环境变量中设置OPENAI_API_KEY")
client = self.openai_client
try:
response = await client.chat.completions.create(
model=model,
messages=[
{"role": "user", "content": prompt}
],
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
content = response.choices[0].message.content
if content is None:
raise Exception("OpenAI API返回的内容为空请检查API配置和模型名称")
return content
except Exception as e:
raise Exception(f"OpenAI API调用失败: {str(e)}")
async def call_deepseek(
self,
prompt: str,
model: str = "deepseek-chat",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs
) -> str:
"""
调用DeepSeek API
Args:
prompt: 提示词
model: 模型名称默认deepseek-chat
temperature: 温度参数默认0.7
max_tokens: 最大token数
api_key: API密钥可选如果不提供则使用默认配置
base_url: API地址可选如果不提供则使用默认配置
**kwargs: 其他参数
Returns:
LLM返回的文本
"""
# 如果提供了api_key或base_url创建临时客户端
# 注意api_key 可能是空字符串,需要检查是否为 None
if api_key is not None or base_url is not None:
# 如果提供了 api_key使用它否则使用系统默认配置
final_api_key = api_key if api_key else settings.DEEPSEEK_API_KEY
final_base_url = base_url if base_url else settings.DEEPSEEK_BASE_URL
if not final_api_key:
raise ValueError("DeepSeek API Key未配置请在节点配置中设置API Key或在环境变量中设置DEEPSEEK_API_KEY")
client = AsyncOpenAI(
api_key=final_api_key,
base_url=final_base_url
)
else:
# 如果 deepseek_client 未初始化,尝试从 settings 重新读取并初始化
if not self.deepseek_client:
if settings.DEEPSEEK_API_KEY:
self.deepseek_client = AsyncOpenAI(
api_key=settings.DEEPSEEK_API_KEY,
base_url=settings.DEEPSEEK_BASE_URL
)
else:
raise ValueError("DeepSeek API Key未配置请在节点配置中设置API Key或在环境变量中设置DEEPSEEK_API_KEY")
client = self.deepseek_client
try:
# 记录实际发送给LLM的prompt
import logging
logger = logging.getLogger(__name__)
logger.info(f"[rjb] DeepSeek实际发送的prompt前200字符: {prompt[:200] if len(prompt) > 200 else prompt}")
response = await client.chat.completions.create(
model=model,
messages=[
{"role": "user", "content": prompt}
],
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
content = response.choices[0].message.content
if content is None:
raise Exception("DeepSeek API返回的内容为空请检查API配置和模型名称")
return content
except Exception as e:
raise Exception(f"DeepSeek API调用失败: {str(e)}")
async def call_llm(
self,
prompt: str,
provider: str = "openai",
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> str:
"""
通用LLM调用接口
Args:
prompt: 提示词
provider: 提供商支持openai、deepseek
model: 模型名称
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 其他参数
Returns:
LLM返回的文本
"""
if provider == "openai":
# 默认模型
if not model:
model = "gpt-3.5-turbo"
return await self.call_openai(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
elif provider == "deepseek":
# 默认模型
if not model:
model = "deepseek-chat"
return await self.call_deepseek(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
else:
raise ValueError(f"不支持的LLM提供商: {provider},目前支持: openai, deepseek")
async def call_openai_with_tools(
self,
prompt: str,
tools: List[Dict[str, Any]],
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
max_iterations: int = 5
) -> str:
"""
调用OpenAI API支持工具调用
Args:
prompt: 提示词
tools: 工具定义列表OpenAI Function格式
model: 模型名称
temperature: 温度参数
max_tokens: 最大token数
api_key: API密钥
base_url: API地址
max_iterations: 最大工具调用迭代次数
Returns:
LLM返回的最终文本
"""
# 获取客户端
if api_key is not None or base_url is not None:
final_api_key = api_key if api_key else settings.OPENAI_API_KEY
final_base_url = base_url if base_url else settings.OPENAI_BASE_URL
if not final_api_key:
raise ValueError("OpenAI API Key未配置")
client = AsyncOpenAI(api_key=final_api_key, base_url=final_base_url)
else:
if not self.openai_client:
if settings.OPENAI_API_KEY:
self.openai_client = AsyncOpenAI(
api_key=settings.OPENAI_API_KEY,
base_url=settings.OPENAI_BASE_URL
)
else:
raise ValueError("OpenAI API Key未配置")
client = self.openai_client
messages = [{"role": "user", "content": prompt}]
try:
for iteration in range(max_iterations):
# 准备工具参数只在第一次调用时传递tools
create_kwargs = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
}
if iteration == 0:
# 转换工具格式为OpenAI格式
openai_tools = []
for tool in tools:
if isinstance(tool, dict):
if "type" in tool and tool["type"] == "function":
openai_tools.append(tool)
elif "function" in tool:
openai_tools.append(tool)
else:
# 假设是function格式包装一下
openai_tools.append({
"type": "function",
"function": tool
})
create_kwargs["tools"] = openai_tools
create_kwargs["tool_choice"] = "auto"
# 调用LLM
response = await client.chat.completions.create(**create_kwargs)
message = response.choices[0].message
# 添加助手回复到消息历史
messages.append({
"role": "assistant",
"content": message.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
}
} for tc in (message.tool_calls or [])
]
})
# 检查是否有工具调用
if message.tool_calls and len(message.tool_calls) > 0:
logger.info(f"检测到 {len(message.tool_calls)} 个工具调用")
# 处理每个工具调用
for tool_call in message.tool_calls:
tool_name = tool_call.function.name
try:
tool_args = json.loads(tool_call.function.arguments)
except:
tool_args = {}
logger.info(f"执行工具: {tool_name}, 参数: {tool_args}")
# 执行工具
tool_result = await self._execute_tool(tool_name, tool_args)
# 添加工具结果到消息历史
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result
})
else:
# 没有工具调用,返回最终回复
final_content = message.content or ""
if final_content:
logger.info("LLM返回最终回复工具调用完成")
return final_content
# 达到最大迭代次数
logger.warning(f"达到最大工具调用迭代次数 ({max_iterations})")
last_message = messages[-1] if messages else {}
return last_message.get("content", "达到最大工具调用次数")
except Exception as e:
logger.error(f"工具调用过程中出错: {str(e)}")
raise Exception(f"OpenAI工具调用失败: {str(e)}")
async def call_deepseek_with_tools(
self,
prompt: str,
tools: List[Dict[str, Any]],
model: str = "deepseek-chat",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
max_iterations: int = 5
) -> str:
"""
调用DeepSeek API支持工具调用DeepSeek兼容OpenAI API格式
"""
# DeepSeek使用相同的实现
return await self.call_openai_with_tools(
prompt=prompt,
tools=tools,
model=model,
temperature=temperature,
max_tokens=max_tokens,
api_key=api_key or settings.DEEPSEEK_API_KEY,
base_url=base_url or settings.DEEPSEEK_BASE_URL,
max_iterations=max_iterations
)
async def call_llm_with_tools(
self,
prompt: str,
tools: List[Dict[str, Any]],
provider: str = "openai",
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> str:
"""
通用LLM调用接口支持工具
Args:
prompt: 提示词
tools: 工具定义列表
provider: 提供商支持openai、deepseek
model: 模型名称
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 其他参数
Returns:
LLM返回的最终文本
"""
if provider == "openai":
if not model:
model = "gpt-3.5-turbo"
return await self.call_openai_with_tools(
prompt=prompt,
tools=tools,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
elif provider == "deepseek":
if not model:
model = "deepseek-chat"
return await self.call_deepseek_with_tools(
prompt=prompt,
tools=tools,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
else:
raise ValueError(f"不支持的LLM提供商: {provider},目前支持: openai, deepseek")
async def _execute_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> str:
"""
执行工具
Args:
tool_name: 工具名称
tool_args: 工具参数
Returns:
工具执行结果JSON字符串
"""
# 从注册表获取工具函数
tool_func = tool_registry.get_tool_function(tool_name)
if not tool_func:
error_msg = f"工具 {tool_name} 未找到"
logger.error(error_msg)
return json.dumps({"error": error_msg}, ensure_ascii=False)
try:
logger.info(f"执行工具 {tool_name},参数: {tool_args}")
# 执行工具(支持异步函数)
if asyncio.iscoroutinefunction(tool_func):
result = await tool_func(**tool_args)
else:
# 同步函数在事件循环中执行
result = tool_func(**tool_args)
# 将结果转换为字符串
if isinstance(result, (dict, list)):
result_str = json.dumps(result, ensure_ascii=False)
else:
result_str = str(result)
logger.info(f"工具 {tool_name} 执行成功,结果长度: {len(result_str)}")
return result_str
except Exception as e:
error_msg = f"工具 {tool_name} 执行失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return json.dumps({"error": error_msg}, ensure_ascii=False)
# 全局LLM服务实例
llm_service = LLMService()