Source code for trinity.trainer.verl.checkpoint

# -*- coding: utf-8 -*-
"""Checkpoint coordination for verl trainer.

Provides CheckpointCoordinator β€” the single entry point for all checkpoint
operations. It wraps the CheckpointMonitor Ray actor and background-thread
management so that callers (VERLTrainer, workers, engine helpers) never need
to interact with CheckpointMonitor directly.

Flow:
    1. Main thread collects state dicts (requires FSDP/Megatron context)
    2. CheckpointCoordinator.save_async() offloads torch.save to a background thread
    3. The thread calls notify_started β†’ save β†’ notify_finished on CheckpointMonitor
    4. CheckpointMonitor only updates latest_state_dict_iteration.txt after ALL
       registered saves for a given step complete
"""

import threading
from typing import Callable

import ray
from ray.actor import ActorHandle

from trinity.utils.log import get_logger


[docs] class CheckpointCoordinator: """Manages checkpoint saves and coordinates with CheckpointMonitor. This is the **only** interface callers should use for checkpoint operations. CheckpointMonitor (a Ray actor) is an internal implementation detail β€” all Monitor RPCs are routed through this class. Each save operation runs in a named background thread. Threads with the same name are serialized (the previous one is joined before a new one starts). The CheckpointMonitor is notified before and after each save, which gates the iteration file update. Usage:: coordinator = CheckpointCoordinator(checkpoint_monitor) coordinator.save_async("model", lambda: torch.save(sd, path), step, is_state_dict=True) # ... training continues while save runs in background ... coordinator.wait_all() # block until all saves complete """
[docs] def __init__(self, checkpoint_monitor: ActorHandle): self._monitor = checkpoint_monitor self._threads: dict[str, threading.Thread] = {} self.logger = get_logger(__name__)
# ------------------------------------------------------------------ # Worker-level: background save with Monitor notifications # ------------------------------------------------------------------
[docs] def save_async( self, name: str, save_fn: Callable[[], None], global_step: int, is_state_dict: bool = False, ) -> None: """Run save_fn in a background thread with CheckpointMonitor coordination. Joins any previous thread with the same ``name`` before starting, so concurrent writes to the same destination are impossible. Args: name: Logical name for this save slot (e.g. "model_state_dict"). save_fn: Zero-arg callable that does the actual I/O. global_step: Training step, passed to CheckpointMonitor. is_state_dict: True for state-dict-only saves (weight sync), False for full checkpoint components. """ self._join(name) def _run(): try: ctx = ray.get_runtime_context() ray.get( self._monitor.notify_started.remote( node_id=ctx.get_node_id(), job_id=ctx.get_job_id() ) ) save_fn() ray.get(self._monitor.notify_finished.remote(global_step, is_state_dict)) except Exception: self.logger.error( f"Background save '{name}' failed at step {global_step}", exc_info=True ) raise t = threading.Thread(target=_run, name=f"ckpt-{name}-step{global_step}") t.start() self._threads[name] = t
[docs] def save_sync( self, save_fn: Callable[[], None], global_step: int, is_state_dict: bool = False, ) -> None: """Run save_fn synchronously with CheckpointMonitor notifications. Use this for saves that must run on the main thread (e.g. because they involve distributed barriers). """ ctx = ray.get_runtime_context() ray.get( self._monitor.notify_started.remote(node_id=ctx.get_node_id(), job_id=ctx.get_job_id()) ) save_fn() ray.get(self._monitor.notify_finished.remote(global_step, is_state_dict))
def _join(self, name: str) -> None: t = self._threads.get(name) if t is not None: t.join()
[docs] def wait_all(self) -> None: """Block until all background save threads complete.""" for t in self._threads.values(): t.join() self._threads.clear()
# ------------------------------------------------------------------ # Trainer-level: register expected save counts and commit step # ------------------------------------------------------------------
[docs] async def register_and_monitor( self, global_step: int, is_state_dict: bool = False, state_dict_thread_count: int = 0, checkpoint_thread_count: int = 0, ) -> None: """Register expected save thread counts and commit the step. Called by VERLTrainer after dispatching save operations to workers. This wraps the two Monitor RPCs (register_thread_count + monitor_step) into a single call so the trainer never talks to Monitor directly. Args: global_step: Training step. is_state_dict: Whether this is a state-dict-only save. state_dict_thread_count: Number of state dict save threads to expect. checkpoint_thread_count: Number of full checkpoint save threads to expect. """ if state_dict_thread_count or checkpoint_thread_count: ray.get( self._monitor.register_thread_count.remote( global_step, state_dict_thread_count=state_dict_thread_count, checkpoint_thread_count=checkpoint_thread_count, ) ) await self._monitor.monitor_step.remote(global_step, is_state_dict=is_state_dict)