230 lines
7.5 KiB
Python
230 lines
7.5 KiB
Python
|
|
"""
|
|||
|
|
Embedding 生成与语义检索服务。
|
|||
|
|
|
|||
|
|
使用 OpenAI text-embedding-3-small 生成文本向量,
|
|||
|
|
在内存中计算余弦相似度实现语义搜索。
|
|||
|
|
|
|||
|
|
如未配置 OpenAI API Key,所有方法静默降级返回空结果。
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
import time
|
|||
|
|
from typing import Any, Dict, List, Optional, TypedDict
|
|||
|
|
|
|||
|
|
from app.core.config import settings
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# 默认 embedding 模型
|
|||
|
|
EMBEDDING_MODEL = "text-embedding-3-small"
|
|||
|
|
EMBEDDING_DIMENSIONS = 1536
|
|||
|
|
|
|||
|
|
|
|||
|
|
class VectorEntry(TypedDict, total=False):
|
|||
|
|
"""向量条目"""
|
|||
|
|
id: str
|
|||
|
|
scope_kind: str
|
|||
|
|
scope_id: str
|
|||
|
|
content_text: str
|
|||
|
|
embedding: List[float]
|
|||
|
|
metadata: Dict[str, Any]
|
|||
|
|
score: float # 余弦相似度,仅检索结果包含
|
|||
|
|
|
|||
|
|
|
|||
|
|
class EmbeddingService:
|
|||
|
|
"""
|
|||
|
|
Embedding 服务。
|
|||
|
|
|
|||
|
|
用法:
|
|||
|
|
svc = EmbeddingService()
|
|||
|
|
emb = await svc.generate_embedding("你好")
|
|||
|
|
results = await svc.similarity_search(query_emb, entries, top_k=5)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self._client: Optional[Any] = None
|
|||
|
|
self._client_lock = asyncio.Lock()
|
|||
|
|
|
|||
|
|
async def _get_client(self):
|
|||
|
|
"""延迟初始化 OpenAI 客户端(仅在首次调用时创建)。
|
|||
|
|
|
|||
|
|
优先级:SiliconFlow > OpenAI > DeepSeek。
|
|||
|
|
SiliconFlow 的 netease-youdao/bce-embedding-base_v1 在国内可直连且免费。
|
|||
|
|
"""
|
|||
|
|
if self._client is not None:
|
|||
|
|
return self._client
|
|||
|
|
|
|||
|
|
async with self._client_lock:
|
|||
|
|
if self._client is not None:
|
|||
|
|
return self._client
|
|||
|
|
|
|||
|
|
from openai import AsyncOpenAI
|
|||
|
|
|
|||
|
|
api_key: Optional[str] = None
|
|||
|
|
base_url: Optional[str] = None
|
|||
|
|
backend_label = "none"
|
|||
|
|
|
|||
|
|
# 1) SiliconFlow(国内直连,推荐)
|
|||
|
|
if settings.SILICONFLOW_API_KEY:
|
|||
|
|
api_key = settings.SILICONFLOW_API_KEY
|
|||
|
|
base_url = settings.SILICONFLOW_BASE_URL or "https://api.siliconflow.cn/v1"
|
|||
|
|
backend_label = "siliconflow"
|
|||
|
|
logger.info("Embedding 后端: SiliconFlow (model=%s)", settings.SILICONFLOW_EMBEDDING_MODEL)
|
|||
|
|
|
|||
|
|
# 2) OpenAI
|
|||
|
|
if not api_key:
|
|||
|
|
if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key":
|
|||
|
|
api_key = settings.OPENAI_API_KEY
|
|||
|
|
base_url = settings.OPENAI_BASE_URL or "https://api.openai.com/v1"
|
|||
|
|
backend_label = "openai"
|
|||
|
|
|
|||
|
|
# 3) DeepSeek(部分代理可能支持 embedding)
|
|||
|
|
if not api_key:
|
|||
|
|
if settings.DEEPSEEK_API_KEY:
|
|||
|
|
api_key = settings.DEEPSEEK_API_KEY
|
|||
|
|
base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
|
|||
|
|
backend_label = "deepseek"
|
|||
|
|
|
|||
|
|
if not api_key:
|
|||
|
|
logger.info("未配置任何 API Key,向量记忆功能已禁用(请配置 SILICONFLOW_API_KEY 或 OPENAI_API_KEY)")
|
|||
|
|
self._client = None
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
self._client = AsyncOpenAI(
|
|||
|
|
api_key=api_key,
|
|||
|
|
base_url=base_url,
|
|||
|
|
)
|
|||
|
|
logger.info("EmbeddingService 已初始化 (backend=%s)", backend_label)
|
|||
|
|
return self._client
|
|||
|
|
|
|||
|
|
def _get_model(self) -> str:
|
|||
|
|
"""根据后端选择对应的 embedding 模型。"""
|
|||
|
|
if settings.SILICONFLOW_API_KEY:
|
|||
|
|
return settings.SILICONFLOW_EMBEDDING_MODEL
|
|||
|
|
return EMBEDDING_MODEL
|
|||
|
|
|
|||
|
|
async def generate_embedding(self, text: str) -> Optional[List[float]]:
|
|||
|
|
"""
|
|||
|
|
为单段文本生成 embedding 向量。
|
|||
|
|
|
|||
|
|
返回 float 列表,无 API key 或出错时返回 None。
|
|||
|
|
"""
|
|||
|
|
if not text or not text.strip():
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
client = await self._get_client()
|
|||
|
|
if not client:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
model = self._get_model()
|
|||
|
|
try:
|
|||
|
|
kwargs: Dict[str, Any] = {
|
|||
|
|
"model": model,
|
|||
|
|
"input": text.strip()[:8000],
|
|||
|
|
}
|
|||
|
|
# OpenAI text-embedding-3-small 支持指定 dimensions;其它模型可能不支持
|
|||
|
|
if model == EMBEDDING_MODEL:
|
|||
|
|
kwargs["dimensions"] = EMBEDDING_DIMENSIONS
|
|||
|
|
resp = await client.embeddings.create(**kwargs)
|
|||
|
|
return resp.data[0].embedding
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning("生成 embedding 失败: %s", e)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
async def generate_embeddings(self, texts: List[str]) -> List[Optional[List[float]]]:
|
|||
|
|
"""
|
|||
|
|
批量生成 embedding 向量。
|
|||
|
|
"""
|
|||
|
|
if not texts:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
client = await self._get_client()
|
|||
|
|
if not client:
|
|||
|
|
return [None] * len(texts)
|
|||
|
|
|
|||
|
|
# 清理空文本
|
|||
|
|
valid_indices = [i for i, t in enumerate(texts) if t and t.strip()]
|
|||
|
|
if not valid_indices:
|
|||
|
|
return [None] * len(texts)
|
|||
|
|
|
|||
|
|
model = self._get_model()
|
|||
|
|
try:
|
|||
|
|
inputs = [texts[i].strip()[:8000] for i in valid_indices]
|
|||
|
|
kwargs: Dict[str, Any] = {
|
|||
|
|
"model": model,
|
|||
|
|
"input": inputs,
|
|||
|
|
}
|
|||
|
|
if model == EMBEDDING_MODEL:
|
|||
|
|
kwargs["dimensions"] = EMBEDDING_DIMENSIONS
|
|||
|
|
resp = await client.embeddings.create(**kwargs)
|
|||
|
|
embeddings = [r.embedding for r in resp.data]
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning("批量生成 embedding 失败: %s", e)
|
|||
|
|
return [None] * len(texts)
|
|||
|
|
|
|||
|
|
# 按原始顺序排列结果
|
|||
|
|
result: List[Optional[List[float]]] = [None] * len(texts)
|
|||
|
|
for idx, emb in zip(valid_indices, embeddings):
|
|||
|
|
result[idx] = emb
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
|||
|
|
"""计算两个向量的余弦相似度。"""
|
|||
|
|
if not a or not b or len(a) != len(b):
|
|||
|
|
return 0.0
|
|||
|
|
|
|||
|
|
dot = sum(x * y for x, y in zip(a, b))
|
|||
|
|
norm_a = sum(x * x for x in a) ** 0.5
|
|||
|
|
norm_b = sum(y * y for y in b) ** 0.5
|
|||
|
|
|
|||
|
|
if norm_a == 0 or norm_b == 0:
|
|||
|
|
return 0.0
|
|||
|
|
return dot / (norm_a * norm_b)
|
|||
|
|
|
|||
|
|
async def similarity_search(
|
|||
|
|
self,
|
|||
|
|
query_embedding: List[float],
|
|||
|
|
entries: List[VectorEntry],
|
|||
|
|
top_k: int = 5,
|
|||
|
|
min_score: float = 0.3,
|
|||
|
|
) -> List[VectorEntry]:
|
|||
|
|
"""
|
|||
|
|
在内存中对 entries 做余弦相似度搜索,返回 Top-K 结果(按得分降序)。
|
|||
|
|
min_score: 最低相似度阈值,低于该值的结果被过滤。
|
|||
|
|
"""
|
|||
|
|
scored: List[VectorEntry] = []
|
|||
|
|
for entry in entries:
|
|||
|
|
emb = entry.get("embedding")
|
|||
|
|
if not emb:
|
|||
|
|
continue
|
|||
|
|
score = self.cosine_similarity(query_embedding, emb)
|
|||
|
|
if score >= min_score:
|
|||
|
|
entry["score"] = score
|
|||
|
|
scored.append(entry)
|
|||
|
|
|
|||
|
|
scored.sort(key=lambda x: x["score"], reverse=True)
|
|||
|
|
return scored[:top_k]
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def serialize_embedding(embedding: List[float]) -> str:
|
|||
|
|
"""将 embedding 序列化为 JSON 字符串。"""
|
|||
|
|
return json.dumps(embedding, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def deserialize_embedding(data: str) -> List[float]:
|
|||
|
|
"""从 JSON 字符串反序列化 embedding。"""
|
|||
|
|
if isinstance(data, list):
|
|||
|
|
return data
|
|||
|
|
try:
|
|||
|
|
return json.loads(data)
|
|||
|
|
except (json.JSONDecodeError, TypeError):
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局单例
|
|||
|
|
embedding_service = EmbeddingService()
|