Files
aiagent/backend/app/api/websocket.py
renjianbo 7e00b027d4 feat: Phase 3 - parallel execution, progress reporting, result caching + AgentChat bug fixes
Phase 3 能力:
- DAG 并行执行 (workflow_engine): asyncio.gather 并行执行就绪节点
- Debate 并行 (orchestrator): for 循环改为 asyncio.gather
- 粒度进度上报 (workflow_engine + tasks + websocket): Redis 推送 + DB 降级
- 工具结果缓存 (tool_manager): 确定性工具默认开启缓存
- LLM 响应缓存 (core): messages[-4:] + model 哈希,5min TTL

AgentChat bug 修复 (Gitea #1-#5):
- #1 SSE 降级重复空消息: fallback POST 前移除占位消息
- #2 streamTimeout 泄漏: while 正常退出后 clearTimeout
- #3 loading 闪烁: final/error 事件中提前设 loading=false
- #4 SSE 事件类型对齐: 确认匹配,未知类型加 console.warn
- #5 retryMessage 流式残留: 重试时清理占位消息

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-05 00:00:51 +08:00

153 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
WebSocket API
"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from app.websocket.manager import websocket_manager
from app.core.database import SessionLocal
from app.models.execution import Execution
from typing import Optional
import json
import asyncio
router = APIRouter()
def _get_progress_from_redis(execution_id: str) -> Optional[dict]:
"""从 Redis 读取进度数据。"""
try:
from app.core.redis_client import get_redis_client
redis_client = get_redis_client()
if redis_client:
raw = redis_client.get(f"workflow:progress:{execution_id}")
if raw:
return json.loads(raw)
except Exception:
pass
return None
@router.websocket("/api/v1/ws/executions/{execution_id}")
async def websocket_execution_status(
websocket: WebSocket,
execution_id: str,
token: Optional[str] = None
):
"""
WebSocket实时推送执行状态
Args:
websocket: WebSocket连接
execution_id: 执行记录ID
token: JWT Token可选通过query参数传递
"""
await websocket_manager.connect(websocket, execution_id)
db = SessionLocal()
try:
# 发送初始状态
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
# 尝试从 Redis 读取进度
redis_progress = _get_progress_from_redis(execution_id)
if redis_progress:
await websocket_manager.send_personal_message({
"type": "progress",
"execution_id": execution_id,
**redis_progress,
}, websocket)
else:
await websocket_manager.send_personal_message({
"type": "status",
"execution_id": execution_id,
"status": execution.status,
"progress": 0,
"message": "连接已建立"
}, websocket)
else:
await websocket_manager.send_personal_message({
"type": "error",
"message": f"执行记录 {execution_id} 不存在"
}, websocket)
await websocket.close()
return
last_progress = -1
# 持续监听并推送状态更新
while True:
try:
# 接收客户端消息(心跳等),超时 1 秒以便轮询进度
try:
data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0)
try:
message = json.loads(data)
if message.get("type") == "ping":
await websocket_manager.send_personal_message({
"type": "pong"
}, websocket)
except Exception:
pass
except asyncio.TimeoutError:
pass # 超时后轮询进度
except WebSocketDisconnect:
break
# 优先从 Redis 读取进度(推送模式)
redis_progress = _get_progress_from_redis(execution_id)
if redis_progress:
pct = redis_progress.get("progress", -1)
if pct != last_progress:
last_progress = pct
await websocket_manager.send_personal_message({
"type": "progress",
"execution_id": execution_id,
**redis_progress,
}, websocket)
else:
# Redis 不可用时回退到 DB 轮询
db.refresh(execution)
pct = 100 if execution.status in ["completed", "failed"] else (50 if execution.status == "running" else 0)
if pct != last_progress:
last_progress = pct
await websocket_manager.send_personal_message({
"type": "status",
"execution_id": execution_id,
"status": execution.status,
"progress": pct,
"message": f"执行中..." if execution.status == "running" else "等待执行"
}, websocket)
# 如果执行完成或失败,发送最终状态并断开
db.refresh(execution)
if execution.status in ["completed", "failed"]:
await websocket_manager.send_personal_message({
"type": "status",
"execution_id": execution_id,
"status": execution.status,
"progress": 100,
"result": execution.output_data if execution.status == "completed" else None,
"error": execution.error_message if execution.status == "failed" else None,
"execution_time": execution.execution_time
}, websocket)
await asyncio.sleep(1)
break
except WebSocketDisconnect:
pass
except Exception as e:
print(f"WebSocket错误: {e}")
try:
await websocket_manager.send_personal_message({
"type": "error",
"message": f"发生错误: {str(e)}"
}, websocket)
except Exception:
pass
finally:
websocket_manager.disconnect(websocket, execution_id)
db.close()