- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser - 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行) - 新增 agent_vector_memory / knowledge_base 模型及对应数据库表 - 实现 SSE 流式响应与 Agent 预算控制 - AgentChat.vue 集成 MainLayout 导航布局 - 完善测试体系:7 个新测试文件共 110 个测试覆盖 - 修复 conftest.py SQLite 内存数据库连接隔离问题 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
230 lines
7.5 KiB
Python
230 lines
7.5 KiB
Python
"""
|
||
Embedding 生成与语义检索服务。
|
||
|
||
使用 OpenAI text-embedding-3-small 生成文本向量,
|
||
在内存中计算余弦相似度实现语义搜索。
|
||
|
||
如未配置 OpenAI API Key,所有方法静默降级返回空结果。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import time
|
||
from typing import Any, Dict, List, Optional, TypedDict
|
||
|
||
from app.core.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 默认 embedding 模型
|
||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||
EMBEDDING_DIMENSIONS = 1536
|
||
|
||
|
||
class VectorEntry(TypedDict, total=False):
|
||
"""向量条目"""
|
||
id: str
|
||
scope_kind: str
|
||
scope_id: str
|
||
content_text: str
|
||
embedding: List[float]
|
||
metadata: Dict[str, Any]
|
||
score: float # 余弦相似度,仅检索结果包含
|
||
|
||
|
||
class EmbeddingService:
|
||
"""
|
||
Embedding 服务。
|
||
|
||
用法:
|
||
svc = EmbeddingService()
|
||
emb = await svc.generate_embedding("你好")
|
||
results = await svc.similarity_search(query_emb, entries, top_k=5)
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._client: Optional[Any] = None
|
||
self._client_lock = asyncio.Lock()
|
||
|
||
async def _get_client(self):
|
||
"""延迟初始化 OpenAI 客户端(仅在首次调用时创建)。
|
||
|
||
优先级:SiliconFlow > OpenAI > DeepSeek。
|
||
SiliconFlow 的 netease-youdao/bce-embedding-base_v1 在国内可直连且免费。
|
||
"""
|
||
if self._client is not None:
|
||
return self._client
|
||
|
||
async with self._client_lock:
|
||
if self._client is not None:
|
||
return self._client
|
||
|
||
from openai import AsyncOpenAI
|
||
|
||
api_key: Optional[str] = None
|
||
base_url: Optional[str] = None
|
||
backend_label = "none"
|
||
|
||
# 1) SiliconFlow(国内直连,推荐)
|
||
if settings.SILICONFLOW_API_KEY:
|
||
api_key = settings.SILICONFLOW_API_KEY
|
||
base_url = settings.SILICONFLOW_BASE_URL or "https://api.siliconflow.cn/v1"
|
||
backend_label = "siliconflow"
|
||
logger.info("Embedding 后端: SiliconFlow (model=%s)", settings.SILICONFLOW_EMBEDDING_MODEL)
|
||
|
||
# 2) OpenAI
|
||
if not api_key:
|
||
if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key":
|
||
api_key = settings.OPENAI_API_KEY
|
||
base_url = settings.OPENAI_BASE_URL or "https://api.openai.com/v1"
|
||
backend_label = "openai"
|
||
|
||
# 3) DeepSeek(部分代理可能支持 embedding)
|
||
if not api_key:
|
||
if settings.DEEPSEEK_API_KEY:
|
||
api_key = settings.DEEPSEEK_API_KEY
|
||
base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
|
||
backend_label = "deepseek"
|
||
|
||
if not api_key:
|
||
logger.info("未配置任何 API Key,向量记忆功能已禁用(请配置 SILICONFLOW_API_KEY 或 OPENAI_API_KEY)")
|
||
self._client = None
|
||
return None
|
||
|
||
self._client = AsyncOpenAI(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
)
|
||
logger.info("EmbeddingService 已初始化 (backend=%s)", backend_label)
|
||
return self._client
|
||
|
||
def _get_model(self) -> str:
|
||
"""根据后端选择对应的 embedding 模型。"""
|
||
if settings.SILICONFLOW_API_KEY:
|
||
return settings.SILICONFLOW_EMBEDDING_MODEL
|
||
return EMBEDDING_MODEL
|
||
|
||
async def generate_embedding(self, text: str) -> Optional[List[float]]:
|
||
"""
|
||
为单段文本生成 embedding 向量。
|
||
|
||
返回 float 列表,无 API key 或出错时返回 None。
|
||
"""
|
||
if not text or not text.strip():
|
||
return None
|
||
|
||
client = await self._get_client()
|
||
if not client:
|
||
return None
|
||
|
||
model = self._get_model()
|
||
try:
|
||
kwargs: Dict[str, Any] = {
|
||
"model": model,
|
||
"input": text.strip()[:8000],
|
||
}
|
||
# OpenAI text-embedding-3-small 支持指定 dimensions;其它模型可能不支持
|
||
if model == EMBEDDING_MODEL:
|
||
kwargs["dimensions"] = EMBEDDING_DIMENSIONS
|
||
resp = await client.embeddings.create(**kwargs)
|
||
return resp.data[0].embedding
|
||
except Exception as e:
|
||
logger.warning("生成 embedding 失败: %s", e)
|
||
return None
|
||
|
||
async def generate_embeddings(self, texts: List[str]) -> List[Optional[List[float]]]:
|
||
"""
|
||
批量生成 embedding 向量。
|
||
"""
|
||
if not texts:
|
||
return []
|
||
|
||
client = await self._get_client()
|
||
if not client:
|
||
return [None] * len(texts)
|
||
|
||
# 清理空文本
|
||
valid_indices = [i for i, t in enumerate(texts) if t and t.strip()]
|
||
if not valid_indices:
|
||
return [None] * len(texts)
|
||
|
||
model = self._get_model()
|
||
try:
|
||
inputs = [texts[i].strip()[:8000] for i in valid_indices]
|
||
kwargs: Dict[str, Any] = {
|
||
"model": model,
|
||
"input": inputs,
|
||
}
|
||
if model == EMBEDDING_MODEL:
|
||
kwargs["dimensions"] = EMBEDDING_DIMENSIONS
|
||
resp = await client.embeddings.create(**kwargs)
|
||
embeddings = [r.embedding for r in resp.data]
|
||
except Exception as e:
|
||
logger.warning("批量生成 embedding 失败: %s", e)
|
||
return [None] * len(texts)
|
||
|
||
# 按原始顺序排列结果
|
||
result: List[Optional[List[float]]] = [None] * len(texts)
|
||
for idx, emb in zip(valid_indices, embeddings):
|
||
result[idx] = emb
|
||
return result
|
||
|
||
@staticmethod
|
||
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
||
"""计算两个向量的余弦相似度。"""
|
||
if not a or not b or len(a) != len(b):
|
||
return 0.0
|
||
|
||
dot = sum(x * y for x, y in zip(a, b))
|
||
norm_a = sum(x * x for x in a) ** 0.5
|
||
norm_b = sum(y * y for y in b) ** 0.5
|
||
|
||
if norm_a == 0 or norm_b == 0:
|
||
return 0.0
|
||
return dot / (norm_a * norm_b)
|
||
|
||
async def similarity_search(
|
||
self,
|
||
query_embedding: List[float],
|
||
entries: List[VectorEntry],
|
||
top_k: int = 5,
|
||
min_score: float = 0.3,
|
||
) -> List[VectorEntry]:
|
||
"""
|
||
在内存中对 entries 做余弦相似度搜索,返回 Top-K 结果(按得分降序)。
|
||
min_score: 最低相似度阈值,低于该值的结果被过滤。
|
||
"""
|
||
scored: List[VectorEntry] = []
|
||
for entry in entries:
|
||
emb = entry.get("embedding")
|
||
if not emb:
|
||
continue
|
||
score = self.cosine_similarity(query_embedding, emb)
|
||
if score >= min_score:
|
||
entry["score"] = score
|
||
scored.append(entry)
|
||
|
||
scored.sort(key=lambda x: x["score"], reverse=True)
|
||
return scored[:top_k]
|
||
|
||
@staticmethod
|
||
def serialize_embedding(embedding: List[float]) -> str:
|
||
"""将 embedding 序列化为 JSON 字符串。"""
|
||
return json.dumps(embedding, ensure_ascii=False)
|
||
|
||
@staticmethod
|
||
def deserialize_embedding(data: str) -> List[float]:
|
||
"""从 JSON 字符串反序列化 embedding。"""
|
||
if isinstance(data, list):
|
||
return data
|
||
try:
|
||
return json.loads(data)
|
||
except (json.JSONDecodeError, TypeError):
|
||
return []
|
||
|
||
|
||
# 全局单例
|
||
embedding_service = EmbeddingService()
|