Files
aiagent/backend/app/services/embedding_service.py
renjianbo 7b9e0826de feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖
- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser
- 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行)
- 新增 agent_vector_memory / knowledge_base 模型及对应数据库表
- 实现 SSE 流式响应与 Agent 预算控制
- AgentChat.vue 集成 MainLayout 导航布局
- 完善测试体系:7 个新测试文件共 110 个测试覆盖
- 修复 conftest.py SQLite 内存数据库连接隔离问题

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-01 22:30:46 +08:00

230 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Embedding 生成与语义检索服务。
使用 OpenAI text-embedding-3-small 生成文本向量,
在内存中计算余弦相似度实现语义搜索。
如未配置 OpenAI API Key所有方法静默降级返回空结果。
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from typing import Any, Dict, List, Optional, TypedDict
from app.core.config import settings
logger = logging.getLogger(__name__)
# 默认 embedding 模型
EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_DIMENSIONS = 1536
class VectorEntry(TypedDict, total=False):
"""向量条目"""
id: str
scope_kind: str
scope_id: str
content_text: str
embedding: List[float]
metadata: Dict[str, Any]
score: float # 余弦相似度,仅检索结果包含
class EmbeddingService:
"""
Embedding 服务。
用法:
svc = EmbeddingService()
emb = await svc.generate_embedding("你好")
results = await svc.similarity_search(query_emb, entries, top_k=5)
"""
def __init__(self):
self._client: Optional[Any] = None
self._client_lock = asyncio.Lock()
async def _get_client(self):
"""延迟初始化 OpenAI 客户端(仅在首次调用时创建)。
优先级SiliconFlow > OpenAI > DeepSeek。
SiliconFlow 的 netease-youdao/bce-embedding-base_v1 在国内可直连且免费。
"""
if self._client is not None:
return self._client
async with self._client_lock:
if self._client is not None:
return self._client
from openai import AsyncOpenAI
api_key: Optional[str] = None
base_url: Optional[str] = None
backend_label = "none"
# 1) SiliconFlow国内直连推荐
if settings.SILICONFLOW_API_KEY:
api_key = settings.SILICONFLOW_API_KEY
base_url = settings.SILICONFLOW_BASE_URL or "https://api.siliconflow.cn/v1"
backend_label = "siliconflow"
logger.info("Embedding 后端: SiliconFlow (model=%s)", settings.SILICONFLOW_EMBEDDING_MODEL)
# 2) OpenAI
if not api_key:
if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key":
api_key = settings.OPENAI_API_KEY
base_url = settings.OPENAI_BASE_URL or "https://api.openai.com/v1"
backend_label = "openai"
# 3) DeepSeek部分代理可能支持 embedding
if not api_key:
if settings.DEEPSEEK_API_KEY:
api_key = settings.DEEPSEEK_API_KEY
base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
backend_label = "deepseek"
if not api_key:
logger.info("未配置任何 API Key向量记忆功能已禁用请配置 SILICONFLOW_API_KEY 或 OPENAI_API_KEY")
self._client = None
return None
self._client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
)
logger.info("EmbeddingService 已初始化 (backend=%s)", backend_label)
return self._client
def _get_model(self) -> str:
"""根据后端选择对应的 embedding 模型。"""
if settings.SILICONFLOW_API_KEY:
return settings.SILICONFLOW_EMBEDDING_MODEL
return EMBEDDING_MODEL
async def generate_embedding(self, text: str) -> Optional[List[float]]:
"""
为单段文本生成 embedding 向量。
返回 float 列表,无 API key 或出错时返回 None。
"""
if not text or not text.strip():
return None
client = await self._get_client()
if not client:
return None
model = self._get_model()
try:
kwargs: Dict[str, Any] = {
"model": model,
"input": text.strip()[:8000],
}
# OpenAI text-embedding-3-small 支持指定 dimensions其它模型可能不支持
if model == EMBEDDING_MODEL:
kwargs["dimensions"] = EMBEDDING_DIMENSIONS
resp = await client.embeddings.create(**kwargs)
return resp.data[0].embedding
except Exception as e:
logger.warning("生成 embedding 失败: %s", e)
return None
async def generate_embeddings(self, texts: List[str]) -> List[Optional[List[float]]]:
"""
批量生成 embedding 向量。
"""
if not texts:
return []
client = await self._get_client()
if not client:
return [None] * len(texts)
# 清理空文本
valid_indices = [i for i, t in enumerate(texts) if t and t.strip()]
if not valid_indices:
return [None] * len(texts)
model = self._get_model()
try:
inputs = [texts[i].strip()[:8000] for i in valid_indices]
kwargs: Dict[str, Any] = {
"model": model,
"input": inputs,
}
if model == EMBEDDING_MODEL:
kwargs["dimensions"] = EMBEDDING_DIMENSIONS
resp = await client.embeddings.create(**kwargs)
embeddings = [r.embedding for r in resp.data]
except Exception as e:
logger.warning("批量生成 embedding 失败: %s", e)
return [None] * len(texts)
# 按原始顺序排列结果
result: List[Optional[List[float]]] = [None] * len(texts)
for idx, emb in zip(valid_indices, embeddings):
result[idx] = emb
return result
@staticmethod
def cosine_similarity(a: List[float], b: List[float]) -> float:
"""计算两个向量的余弦相似度。"""
if not a or not b or len(a) != len(b):
return 0.0
dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(y * y for y in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
async def similarity_search(
self,
query_embedding: List[float],
entries: List[VectorEntry],
top_k: int = 5,
min_score: float = 0.3,
) -> List[VectorEntry]:
"""
在内存中对 entries 做余弦相似度搜索,返回 Top-K 结果(按得分降序)。
min_score: 最低相似度阈值,低于该值的结果被过滤。
"""
scored: List[VectorEntry] = []
for entry in entries:
emb = entry.get("embedding")
if not emb:
continue
score = self.cosine_similarity(query_embedding, emb)
if score >= min_score:
entry["score"] = score
scored.append(entry)
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]
@staticmethod
def serialize_embedding(embedding: List[float]) -> str:
"""将 embedding 序列化为 JSON 字符串。"""
return json.dumps(embedding, ensure_ascii=False)
@staticmethod
def deserialize_embedding(data: str) -> List[float]:
"""从 JSON 字符串反序列化 embedding。"""
if isinstance(data, list):
return data
try:
return json.loads(data)
except (json.JSONDecodeError, TypeError):
return []
# 全局单例
embedding_service = EmbeddingService()