""" Embedding 生成与语义检索服务。 使用 OpenAI text-embedding-3-small 生成文本向量, 在内存中计算余弦相似度实现语义搜索。 如未配置 OpenAI API Key,所有方法静默降级返回空结果。 """ from __future__ import annotations import asyncio import json import logging import time from typing import Any, Dict, List, Optional, TypedDict from app.core.config import settings logger = logging.getLogger(__name__) # 默认 embedding 模型 EMBEDDING_MODEL = "text-embedding-3-small" EMBEDDING_DIMENSIONS = 1536 class VectorEntry(TypedDict, total=False): """向量条目""" id: str scope_kind: str scope_id: str content_text: str embedding: List[float] metadata: Dict[str, Any] score: float # 余弦相似度,仅检索结果包含 class EmbeddingService: """ Embedding 服务。 用法: svc = EmbeddingService() emb = await svc.generate_embedding("你好") results = await svc.similarity_search(query_emb, entries, top_k=5) """ def __init__(self): self._client: Optional[Any] = None self._client_lock = asyncio.Lock() async def _get_client(self): """延迟初始化 OpenAI 客户端(仅在首次调用时创建)。 优先级:SiliconFlow > OpenAI > DeepSeek。 SiliconFlow 的 netease-youdao/bce-embedding-base_v1 在国内可直连且免费。 """ if self._client is not None: return self._client async with self._client_lock: if self._client is not None: return self._client from openai import AsyncOpenAI api_key: Optional[str] = None base_url: Optional[str] = None backend_label = "none" # 1) SiliconFlow(国内直连,推荐) if settings.SILICONFLOW_API_KEY: api_key = settings.SILICONFLOW_API_KEY base_url = settings.SILICONFLOW_BASE_URL or "https://api.siliconflow.cn/v1" backend_label = "siliconflow" logger.info("Embedding 后端: SiliconFlow (model=%s)", settings.SILICONFLOW_EMBEDDING_MODEL) # 2) OpenAI if not api_key: if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key": api_key = settings.OPENAI_API_KEY base_url = settings.OPENAI_BASE_URL or "https://api.openai.com/v1" backend_label = "openai" # 3) DeepSeek(部分代理可能支持 embedding) if not api_key: if settings.DEEPSEEK_API_KEY: api_key = settings.DEEPSEEK_API_KEY base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com" backend_label = "deepseek" if not api_key: logger.info("未配置任何 API Key,向量记忆功能已禁用(请配置 SILICONFLOW_API_KEY 或 OPENAI_API_KEY)") self._client = None return None self._client = AsyncOpenAI( api_key=api_key, base_url=base_url, ) logger.info("EmbeddingService 已初始化 (backend=%s)", backend_label) return self._client def _get_model(self) -> str: """根据后端选择对应的 embedding 模型。""" if settings.SILICONFLOW_API_KEY: return settings.SILICONFLOW_EMBEDDING_MODEL return EMBEDDING_MODEL async def generate_embedding(self, text: str) -> Optional[List[float]]: """ 为单段文本生成 embedding 向量。 返回 float 列表,无 API key 或出错时返回 None。 """ if not text or not text.strip(): return None client = await self._get_client() if not client: return None model = self._get_model() try: kwargs: Dict[str, Any] = { "model": model, "input": text.strip()[:8000], } # OpenAI text-embedding-3-small 支持指定 dimensions;其它模型可能不支持 if model == EMBEDDING_MODEL: kwargs["dimensions"] = EMBEDDING_DIMENSIONS resp = await client.embeddings.create(**kwargs) return resp.data[0].embedding except Exception as e: logger.warning("生成 embedding 失败: %s", e) return None async def generate_embeddings(self, texts: List[str]) -> List[Optional[List[float]]]: """ 批量生成 embedding 向量。 """ if not texts: return [] client = await self._get_client() if not client: return [None] * len(texts) # 清理空文本 valid_indices = [i for i, t in enumerate(texts) if t and t.strip()] if not valid_indices: return [None] * len(texts) model = self._get_model() try: inputs = [texts[i].strip()[:8000] for i in valid_indices] kwargs: Dict[str, Any] = { "model": model, "input": inputs, } if model == EMBEDDING_MODEL: kwargs["dimensions"] = EMBEDDING_DIMENSIONS resp = await client.embeddings.create(**kwargs) embeddings = [r.embedding for r in resp.data] except Exception as e: logger.warning("批量生成 embedding 失败: %s", e) return [None] * len(texts) # 按原始顺序排列结果 result: List[Optional[List[float]]] = [None] * len(texts) for idx, emb in zip(valid_indices, embeddings): result[idx] = emb return result @staticmethod def cosine_similarity(a: List[float], b: List[float]) -> float: """计算两个向量的余弦相似度。""" if not a or not b or len(a) != len(b): return 0.0 dot = sum(x * y for x, y in zip(a, b)) norm_a = sum(x * x for x in a) ** 0.5 norm_b = sum(y * y for y in b) ** 0.5 if norm_a == 0 or norm_b == 0: return 0.0 return dot / (norm_a * norm_b) async def similarity_search( self, query_embedding: List[float], entries: List[VectorEntry], top_k: int = 5, min_score: float = 0.3, ) -> List[VectorEntry]: """ 在内存中对 entries 做余弦相似度搜索,返回 Top-K 结果(按得分降序)。 min_score: 最低相似度阈值,低于该值的结果被过滤。 """ scored: List[VectorEntry] = [] for entry in entries: emb = entry.get("embedding") if not emb: continue score = self.cosine_similarity(query_embedding, emb) if score >= min_score: entry["score"] = score scored.append(entry) scored.sort(key=lambda x: x["score"], reverse=True) return scored[:top_k] @staticmethod def serialize_embedding(embedding: List[float]) -> str: """将 embedding 序列化为 JSON 字符串。""" return json.dumps(embedding, ensure_ascii=False) @staticmethod def deserialize_embedding(data: str) -> List[float]: """从 JSON 字符串反序列化 embedding。""" if isinstance(data, list): return data try: return json.loads(data) except (json.JSONDecodeError, TypeError): return [] # ─── 离线兜底:关键词匹配(无需任何外部 API) ─── @staticmethod def _tokenize(text: str) -> set: """ 轻量分词:中文用字符二元组,英文/数字用空格分词。 混合文本同时提取中英文 token。零外部依赖,完全离线可用。 """ tokens: set = set() text_lower = text.lower() # 分离 CJK 字符和 ASCII/数字 import re # 提取所有英文/数字词(>=2字符) alpha_words = re.findall(r'[a-z0-9]{2,}', text_lower) for w in alpha_words: tokens.add(w) # 提取所有连续 CJK 字符段,生成二元组 + 单字 cjk_segments = re.findall(r'[\u4e00-\u9fff]+', text_lower) for seg in cjk_segments: for i in range(len(seg) - 1): tokens.add(seg[i:i+2]) for c in seg: tokens.add(c) # 提取数字 numbers = re.findall(r'\d+', text_lower) for n in numbers: tokens.add(n) return tokens def keyword_search( self, query: str, entries: List[VectorEntry], top_k: int = 5, min_score: float = 0.1, ) -> List[VectorEntry]: """ 离线关键词匹配(Embedding API 不可用时的兜底方案)。 对每个 entry 计算与 query 的 Jaccard 关键词重叠分数, 返回 top-K 结果。 完全不依赖外部 API,零网络请求。 """ q_tokens = self._tokenize(query) if not q_tokens: return entries[:top_k] scored: List[VectorEntry] = [] for entry in entries: text = entry.get("content_text", "") t_tokens = self._tokenize(text) if not t_tokens: continue intersection = q_tokens & t_tokens union = q_tokens | t_tokens score = len(intersection) / len(union) if union else 0.0 if score >= min_score: entry["score"] = score scored.append(entry) scored.sort(key=lambda x: x["score"], reverse=True) return scored[:top_k] @property def offline_available(self) -> bool: """离线兜底始终可用(关键词匹配无需外部依赖)。""" return True # 全局单例 embedding_service = EmbeddingService()