""" 任务系统核心 — 原子认领 + 依赖图 + Agent 状态管理 参考 Claude Code src/utils/tasks.ts 的设计模式: - claimTask 使用 DB 行锁(SELECT FOR UPDATE)实现原子认领 - block/blockedBy 双向依赖管理 - Agent busy/idle 状态跟踪 """ from sqlalchemy.orm import Session from sqlalchemy import and_ from typing import List, Optional, Dict, Any from datetime import datetime from dataclasses import dataclass, field from enum import Enum import logging from app.models.task import Task from app.core.exceptions import NotFoundError, ValidationError logger = logging.getLogger(__name__) # ──────────────────────────── 状态定义 ──────────────────────────── class TaskStatus(str, Enum): PENDING = "pending" IN_PROGRESS = "in_progress" AWAITING_APPROVAL = "awaiting_approval" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" class AgentStatus(str, Enum): IDLE = "idle" BUSY = "busy" @dataclass class ClaimResult: success: bool reason: Optional[str] = None # task_not_found / already_claimed / already_resolved / blocked / agent_busy task: Optional[Task] = None busy_with_tasks: List[str] = field(default_factory=list) # agent_busy 时 blocked_by_tasks: List[str] = field(default_factory=list) # blocked 时 @dataclass class AgentState: agent_id: str name: str = "" agent_type: str = "" status: AgentStatus = AgentStatus.IDLE current_tasks: List[str] = field(default_factory=list) # ──────────────────────────── 任务系统 ──────────────────────────── class TaskSystem: """任务认领与依赖管理系统""" def __init__(self, db: Session): self.db = db # ── 原子认领 ── def claim_task( self, task_id: str, agent_id: str, check_busy: bool = True, ) -> ClaimResult: """ 使用 SELECT FOR UPDATE 原子认领任务。 检查顺序: 1. 任务存在性 2. 未被其他 Agent 认领 3. 任务未完成 4. 所有 blockedBy 依赖已满足 5. (可选) Agent 不忙碌 """ # SELECT FOR UPDATE — 锁定行直到事务提交 task = ( self.db.query(Task) .filter(Task.id == task_id) .with_for_update() .first() ) if not task: return ClaimResult(success=False, reason="task_not_found") # 检查是否已被其他 Agent 认领 if task.owner and task.owner != agent_id: return ClaimResult(success=False, reason="already_claimed", task=task) # 检查是否已完成 if task.status == TaskStatus.COMPLETED.value: return ClaimResult(success=False, reason="already_resolved", task=task) # 检查依赖: blockedBy 中的任务必须全部完成 blocked_by = task.depends_on or [] if blocked_by: unresolved = ( self.db.query(Task) .filter( and_( Task.id.in_(blocked_by), Task.status != TaskStatus.COMPLETED.value, ) ) .all() ) if unresolved: return ClaimResult( success=False, reason="blocked", task=task, blocked_by_tasks=[t.id for t in unresolved], ) # 检查 Agent 是否忙碌 if check_busy: agent_open_tasks = ( self.db.query(Task) .filter( and_( Task.owner == agent_id, Task.status.in_([ TaskStatus.PENDING.value, TaskStatus.IN_PROGRESS.value, TaskStatus.AWAITING_APPROVAL.value, ]), Task.id != task_id, ) ) .all() ) if agent_open_tasks: return ClaimResult( success=False, reason="agent_busy", task=task, busy_with_tasks=[t.id for t in agent_open_tasks], ) # 认领 task.owner = agent_id task.status = TaskStatus.IN_PROGRESS.value if task.started_at is None: task.started_at = datetime.now() self.db.commit() self.db.refresh(task) logger.info(f"Task {task_id} claimed by agent {agent_id}") return ClaimResult(success=True, task=task) # ── 依赖管理 ── def block_task(self, from_task_id: str, to_task_id: str) -> bool: """ 设置任务依赖: from_task 阻塞 to_task。 即: to_task 依赖 from_task 完成后才能执行。 等价于: - from_task.blocks += to_task_id - to_task.depends_on += from_task_id """ from_task = self.db.query(Task).filter(Task.id == from_task_id).first() to_task = self.db.query(Task).filter(Task.id == to_task_id).first() if not from_task or not to_task: return False # 检测循环依赖: 如果 to_task 已经(直接或间接)被 from_task 依赖,则形成环 if self._would_create_cycle(from_task_id, to_task_id): raise ValidationError( f"无法设置依赖 {from_task_id} → {to_task_id}: 会产生循环依赖" ) # from_task 阻塞 to_task blocks = list(from_task.blocks or []) if to_task_id not in blocks: blocks.append(to_task_id) from_task.blocks = blocks # to_task 被 from_task 阻塞 depends = list(to_task.depends_on or []) if from_task_id not in depends: depends.append(from_task_id) to_task.depends_on = depends self.db.commit() logger.info(f"Task dependency set: {from_task_id} blocks {to_task_id}") return True def unblock_task(self, from_task_id: str, to_task_id: str) -> bool: """移除任务依赖""" from_task = self.db.query(Task).filter(Task.id == from_task_id).first() to_task = self.db.query(Task).filter(Task.id == to_task_id).first() if not from_task or not to_task: return False blocks = list(from_task.blocks or []) if to_task_id in blocks: blocks.remove(to_task_id) from_task.blocks = blocks depends = list(to_task.depends_on or []) if from_task_id in depends: depends.remove(from_task_id) to_task.depends_on = depends self.db.commit() return True def _would_create_cycle(self, from_id: str, to_id: str) -> bool: """检查 from → to 是否会产生循环依赖""" # 收集 to_task 直接和间接阻塞的所有任务 visited = set() stack = [to_id] while stack: current = stack.pop() if current == from_id: return True if current in visited: continue visited.add(current) task = self.db.query(Task).filter(Task.id == current).first() if task and task.blocks: for blocked_id in task.blocks: if blocked_id not in visited: stack.append(blocked_id) return False # ── Agent 状态 ── def get_agent_status(self, agent_id: str) -> AgentState: """获取 Agent 忙闲状态及当前持有的任务""" open_tasks = ( self.db.query(Task) .filter( and_( Task.owner == agent_id, Task.status.in_([ TaskStatus.PENDING.value, TaskStatus.IN_PROGRESS.value, TaskStatus.AWAITING_APPROVAL.value, ]), ) ) .all() ) task_ids = [t.id for t in open_tasks] status = AgentStatus.BUSY if task_ids else AgentStatus.IDLE return AgentState( agent_id=agent_id, status=status, current_tasks=task_ids, ) def get_all_agent_statuses(self, task_list_owner_ids: List[str]) -> List[AgentState]: """批量获取多个 Agent 的状态""" result = [] for agent_id in task_list_owner_ids: result.append(self.get_agent_status(agent_id)) return result # ── 任务释放 ── def unassign_agent_tasks( self, agent_id: str, ) -> List[Task]: """释放 Agent 持有的所有未完成任务(Agent 下线/终止时调用)""" open_tasks = ( self.db.query(Task) .filter( and_( Task.owner == agent_id, Task.status.in_([ TaskStatus.PENDING.value, TaskStatus.IN_PROGRESS.value, TaskStatus.AWAITING_APPROVAL.value, ]), ) ) .with_for_update() .all() ) unassigned = [] for task in open_tasks: task.owner = None task.status = TaskStatus.PENDING.value unassigned.append(task) logger.info(f"Task {task.id} unassigned from agent {agent_id}") if unassigned: self.db.commit() return unassigned def release_task(self, task_id: str, agent_id: str) -> bool: """释放单个任务(Agent 主动放弃)""" task = ( self.db.query(Task) .filter(Task.id == task_id) .with_for_update() .first() ) if not task or task.owner != agent_id: return False task.owner = None task.status = TaskStatus.PENDING.value self.db.commit() logger.info(f"Task {task_id} released by agent {agent_id}") return True # ── 任务完成/失败 ── def complete_task(self, task_id: str, result: Optional[Dict[str, Any]] = None) -> Optional[Task]: """标记任务完成""" task = self.db.query(Task).filter(Task.id == task_id).first() if not task: return None task.status = TaskStatus.COMPLETED.value task.result = result or {} task.completed_at = datetime.now() self.db.commit() self.db.refresh(task) # 检查被此任务阻塞的任务是否现在可以执行 self._check_unblocked_tasks(task) return task def fail_task(self, task_id: str, error_message: str = "") -> Optional[Task]: """标记任务失败""" task = self.db.query(Task).filter(Task.id == task_id).first() if not task: return None task.status = TaskStatus.FAILED.value task.error_message = error_message task.completed_at = datetime.now() self.db.commit() self.db.refresh(task) return task def _check_unblocked_tasks(self, completed_task: Task) -> None: """检查被已完成任务阻塞的任务是否已解除阻塞""" blocks = completed_task.blocks or [] for blocked_id in blocks: blocked_task = self.db.query(Task).filter(Task.id == blocked_id).first() if not blocked_task: continue # 检查 blocked_task 的所有依赖是否都已满足 deps = blocked_task.depends_on or [] all_deps_met = True for dep_id in deps: dep = self.db.query(Task).filter(Task.id == dep_id).first() if dep and dep.status != TaskStatus.COMPLETED.value: all_deps_met = False break if all_deps_met and blocked_task.status == TaskStatus.PENDING.value: logger.info( f"Task {blocked_id} is now unblocked (all dependencies met)" ) # ── 查询辅助 ── def get_unresolved_blockers(self, task_id: str) -> List[Task]: """获取某个任务尚未完成的阻塞任务""" task = self.db.query(Task).filter(Task.id == task_id).first() if not task or not task.depends_on: return [] return ( self.db.query(Task) .filter( and_( Task.id.in_(task.depends_on), Task.status != TaskStatus.COMPLETED.value, ) ) .all() ) def get_next_available_tasks(self, goal_id: str, limit: int = 10) -> List[Task]: """获取下一个可执行的任务(依赖已满足、未被认领)""" # 获取 goal 下所有任务 all_tasks = ( self.db.query(Task) .filter(Task.goal_id == goal_id) .all() ) available = [] for task in all_tasks: if task.status != TaskStatus.PENDING.value: continue if task.owner is not None: continue # 检查依赖 deps = task.depends_on or [] blocked = False for dep_id in deps: dep = next((t for t in all_tasks if t.id == dep_id), None) if dep and dep.status != TaskStatus.COMPLETED.value: blocked = True break if not blocked: available.append(task) if len(available) >= limit: break return available