Phase 3 能力: - DAG 并行执行 (workflow_engine): asyncio.gather 并行执行就绪节点 - Debate 并行 (orchestrator): for 循环改为 asyncio.gather - 粒度进度上报 (workflow_engine + tasks + websocket): Redis 推送 + DB 降级 - 工具结果缓存 (tool_manager): 确定性工具默认开启缓存 - LLM 响应缓存 (core): messages[-4:] + model 哈希,5min TTL AgentChat bug 修复 (Gitea #1-#5): - #1 SSE 降级重复空消息: fallback POST 前移除占位消息 - #2 streamTimeout 泄漏: while 正常退出后 clearTimeout - #3 loading 闪烁: final/error 事件中提前设 loading=false - #4 SSE 事件类型对齐: 确认匹配,未知类型加 console.warn - #5 retryMessage 流式残留: 重试时清理占位消息 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
153 lines
5.4 KiB
Python
153 lines
5.4 KiB
Python
"""
|
||
WebSocket API
|
||
"""
|
||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||
from app.websocket.manager import websocket_manager
|
||
from app.core.database import SessionLocal
|
||
from app.models.execution import Execution
|
||
from typing import Optional
|
||
import json
|
||
import asyncio
|
||
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
def _get_progress_from_redis(execution_id: str) -> Optional[dict]:
|
||
"""从 Redis 读取进度数据。"""
|
||
try:
|
||
from app.core.redis_client import get_redis_client
|
||
redis_client = get_redis_client()
|
||
if redis_client:
|
||
raw = redis_client.get(f"workflow:progress:{execution_id}")
|
||
if raw:
|
||
return json.loads(raw)
|
||
except Exception:
|
||
pass
|
||
return None
|
||
|
||
|
||
@router.websocket("/api/v1/ws/executions/{execution_id}")
|
||
async def websocket_execution_status(
|
||
websocket: WebSocket,
|
||
execution_id: str,
|
||
token: Optional[str] = None
|
||
):
|
||
"""
|
||
WebSocket实时推送执行状态
|
||
|
||
Args:
|
||
websocket: WebSocket连接
|
||
execution_id: 执行记录ID
|
||
token: JWT Token(可选,通过query参数传递)
|
||
"""
|
||
await websocket_manager.connect(websocket, execution_id)
|
||
|
||
db = SessionLocal()
|
||
|
||
try:
|
||
# 发送初始状态
|
||
execution = db.query(Execution).filter(Execution.id == execution_id).first()
|
||
if execution:
|
||
# 尝试从 Redis 读取进度
|
||
redis_progress = _get_progress_from_redis(execution_id)
|
||
if redis_progress:
|
||
await websocket_manager.send_personal_message({
|
||
"type": "progress",
|
||
"execution_id": execution_id,
|
||
**redis_progress,
|
||
}, websocket)
|
||
else:
|
||
await websocket_manager.send_personal_message({
|
||
"type": "status",
|
||
"execution_id": execution_id,
|
||
"status": execution.status,
|
||
"progress": 0,
|
||
"message": "连接已建立"
|
||
}, websocket)
|
||
else:
|
||
await websocket_manager.send_personal_message({
|
||
"type": "error",
|
||
"message": f"执行记录 {execution_id} 不存在"
|
||
}, websocket)
|
||
await websocket.close()
|
||
return
|
||
|
||
last_progress = -1
|
||
|
||
# 持续监听并推送状态更新
|
||
while True:
|
||
try:
|
||
# 接收客户端消息(心跳等),超时 1 秒以便轮询进度
|
||
try:
|
||
data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0)
|
||
try:
|
||
message = json.loads(data)
|
||
if message.get("type") == "ping":
|
||
await websocket_manager.send_personal_message({
|
||
"type": "pong"
|
||
}, websocket)
|
||
except Exception:
|
||
pass
|
||
except asyncio.TimeoutError:
|
||
pass # 超时后轮询进度
|
||
|
||
except WebSocketDisconnect:
|
||
break
|
||
|
||
# 优先从 Redis 读取进度(推送模式)
|
||
redis_progress = _get_progress_from_redis(execution_id)
|
||
if redis_progress:
|
||
pct = redis_progress.get("progress", -1)
|
||
if pct != last_progress:
|
||
last_progress = pct
|
||
await websocket_manager.send_personal_message({
|
||
"type": "progress",
|
||
"execution_id": execution_id,
|
||
**redis_progress,
|
||
}, websocket)
|
||
else:
|
||
# Redis 不可用时回退到 DB 轮询
|
||
db.refresh(execution)
|
||
pct = 100 if execution.status in ["completed", "failed"] else (50 if execution.status == "running" else 0)
|
||
if pct != last_progress:
|
||
last_progress = pct
|
||
await websocket_manager.send_personal_message({
|
||
"type": "status",
|
||
"execution_id": execution_id,
|
||
"status": execution.status,
|
||
"progress": pct,
|
||
"message": f"执行中..." if execution.status == "running" else "等待执行"
|
||
}, websocket)
|
||
|
||
# 如果执行完成或失败,发送最终状态并断开
|
||
db.refresh(execution)
|
||
if execution.status in ["completed", "failed"]:
|
||
await websocket_manager.send_personal_message({
|
||
"type": "status",
|
||
"execution_id": execution_id,
|
||
"status": execution.status,
|
||
"progress": 100,
|
||
"result": execution.output_data if execution.status == "completed" else None,
|
||
"error": execution.error_message if execution.status == "failed" else None,
|
||
"execution_time": execution.execution_time
|
||
}, websocket)
|
||
|
||
await asyncio.sleep(1)
|
||
break
|
||
|
||
except WebSocketDisconnect:
|
||
pass
|
||
except Exception as e:
|
||
print(f"WebSocket错误: {e}")
|
||
try:
|
||
await websocket_manager.send_personal_message({
|
||
"type": "error",
|
||
"message": f"发生错误: {str(e)}"
|
||
}, websocket)
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
websocket_manager.disconnect(websocket, execution_id)
|
||
db.close()
|