200 lines
7.6 KiB
Python
200 lines
7.6 KiB
Python
"""
|
||
工作流协作管理器
|
||
管理多人协作编辑工作流的WebSocket连接和消息同步
|
||
"""
|
||
from typing import Dict, Set, List, Optional
|
||
from fastapi import WebSocket
|
||
import json
|
||
import asyncio
|
||
from datetime import datetime
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class CollaborationManager:
|
||
"""工作流协作管理器"""
|
||
|
||
def __init__(self):
|
||
# workflow_id -> Set[WebSocket] 映射,存储每个工作流的连接
|
||
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||
# WebSocket -> user_info 映射,存储每个连接的用户信息
|
||
self.connection_users: Dict[WebSocket, Dict] = {}
|
||
# workflow_id -> Dict[user_id, user_info] 映射,存储每个工作流的在线用户
|
||
self.workflow_users: Dict[str, Dict[str, Dict]] = {}
|
||
# 操作锁,用于冲突解决
|
||
self.operation_locks: Dict[str, asyncio.Lock] = {}
|
||
|
||
async def connect(self, websocket: WebSocket, workflow_id: str, user_id: str, username: str):
|
||
"""建立协作连接"""
|
||
await websocket.accept()
|
||
|
||
if workflow_id not in self.active_connections:
|
||
self.active_connections[workflow_id] = set()
|
||
self.workflow_users[workflow_id] = {}
|
||
self.operation_locks[workflow_id] = asyncio.Lock()
|
||
|
||
self.active_connections[workflow_id].add(websocket)
|
||
|
||
user_info = {
|
||
"user_id": user_id,
|
||
"username": username,
|
||
"joined_at": datetime.now().isoformat(),
|
||
"color": self._get_user_color(user_id) # 为用户分配颜色
|
||
}
|
||
self.connection_users[websocket] = user_info
|
||
self.workflow_users[workflow_id][user_id] = user_info
|
||
|
||
# 通知其他用户有新用户加入
|
||
await self.broadcast_user_joined(workflow_id, user_info, exclude_websocket=websocket)
|
||
|
||
# 发送当前在线用户列表给新用户
|
||
await self.send_personal_message({
|
||
"type": "collaboration_init",
|
||
"workflow_id": workflow_id,
|
||
"current_user": user_info,
|
||
"online_users": list(self.workflow_users[workflow_id].values())
|
||
}, websocket)
|
||
|
||
logger.info(f"用户 {username} ({user_id}) 加入工作流 {workflow_id} 的协作编辑")
|
||
|
||
def disconnect(self, websocket: WebSocket, workflow_id: str):
|
||
"""断开协作连接"""
|
||
if workflow_id in self.active_connections:
|
||
self.active_connections[workflow_id].discard(websocket)
|
||
|
||
if websocket in self.connection_users:
|
||
user_info = self.connection_users[websocket]
|
||
user_id = user_info["user_id"]
|
||
|
||
# 从工作流用户列表中移除
|
||
if user_id in self.workflow_users[workflow_id]:
|
||
del self.workflow_users[workflow_id][user_id]
|
||
|
||
# 通知其他用户有用户离开
|
||
self.broadcast_user_left(workflow_id, user_id, exclude_websocket=websocket)
|
||
|
||
del self.connection_users[websocket]
|
||
logger.info(f"用户 {user_info.get('username')} ({user_id}) 离开工作流 {workflow_id} 的协作编辑")
|
||
|
||
# 如果没有连接了,清理资源
|
||
if not self.active_connections[workflow_id]:
|
||
del self.active_connections[workflow_id]
|
||
del self.workflow_users[workflow_id]
|
||
if workflow_id in self.operation_locks:
|
||
del self.operation_locks[workflow_id]
|
||
|
||
async def broadcast_operation(self, workflow_id: str, operation: Dict, exclude_websocket: Optional[WebSocket] = None):
|
||
"""广播操作到所有连接的客户端"""
|
||
if workflow_id not in self.active_connections:
|
||
return
|
||
|
||
message = {
|
||
"type": "operation",
|
||
"workflow_id": workflow_id,
|
||
"operation": operation,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
disconnected = set()
|
||
for websocket in self.active_connections[workflow_id]:
|
||
if websocket == exclude_websocket:
|
||
continue
|
||
|
||
try:
|
||
await websocket.send_json(message)
|
||
except Exception as e:
|
||
logger.warning(f"发送协作消息失败: {e}")
|
||
disconnected.add(websocket)
|
||
|
||
# 清理断开的连接
|
||
for ws in disconnected:
|
||
self.disconnect(ws, workflow_id)
|
||
|
||
async def broadcast_user_joined(self, workflow_id: str, user_info: Dict, exclude_websocket: Optional[WebSocket] = None):
|
||
"""广播用户加入消息"""
|
||
if workflow_id not in self.active_connections:
|
||
return
|
||
|
||
message = {
|
||
"type": "user_joined",
|
||
"workflow_id": workflow_id,
|
||
"user": user_info
|
||
}
|
||
|
||
disconnected = set()
|
||
for websocket in self.active_connections[workflow_id]:
|
||
if websocket == exclude_websocket:
|
||
continue
|
||
|
||
try:
|
||
await websocket.send_json(message)
|
||
except Exception as e:
|
||
logger.warning(f"发送用户加入消息失败: {e}")
|
||
disconnected.add(websocket)
|
||
|
||
for ws in disconnected:
|
||
self.disconnect(ws, workflow_id)
|
||
|
||
async def broadcast_user_left(self, workflow_id: str, user_id: str, exclude_websocket: Optional[WebSocket] = None):
|
||
"""广播用户离开消息"""
|
||
if workflow_id not in self.active_connections:
|
||
return
|
||
|
||
message = {
|
||
"type": "user_left",
|
||
"workflow_id": workflow_id,
|
||
"user_id": user_id
|
||
}
|
||
|
||
disconnected = set()
|
||
for websocket in self.active_connections[workflow_id]:
|
||
if websocket == exclude_websocket:
|
||
continue
|
||
|
||
try:
|
||
await websocket.send_json(message)
|
||
except Exception as e:
|
||
logger.warning(f"发送用户离开消息失败: {e}")
|
||
disconnected.add(websocket)
|
||
|
||
for ws in disconnected:
|
||
self.disconnect(ws, workflow_id)
|
||
|
||
async def send_personal_message(self, message: Dict, websocket: WebSocket):
|
||
"""发送个人消息"""
|
||
try:
|
||
await websocket.send_json(message)
|
||
except Exception as e:
|
||
logger.warning(f"发送个人消息失败: {e}")
|
||
|
||
def get_online_users(self, workflow_id: str) -> List[Dict]:
|
||
"""获取在线用户列表"""
|
||
if workflow_id not in self.workflow_users:
|
||
return []
|
||
return list(self.workflow_users[workflow_id].values())
|
||
|
||
def _get_user_color(self, user_id: str) -> str:
|
||
"""为用户分配颜色(基于用户ID的哈希)"""
|
||
colors = [
|
||
"#FF6B6B", "#4ECDC4", "#45B7D1", "#FFA07A", "#98D8C8",
|
||
"#F7DC6F", "#BB8FCE", "#85C1E2", "#F8B739", "#52BE80"
|
||
]
|
||
hash_value = hash(user_id) % len(colors)
|
||
return colors[hash_value]
|
||
|
||
async def acquire_lock(self, workflow_id: str):
|
||
"""获取操作锁(用于冲突解决)"""
|
||
if workflow_id not in self.operation_locks:
|
||
self.operation_locks[workflow_id] = asyncio.Lock()
|
||
return await self.operation_locks[workflow_id].acquire()
|
||
|
||
def release_lock(self, workflow_id: str):
|
||
"""释放操作锁"""
|
||
if workflow_id in self.operation_locks:
|
||
self.operation_locks[workflow_id].release()
|
||
|
||
|
||
# 全局协作管理器实例
|
||
collaboration_manager = CollaborationManager()
|