feat: Agent间知识共享 — 全局知识池(去重+置信度+过期)

- GlobalKnowledge 模型新增 confidence 和 expires_at 字段
- save_global_knowledge 增加 MD5 去重、置信度评分、TTL过期
- _global_knowledge_search 增加过期过滤、置信度优先排序
- run()/run_stream() 所有完成路径补齐知识提取调用
- 新增 Alembic 迁移 010_add_global_knowledge

Cl  oses #9

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
renjianbo
2026-05-06 22:04:56 +08:00
parent 1f7c136544
commit 9054f42cda
4 changed files with 175 additions and 8 deletions

View File

@@ -0,0 +1,45 @@
"""add confidence and expires_at columns to global_knowledge table
Revision ID: 010_add_global_knowledge
Revises: 009_notif_sched_feishu
Create Date: 2026-05-06
"""
from alembic import op
import sqlalchemy as sa
revision = "010_add_global_knowledge"
down_revision = "009_notif_sched_feishu"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add confidence column (if table doesn't exist, this is a no-op handled below)
try:
op.add_column(
"global_knowledge",
sa.Column("confidence", sa.String(20), default="medium", comment="置信度: low/medium/high"),
)
except Exception:
pass
# Add expires_at column
try:
op.add_column(
"global_knowledge",
sa.Column("expires_at", sa.DateTime(), nullable=True, comment="过期时间NULL 表示永不过期"),
)
except Exception:
pass
def downgrade() -> None:
try:
op.drop_column("global_knowledge", "expires_at")
except Exception:
pass
try:
op.drop_column("global_knowledge", "confidence")
except Exception:
pass

View File

@@ -226,6 +226,7 @@ class AgentRuntime:
# LLM 直接返回文本 → 结束
self.context.add_assistant_message(content)
final_text = content or "(模型未返回有效内容)"
review_score = 0.0
# 输出质量自检默认关闭Agent 节点可开启)
if self.config.self_review_enabled and not _self_review_attempted:
@@ -239,6 +240,7 @@ class AgentRuntime:
tool_result=json.dumps(review, ensure_ascii=False),
))
if review["passed"]:
review_score = review["score"]
logger.info("self_review 通过 (%.2f >= %.2f)", review["score"], review["threshold"])
else:
logger.info("self_review 未通过 (%.2f < %.2f),追加修正", review["score"], review["threshold"])
@@ -269,7 +271,7 @@ class AgentRuntime:
tool_calls_made=self.context.tool_calls_made,
)
# 提取知识到全局知识池Agent 间知识共享)
await self._extract_global_knowledge(user_input, final_text, steps)
await self._extract_global_knowledge(user_input, final_text, steps, review_score)
return AgentResult(
success=True,
content=final_text,
@@ -393,6 +395,9 @@ class AgentRuntime:
iterations_used=self.context.iteration,
tool_calls_made=self.context.tool_calls_made,
)
# 提取知识到全局知识池(即便截断,工具调用序列仍有参考价值)
if last_content:
await self._extract_global_knowledge(user_input, last_content, steps)
if last_content:
steps.append(AgentStep(
iteration=self.context.iteration,
@@ -512,6 +517,7 @@ class AgentRuntime:
# LLM 直接返回文本 → 结束
self.context.add_assistant_message(content)
final_text = content or "(模型未返回有效内容)"
review_score = 0.0
# 输出质量自检(默认关闭)
if self.config.self_review_enabled and not _self_review_attempted:
@@ -524,6 +530,7 @@ class AgentRuntime:
"session_id": self.context.session_id,
}
if review["passed"]:
review_score = review["score"]
logger.info("self_review 通过 (%.2f >= %.2f)", review["score"], review["threshold"])
else:
logger.info("self_review 未通过 (%.2f < %.2f),追加修正", review["score"], review["threshold"])
@@ -560,6 +567,8 @@ class AgentRuntime:
iterations_used=self.context.iteration,
tool_calls_made=self.context.tool_calls_made,
)
# 提取知识到全局知识池Agent 间知识共享)
await self._extract_global_knowledge(user_input, final_text, steps, review_score)
return
# 有工具调用 → 先记录 assistant 消息
@@ -706,6 +715,9 @@ class AgentRuntime:
iterations_used=self.context.iteration,
tool_calls_made=self.context.tool_calls_made,
)
# 提取知识到全局知识池(即便截断,工具调用序列仍有参考价值)
if last_content:
await self._extract_global_knowledge(user_input, last_content, steps)
yield {
"type": "final",
"content": last_content or "已达最大迭代次数,但模型未返回最终回答。",
@@ -778,6 +790,7 @@ class AgentRuntime:
async def _extract_global_knowledge(
self, user_input: str, final_answer: str, steps: List[AgentStep],
self_review_score: float = 0.0,
) -> None:
"""从 Agent 执行结果中提取知识写入全局知识池Agent 间共享)。"""
# 提取工具调用名称作为 tags
@@ -798,11 +811,24 @@ class AgentRuntime:
source_agent_id = self.config.name if self.config.name != "default_agent" else ""
source_user_id = self.config.user_id or ""
# 置信度评估:基于 self_review 评分和工具执行成功数
confidence = "medium"
if self_review_score >= 0.8:
confidence = "high"
elif self_review_score > 0 and self_review_score < 0.5:
confidence = "low"
elif tool_names and len(tool_names) >= 2:
confidence = "high" # 多工具协作通常质量更高
# TTL: 高置信度知识有效期更长
ttl_hours = 720 if confidence == "high" else 168 if confidence == "medium" else 24
await self.memory.save_global_knowledge(
content=content,
source_agent_id=source_agent_id,
source_user_id=source_user_id,
tags=tags,
confidence=confidence,
ttl_hours=ttl_hours,
)
async def _self_review(self, content: str, task_context: str = "") -> dict:

View File

@@ -178,13 +178,21 @@ class AgentMemory:
async def _global_knowledge_search(self, query: str = "") -> str:
"""从 GlobalKnowledge 表检索相关的全局知识条目。"""
from datetime import datetime
from app.models.agent import GlobalKnowledge
db: Optional[Session] = None
try:
db = SessionLocal()
now = datetime.utcnow()
# 查询未过期的知识expires_at IS NULL 或 expires_at > now
rows = (
db.query(GlobalKnowledge)
.filter(
(GlobalKnowledge.expires_at.is_(None))
| (GlobalKnowledge.expires_at > now)
)
.order_by(GlobalKnowledge.created_at.desc())
.limit(50)
.all()
@@ -212,6 +220,7 @@ class AgentMemory:
"metadata": {
"source_agent_id": row.source_agent_id,
"tags": row.tags or [],
"confidence": row.confidence or "medium",
},
})
@@ -225,17 +234,22 @@ class AgentMemory:
lines = ["## 全局知识库"]
for i, m in enumerate(matched, 1):
tags = m.get("metadata", {}).get("tags", [])
conf = m.get("metadata", {}).get("confidence", "medium")
tag_str = f" [{', '.join(tags[:3])}]" if tags else ""
lines.append(f"{i}.{tag_str} {m.get('content_text', '')[:500]}")
conf_str = f" (置信度:{conf})" if conf != "medium" else ""
lines.append(f"{i}.{tag_str}{conf_str} {m.get('content_text', '')[:500]}")
return "\n".join(lines)
else:
# 无 query返回最近 5 条全局知识
recent = rows[:5]
# 无 query返回最近 5 条全局知识(优先高置信度)
recent = sorted(rows, key=lambda r: (
0 if r.confidence == "high" else 1 if r.confidence == "medium" else 2
))[:5]
if recent:
lines = ["## 全局知识库(最近)"]
for i, row in enumerate(recent, 1):
tag_str = f" [{(', '.join(row.tags[:3]))}]" if row.tags else ""
lines.append(f"{i}.{tag_str} {row.content[:500]}")
conf_str = f" (置信度:{row.confidence})" if row.confidence and row.confidence != "medium" else ""
lines.append(f"{i}.{tag_str}{conf_str} {row.content[:500]}")
return "\n".join(lines)
return ""
@@ -249,8 +263,14 @@ class AgentMemory:
async def save_global_knowledge(
self, content: str, source_agent_id: str = "",
source_user_id: str = "", tags: Optional[List[str]] = None,
confidence: str = "medium", ttl_hours: int = 0,
) -> None:
"""将知识条目写入全局知识池。"""
"""将知识条目写入全局知识池(带去重、置信度、过期时间)
去重策略:对 content 取哈希,若已有相同哈希的条目则跳过。
过期策略ttl_hours > 0 时设置 expires_at0 表示永不过期。
"""
from datetime import datetime, timedelta
from app.models.agent import GlobalKnowledge
if not content or len(content) < 20:
@@ -260,7 +280,26 @@ class AgentMemory:
try:
db = SessionLocal()
# 生成 embedding
# 去重:用 content 的 MD5 哈希检查是否已存在
import hashlib
content_hash = hashlib.md5(content[:500].encode()).hexdigest()
# 查询最近 200 条,检查是否有相同哈希的条目
recent = (
db.query(GlobalKnowledge)
.order_by(GlobalKnowledge.created_at.desc())
.limit(200)
.all()
)
for existing in recent:
existing_hash = hashlib.md5(
(existing.content or "")[:500].encode()
).hexdigest()
if existing_hash == content_hash:
logger.info("全局知识去重:已存在相同条目,跳过写入")
return
# 嵌入向量
embedding_json = ""
try:
emb = await embedding_service.generate_embedding(content)
@@ -269,18 +308,26 @@ class AgentMemory:
except Exception:
pass
# 过期时间
expires_at = None
if ttl_hours > 0:
expires_at = datetime.utcnow() + timedelta(hours=ttl_hours)
record = GlobalKnowledge(
content=content[:2000],
embedding=embedding_json or None,
source_agent_id=source_agent_id or "",
source_user_id=source_user_id or "",
tags=tags or [],
confidence=confidence or "medium",
expires_at=expires_at,
scope_kind=self.scope_kind,
scope_id=self.scope_id or "global",
)
db.add(record)
db.commit()
logger.info("已写入全局知识: agent=%s tags=%s", source_agent_id, tags)
logger.info("已写入全局知识: agent=%s tags=%s confidence=%s",
source_agent_id, tags, confidence)
except Exception as e:
logger.warning("保存全局知识失败: %s", e)
if db:

