""" WebSocket API """ from fastapi import APIRouter, WebSocket, WebSocketDisconnect from app.websocket.manager import websocket_manager from app.core.database import SessionLocal from app.models.execution import Execution from typing import Optional import json import asyncio router = APIRouter() def _get_progress_from_redis(execution_id: str) -> Optional[dict]: """从 Redis 读取进度数据。""" try: from app.core.redis_client import get_redis_client redis_client = get_redis_client() if redis_client: raw = redis_client.get(f"workflow:progress:{execution_id}") if raw: return json.loads(raw) except Exception: pass return None @router.websocket("/api/v1/ws/executions/{execution_id}") async def websocket_execution_status( websocket: WebSocket, execution_id: str, token: Optional[str] = None ): """ WebSocket实时推送执行状态 Args: websocket: WebSocket连接 execution_id: 执行记录ID token: JWT Token(可选,通过query参数传递) """ await websocket_manager.connect(websocket, execution_id) db = SessionLocal() try: # 发送初始状态 execution = db.query(Execution).filter(Execution.id == execution_id).first() if execution: # 尝试从 Redis 读取进度 redis_progress = _get_progress_from_redis(execution_id) if redis_progress: await websocket_manager.send_personal_message({ "type": "progress", "execution_id": execution_id, **redis_progress, }, websocket) else: await websocket_manager.send_personal_message({ "type": "status", "execution_id": execution_id, "status": execution.status, "progress": 0, "message": "连接已建立" }, websocket) else: await websocket_manager.send_personal_message({ "type": "error", "message": f"执行记录 {execution_id} 不存在" }, websocket) await websocket.close() return last_progress = -1 # 持续监听并推送状态更新 while True: try: # 接收客户端消息(心跳等),超时 1 秒以便轮询进度 try: data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0) try: message = json.loads(data) if message.get("type") == "ping": await websocket_manager.send_personal_message({ "type": "pong" }, websocket) except Exception: pass except asyncio.TimeoutError: pass # 超时后轮询进度 except WebSocketDisconnect: break # 优先从 Redis 读取进度(推送模式) redis_progress = _get_progress_from_redis(execution_id) if redis_progress: pct = redis_progress.get("progress", -1) if pct != last_progress: last_progress = pct await websocket_manager.send_personal_message({ "type": "progress", "execution_id": execution_id, **redis_progress, }, websocket) else: # Redis 不可用时回退到 DB 轮询 db.refresh(execution) pct = 100 if execution.status in ["completed", "failed"] else (50 if execution.status == "running" else 0) if pct != last_progress: last_progress = pct await websocket_manager.send_personal_message({ "type": "status", "execution_id": execution_id, "status": execution.status, "progress": pct, "message": f"执行中..." if execution.status == "running" else "等待执行" }, websocket) # 如果执行完成或失败,发送最终状态并断开 db.refresh(execution) if execution.status in ["completed", "failed"]: await websocket_manager.send_personal_message({ "type": "status", "execution_id": execution_id, "status": execution.status, "progress": 100, "result": execution.output_data if execution.status == "completed" else None, "error": execution.error_message if execution.status == "failed" else None, "execution_time": execution.execution_time }, websocket) await asyncio.sleep(1) break except WebSocketDisconnect: pass except Exception as e: print(f"WebSocket错误: {e}") try: await websocket_manager.send_personal_message({ "type": "error", "message": f"发生错误: {str(e)}" }, websocket) except Exception: pass finally: websocket_manager.disconnect(websocket, execution_id) db.close()