"""Rollout coordinator for async batch submission and finalize."""
import asyncio
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
import ray
from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline
from trinity.common.config import Config
from trinity.common.models import InferenceModel
from trinity.common.workflows import Task
from trinity.explorer.scheduler import Scheduler
from trinity.utils.log import get_logger
from trinity.utils.monitor import gather_eval_metrics, gather_metrics
BatchId = Union[int, str]
BatchType = Literal["train", "eval"]
[文档]
class BatchLifecycleState(str, Enum):
"""Lifecycle states for one submitted batch."""
PENDING = "pending"
RUNNING = "running"
FINALIZING = "finalizing"
FINALIZED = "finalized"
ABORTED = "aborted"
[文档]
@dataclass
class BatchState:
"""In-memory state tracked for one train or eval batch."""
batch_id: BatchId
batch_type: BatchType
expected_task_count: int
statuses: Dict[Union[int, str], Any] = field(default_factory=dict)
min_wait_num: Optional[int] = None
state: BatchLifecycleState = BatchLifecycleState.PENDING
final_result: Optional[dict] = None
finalize_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
@property
def completed_task_count(self) -> int:
"""Return the number of completed tasks tracked by status."""
return len(self.statuses)
[文档]
class RolloutCoordinator:
"""Own scheduler-side batch state and expose batch-level finalize APIs."""
[文档]
def __init__(
self,
config: Config,
rollout_model: List[InferenceModel],
auxiliary_models: Optional[List[List[InferenceModel]]] = None,
):
"""Create a coordinator with internally managed scheduler and pipeline."""
self.logger = get_logger(f"{config.explorer.name}_rollout_coordinator", in_ray_actor=True)
self.config = config
self.rollout_model = rollout_model
self.auxiliary_models = auxiliary_models or []
self.experience_pipeline = None
self.scheduler: Optional[Scheduler] = None
self.pending_batches: Dict[BatchId, BatchState] = {}
self.running = False
self.detailed_stats = getattr(getattr(config, "monitor", None), "detailed_stats", False)
[文档]
async def prepare(self) -> None:
"""Initialize the owned pipeline and scheduler."""
if self.running:
return
if self.experience_pipeline is None:
await self._init_experience_pipeline()
if self.scheduler is None:
await self._init_scheduler()
self.running = True
[文档]
async def shutdown(self) -> None:
"""Stop background work and close owned dependencies."""
self.running = False
if self.scheduler is not None:
await self.scheduler.stop()
self.scheduler = None
if self.experience_pipeline is not None:
await self.experience_pipeline.close()
self.experience_pipeline = None
async def _init_experience_pipeline(self):
"""Create the experience pipeline owned by this coordinator actor."""
if self.config.mode == "bench":
return None
self.experience_pipeline = ExperiencePipeline(self.config)
await self.experience_pipeline.prepare()
async def _init_scheduler(self):
"""Create the scheduler owned by this coordinator."""
if self.config.mode == "serve":
return
self.scheduler = Scheduler(
self.config,
self.rollout_model,
self.auxiliary_models,
)
await self.scheduler.start()
def _require_scheduler(self) -> Scheduler:
"""Return the initialized scheduler."""
assert self.scheduler is not None, "RolloutCoordinator.prepare() must be called first."
return self.scheduler
[文档]
async def submit_batch(
self,
*,
batch_id: BatchId,
tasks: list[Task],
batch_type: BatchType,
min_wait_num: Optional[int] = None,
) -> None:
"""Register a new batch and schedule its tasks."""
existing_state = self.pending_batches.get(batch_id)
if existing_state is not None and existing_state.state not in {
BatchLifecycleState.FINALIZED,
BatchLifecycleState.ABORTED,
}:
raise ValueError(f"Batch {batch_id} is already active.")
batch_state = BatchState(
batch_id=batch_id,
batch_type=batch_type,
expected_task_count=len(tasks),
min_wait_num=min_wait_num,
)
self.pending_batches[batch_id] = batch_state
if tasks:
self._require_scheduler().schedule(tasks, batch_id=batch_id)
batch_state.state = BatchLifecycleState.RUNNING
[文档]
async def finalize_train_batch(
self,
batch_id: int,
*,
timeout: Optional[float] = None,
) -> dict:
"""Finalize one train batch and return aggregated metrics."""
batch_state = self._get_batch_state(batch_id, expected_type="train")
return await self._finalize_train_batch(batch_state, timeout=timeout)
[文档]
async def finalize_eval_batch(
self,
batch_id: str,
*,
timeout: Optional[float] = None,
) -> dict:
"""Finalize one eval batch and return aggregated eval metrics."""
batch_state = self._get_batch_state(batch_id, expected_type="eval")
return await self._finalize_eval_batch(batch_state, timeout=timeout)
async def _finalize_eval_batch(
self, batch_state: BatchState, *, timeout: Optional[float]
) -> dict:
"""Finalize one eval batch."""
scheduler = self._require_scheduler()
async with batch_state.finalize_lock:
existing_result = self._get_existing_final_result(batch_state)
if existing_result is not None:
return existing_result
statuses = await scheduler.get_statuses(
batch_id=batch_state.batch_id,
timeout=timeout,
return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks,
)
for task_id, status in enumerate(statuses):
if task_id in batch_state.statuses:
continue
batch_state.statuses[task_id] = status
return self._finish_batch(batch_state, pipeline_metrics={})
[文档]
async def abort_batch(
self,
batch_id: BatchId,
*,
reason: str,
keep_partial_results: bool = False,
) -> None:
"""Abort one batch and cleanup its running and staged state."""
scheduler = self._require_scheduler()
batch_state = self.pending_batches.get(batch_id)
if batch_state is None:
return
if batch_state.state in {BatchLifecycleState.FINALIZED, BatchLifecycleState.ABORTED}:
return
self.logger.warning("Abort batch %s: %s", batch_id, reason)
await scheduler.abort_batch(
batch_id,
return_partial_tasks=keep_partial_results,
restart_runners=True,
)
scheduler.discard_completed_results(batch_id)
batch_state.state = BatchLifecycleState.ABORTED
batch_state.final_result = self._build_batch_result(batch_state, pipeline_metrics={})
self.pending_batches.pop(batch_id, None)
[文档]
async def process_experiences(self, payloads: list[bytes]) -> dict:
"""Process one batch of experience payloads through the pipeline."""
if self.experience_pipeline is None:
raise RuntimeError("Experience pipeline is not initialized.")
if not payloads:
return {}
return await self.experience_pipeline.process_serialized_chunks(payloads)
[文档]
@classmethod
def get_actor(
cls, config: Config, models: List, auxiliary_models: List
) -> ray.actor.ActorHandle:
"""Init rollout coordinator for the task-event-completion path."""
return (
ray.remote(RolloutCoordinator)
.options(namespace=config.ray_namespace)
.remote(
config,
models,
auxiliary_models,
)
)
def _get_batch_state(self, batch_id: BatchId, *, expected_type: BatchType) -> BatchState:
"""Return one registered batch and validate its type."""
batch_state = self.pending_batches.get(batch_id)
if batch_state is None:
raise KeyError(f"Batch {batch_id} is not registered.")
if batch_state.batch_type != expected_type:
raise ValueError(
f"Batch {batch_id} is {batch_state.batch_type}, expected {expected_type}."
)
return batch_state
def _get_existing_final_result(self, batch_state: BatchState) -> Optional[dict]:
"""Reuse an in-flight final result or synthesize an abort result."""
if batch_state.final_result is not None:
return dict(batch_state.final_result)
if batch_state.state != BatchLifecycleState.ABORTED:
return None
batch_state.final_result = self._build_batch_result(batch_state, pipeline_metrics={})
return dict(batch_state.final_result)
async def _finalize_train_batch(
self, batch_state: BatchState, *, timeout: Optional[float]
) -> dict:
"""Finalize one train batch."""
async with batch_state.finalize_lock:
existing_result = self._get_existing_final_result(batch_state)
if existing_result is not None:
return existing_result
scheduler = self._require_scheduler()
scheduled_num = batch_state.expected_task_count
statuses, payload_chunks = await scheduler.get_payload_results(
batch_id=batch_state.batch_id,
min_num=batch_state.min_wait_num,
timeout=timeout,
clear_timeout_tasks=False,
return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks,
)
completed_count = len(statuses)
if scheduled_num == 0:
is_complete = True
else:
if completed_count == 0:
raise TimeoutError(f"Timeout waiting for batch {batch_state.batch_id}.")
if batch_state.min_wait_num is None and completed_count < scheduled_num:
raise TimeoutError(f"Timeout waiting for batch {batch_state.batch_id}.")
batch_state.statuses = {task_id: status for task_id, status in enumerate(statuses)}
is_complete = completed_count >= scheduled_num
batch_state.state = BatchLifecycleState.FINALIZING
try:
pipeline_metrics = await self.process_experiences(payload_chunks)
if not is_complete:
await self._cleanup_train_batch_runtime(batch_state)
except Exception:
batch_state.state = self._get_active_batch_state(batch_state)
raise
return self._finish_batch(batch_state, pipeline_metrics=pipeline_metrics)
def _finish_batch(
self,
batch_state: BatchState,
pipeline_metrics: dict,
) -> dict:
"""Persist one terminal result and evict the batch from active state."""
self._require_scheduler().discard_completed_results(batch_state.batch_id)
batch_state.state = BatchLifecycleState.FINALIZED
batch_state.final_result = self._build_batch_result(batch_state, pipeline_metrics)
self.pending_batches.pop(batch_state.batch_id, None)
return dict(batch_state.final_result)
def _get_active_batch_state(self, batch_state: BatchState) -> BatchLifecycleState:
"""Return the active lifecycle state to restore after a failed finalize attempt."""
if batch_state.expected_task_count == 0:
return BatchLifecycleState.PENDING
return BatchLifecycleState.RUNNING
async def _cleanup_train_batch_runtime(self, batch_state: BatchState) -> None:
"""Drop unfinished train work after a non-complete finalize result."""
scheduler = self._require_scheduler()
await scheduler.abort_batch(
batch_state.batch_id,
return_partial_tasks=False,
restart_runners=True,
)
def _build_batch_result(
self,
batch_state: BatchState,
pipeline_metrics: dict,
) -> dict:
"""Build the public finalize result returned to Explorer."""
metrics = dict(pipeline_metrics)
status_metrics = [
status.metrics[0] for status in batch_state.statuses.values() if status.metrics
]
if batch_state.batch_type == "train":
if status_metrics:
metrics.update(gather_metrics(status_metrics, "rollout"))
metrics["rollout/finished_task_count"] = float(batch_state.completed_task_count)
else:
prefix = self._eval_metric_prefix(batch_state.batch_id)
if status_metrics:
metrics.update(
gather_eval_metrics(
status_metrics,
prefix,
detailed_stats=self.detailed_stats,
)
)
metrics[f"{prefix}/finished_task_count"] = float(batch_state.completed_task_count)
return {
"batch_id": batch_state.batch_id,
"batch_type": batch_state.batch_type,
"finished_task_count": batch_state.completed_task_count,
"metrics": metrics,
}
def _eval_metric_prefix(self, batch_id: BatchId) -> str:
"""Return the metric namespace prefix for one eval batch."""
batch_name = str(batch_id)
if "/" in batch_name:
batch_name = batch_name.split("/", 1)[1]
return f"eval/{batch_name}"