View File

@@ -24,6 +24,19 @@ class Agent(Base):
version = Column(Integer, default=1, comment="版本号")
status = Column(String(20), default="draft", comment="状态: draft/published/running/stopped")
user_id = Column(CHAR(36), ForeignKey("users.id"), comment="创建者ID")
# 技能市场字段
category = Column(String(50), nullable=True, comment="分类: llm/data_processing/automation/integration/other")
tags = Column(JSON, nullable=True, comment="标签列表")
thumbnail = Column(Text, nullable=True, comment="缩略图URL")
is_public = Column(Integer, default=0, comment="是否公开到市场: 0=私有 1=公开")
is_featured = Column(Integer, default=0, comment="是否精选: 0=否 1=是")
rating_avg = Column(String(10), default="0.0", comment="平均评分")
rating_count = Column(Integer, default=0, comment="评分人数")
use_count = Column(Integer, default=0, comment="被安装次数")
view_count = Column(Integer, default=0, comment="查看次数")
forked_from_id = Column(CHAR(36), nullable=True, comment="从哪个Agent Fork而来市场安装")
created_at = Column(DateTime, default=func.now(), comment="创建时间")
updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间")
@@ -62,6 +75,8 @@ class GlobalKnowledge(Base):
source_agent_id = Column(CHAR(36), nullable=True, comment="来源 Agent ID")
source_user_id = Column(CHAR(36), nullable=True, comment="来源用户 ID")
tags = Column(JSON, nullable=True, comment="分类标签")
confidence = Column(String(20), default="medium", comment="置信度: low/medium/high")
expires_at = Column(DateTime, nullable=True, comment="过期时间NULL 表示永不过期")
scope_kind = Column(String(50), default="agent", comment="作用域类型")
scope_id = Column(String(100), default="", comment="作用域 ID")
created_at = Column(DateTime, default=func.now(), comment="创建时间")
@@ -108,3 +123,37 @@ class KnowledgeRelation(Base):
def __repr__(self):
return f"<KnowledgeRelation({self.source_entity_id}) -[{self.relation_type}]-> ({self.target_entity_id})>"
class AgentRating(Base):
"""Agent 技能市场评分/评论表"""
__tablename__ = "agent_ratings"
id = Column(CHAR(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="评分ID")
agent_id = Column(CHAR(36), nullable=False, index=True, comment="Agent ID")
user_id = Column(CHAR(36), ForeignKey("users.id"), nullable=False, comment="评分用户ID")
rating = Column(Integer, nullable=False, comment="评分 1-5")
comment = Column(Text, nullable=True, comment="评论内容")
created_at = Column(DateTime, default=func.now(), comment="评分时间")
# 关系
user = relationship("User", backref="agent_ratings")
def __repr__(self):
return f"<AgentRating(agent={self.agent_id}, user={self.user_id}, rating={self.rating})>"
class AgentFavorite(Base):
"""Agent 技能市场收藏表"""
__tablename__ = "agent_favorites"
id = Column(CHAR(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="收藏ID")
agent_id = Column(CHAR(36), nullable=False, index=True, comment="Agent ID")
user_id = Column(CHAR(36), ForeignKey("users.id"), nullable=False, comment="收藏用户ID")
created_at = Column(DateTime, default=func.now(), comment="收藏时间")
# 关系
user = relationship("User", backref="agent_favorites")
def __repr__(self):
return f"<AgentFavorite(agent={self.agent_id}, user={self.user_id})>"