Files
aiagent/backend/app/services/embedding_service.py

230 lines
7.5 KiB
Python
Raw Normal View History

"""
Embedding 生成与语义检索服务
使用 OpenAI text-embedding-3-small 生成文本向量
在内存中计算余弦相似度实现语义搜索
如未配置 OpenAI API Key所有方法静默降级返回空结果
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from typing import Any, Dict, List, Optional, TypedDict
from app.core.config import settings
logger = logging.getLogger(__name__)
# 默认 embedding 模型
EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_DIMENSIONS = 1536
class VectorEntry(TypedDict, total=False):
"""向量条目"""
id: str
scope_kind: str
scope_id: str
content_text: str
embedding: List[float]
metadata: Dict[str, Any]
score: float # 余弦相似度,仅检索结果包含
class EmbeddingService:
"""
Embedding 服务
用法
svc = EmbeddingService()
emb = await svc.generate_embedding("你好")
results = await svc.similarity_search(query_emb, entries, top_k=5)
"""
def __init__(self):
self._client: Optional[Any] = None
self._client_lock = asyncio.Lock()
async def _get_client(self):
"""延迟初始化 OpenAI 客户端(仅在首次调用时创建)。
优先级SiliconFlow > OpenAI > DeepSeek
SiliconFlow netease-youdao/bce-embedding-base_v1 在国内可直连且免费
"""
if self._client is not None:
return self._client
async with self._client_lock:
if self._client is not None:
return self._client
from openai import AsyncOpenAI
api_key: Optional[str] = None
base_url: Optional[str] = None
backend_label = "none"
# 1) SiliconFlow国内直连推荐
if settings.SILICONFLOW_API_KEY:
api_key = settings.SILICONFLOW_API_KEY
base_url = settings.SILICONFLOW_BASE_URL or "https://api.siliconflow.cn/v1"
backend_label = "siliconflow"
logger.info("Embedding 后端: SiliconFlow (model=%s)", settings.SILICONFLOW_EMBEDDING_MODEL)
# 2) OpenAI
if not api_key:
if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key":
api_key = settings.OPENAI_API_KEY
base_url = settings.OPENAI_BASE_URL or "https://api.openai.com/v1"
backend_label = "openai"
# 3) DeepSeek部分代理可能支持 embedding
if not api_key:
if settings.DEEPSEEK_API_KEY:
api_key = settings.DEEPSEEK_API_KEY
base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
backend_label = "deepseek"
if not api_key:
logger.info("未配置任何 API Key向量记忆功能已禁用请配置 SILICONFLOW_API_KEY 或 OPENAI_API_KEY")
self._client = None
return None
self._client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
)
logger.info("EmbeddingService 已初始化 (backend=%s)", backend_label)
return self._client
def _get_model(self) -> str:
"""根据后端选择对应的 embedding 模型。"""
if settings.SILICONFLOW_API_KEY:
return settings.SILICONFLOW_EMBEDDING_MODEL
return EMBEDDING_MODEL
async def generate_embedding(self, text: str) -> Optional[List[float]]:
"""
为单段文本生成 embedding 向量
返回 float 列表 API key 或出错时返回 None
"""
if not text or not text.strip():
return None
client = await self._get_client()
if not client:
return None
model = self._get_model()
try:
kwargs: Dict[str, Any] = {
"model": model,
"input": text.strip()[:8000],
}
# OpenAI text-embedding-3-small 支持指定 dimensions其它模型可能不支持
if model == EMBEDDING_MODEL:
kwargs["dimensions"] = EMBEDDING_DIMENSIONS
resp = await client.embeddings.create(**kwargs)
return resp.data[0].embedding
except Exception as e:
logger.warning("生成 embedding 失败: %s", e)
return None
async def generate_embeddings(self, texts: List[str]) -> List[Optional[List[float]]]:
"""
批量生成 embedding 向量
"""
if not texts:
return []
client = await self._get_client()
if not client:
return [None] * len(texts)
# 清理空文本
valid_indices = [i for i, t in enumerate(texts) if t and t.strip()]
if not valid_indices:
return [None] * len(texts)
model = self._get_model()
try:
inputs = [texts[i].strip()[:8000] for i in valid_indices]
kwargs: Dict[str, Any] = {
"model": model,
"input": inputs,
}
if model == EMBEDDING_MODEL:
kwargs["dimensions"] = EMBEDDING_DIMENSIONS
resp = await client.embeddings.create(**kwargs)
embeddings = [r.embedding for r in resp.data]
except Exception as e:
logger.warning("批量生成 embedding 失败: %s", e)
return [None] * len(texts)
# 按原始顺序排列结果
result: List[Optional[List[float]]] = [None] * len(texts)
for idx, emb in zip(valid_indices, embeddings):
result[idx] = emb
return result
@staticmethod
def cosine_similarity(a: List[float], b: List[float]) -> float:
"""计算两个向量的余弦相似度。"""
if not a or not b or len(a) != len(b):
return 0.0
dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(y * y for y in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
async def similarity_search(
self,
query_embedding: List[float],
entries: List[VectorEntry],
top_k: int = 5,
min_score: float = 0.3,
) -> List[VectorEntry]:
"""
在内存中对 entries 做余弦相似度搜索返回 Top-K 结果按得分降序
min_score: 最低相似度阈值低于该值的结果被过滤
"""
scored: List[VectorEntry] = []
for entry in entries:
emb = entry.get("embedding")
if not emb:
continue
score = self.cosine_similarity(query_embedding, emb)
if score >= min_score:
entry["score"] = score
scored.append(entry)
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]
@staticmethod
def serialize_embedding(embedding: List[float]) -> str:
"""将 embedding 序列化为 JSON 字符串。"""
return json.dumps(embedding, ensure_ascii=False)
@staticmethod
def deserialize_embedding(data: str) -> List[float]:
"""从 JSON 字符串反序列化 embedding。"""
if isinstance(data, list):
return data
try:
return json.loads(data)
except (json.JSONDecodeError, TypeError):
return []
# 全局单例
embedding_service = EmbeddingService()