trinity.trainer.verl.checkpoint module#

Checkpoint coordination for verl trainer.

Provides CheckpointCoordinator β€” the single entry point for all checkpoint operations. It wraps the CheckpointMonitor Ray actor and background-thread management so that callers (VERLTrainer, workers, engine helpers) never need to interact with CheckpointMonitor directly.

Flow:
  1. Main thread collects state dicts (requires FSDP/Megatron context)

  2. CheckpointCoordinator.save_async() offloads torch.save to a background thread

  3. The thread calls notify_started β†’ save β†’ notify_finished on CheckpointMonitor

  4. CheckpointMonitor only updates latest_state_dict_iteration.txt after ALL registered saves for a given step complete

class trinity.trainer.verl.checkpoint.CheckpointCoordinator(checkpoint_monitor: ActorHandle)[source]#

Bases: object

Manages checkpoint saves and coordinates with CheckpointMonitor.

This is the only interface callers should use for checkpoint operations. CheckpointMonitor (a Ray actor) is an internal implementation detail β€” all Monitor RPCs are routed through this class.

Each save operation runs in a named background thread. Threads with the same name are serialized (the previous one is joined before a new one starts). The CheckpointMonitor is notified before and after each save, which gates the iteration file update.

Usage:

coordinator = CheckpointCoordinator(checkpoint_monitor)
coordinator.save_async("model", lambda: torch.save(sd, path), step, is_state_dict=True)
# ... training continues while save runs in background ...
coordinator.wait_all()  # block until all saves complete
__init__(checkpoint_monitor: ActorHandle)[source]#
save_async(name: str, save_fn: Callable[[], None], global_step: int, is_state_dict: bool = False) None[source]#

Run save_fn in a background thread with CheckpointMonitor coordination.

Joins any previous thread with the same name before starting, so concurrent writes to the same destination are impossible.

Parameters:
  • name – Logical name for this save slot (e.g. β€œmodel_state_dict”).

  • save_fn – Zero-arg callable that does the actual I/O.

  • global_step – Training step, passed to CheckpointMonitor.

  • is_state_dict – True for state-dict-only saves (weight sync), False for full checkpoint components.

save_sync(save_fn: Callable[[], None], global_step: int, is_state_dict: bool = False) None[source]#

Run save_fn synchronously with CheckpointMonitor notifications.

Use this for saves that must run on the main thread (e.g. because they involve distributed barriers).

wait_all() None[source]#

Block until all background save threads complete.

async register_and_monitor(global_step: int, is_state_dict: bool = False, state_dict_thread_count: int = 0, checkpoint_thread_count: int = 0) None[source]#

Register expected save thread counts and commit the step.

Called by VERLTrainer after dispatching save operations to workers. This wraps the two Monitor RPCs (register_thread_count + monitor_step) into a single call so the trainer never talks to Monitor directly.

Parameters:
  • global_step – Training step.

  • is_state_dict – Whether this is a state-dict-only save.

  • state_dict_thread_count – Number of state dict save threads to expect.

  • checkpoint_thread_count – Number of full checkpoint save threads to expect.