trinity.trainer.verl.checkpoint module#
Checkpoint coordination for verl trainer.
Provides CheckpointCoordinator β the single entry point for all checkpoint operations. It wraps the CheckpointMonitor Ray actor and background-thread management so that callers (VERLTrainer, workers, engine helpers) never need to interact with CheckpointMonitor directly.
- Flow:
Main thread collects state dicts (requires FSDP/Megatron context)
CheckpointCoordinator.save_async() offloads torch.save to a background thread
The thread calls notify_started β save β notify_finished on CheckpointMonitor
CheckpointMonitor only updates latest_state_dict_iteration.txt after ALL registered saves for a given step complete
- class trinity.trainer.verl.checkpoint.CheckpointCoordinator(checkpoint_monitor: ActorHandle)[source]#
Bases:
objectManages checkpoint saves and coordinates with CheckpointMonitor.
This is the only interface callers should use for checkpoint operations. CheckpointMonitor (a Ray actor) is an internal implementation detail β all Monitor RPCs are routed through this class.
Each save operation runs in a named background thread. Threads with the same name are serialized (the previous one is joined before a new one starts). The CheckpointMonitor is notified before and after each save, which gates the iteration file update.
Usage:
coordinator = CheckpointCoordinator(checkpoint_monitor) coordinator.save_async("model", lambda: torch.save(sd, path), step, is_state_dict=True) # ... training continues while save runs in background ... coordinator.wait_all() # block until all saves complete
- save_async(name: str, save_fn: Callable[[], None], global_step: int, is_state_dict: bool = False) None[source]#
Run save_fn in a background thread with CheckpointMonitor coordination.
Joins any previous thread with the same
namebefore starting, so concurrent writes to the same destination are impossible.- Parameters:
name β Logical name for this save slot (e.g. βmodel_state_dictβ).
save_fn β Zero-arg callable that does the actual I/O.
global_step β Training step, passed to CheckpointMonitor.
is_state_dict β True for state-dict-only saves (weight sync), False for full checkpoint components.
- save_sync(save_fn: Callable[[], None], global_step: int, is_state_dict: bool = False) None[source]#
Run save_fn synchronously with CheckpointMonitor notifications.
Use this for saves that must run on the main thread (e.g. because they involve distributed barriers).
- async register_and_monitor(global_step: int, is_state_dict: bool = False, state_dict_thread_count: int = 0, checkpoint_thread_count: int = 0) None[source]#
Register expected save thread counts and commit the step.
Called by VERLTrainer after dispatching save operations to workers. This wraps the two Monitor RPCs (register_thread_count + monitor_step) into a single call so the trainer never talks to Monitor directly.
- Parameters:
global_step β Training step.
is_state_dict β Whether this is a state-dict-only save.
state_dict_thread_count β Number of state dict save threads to expect.
checkpoint_thread_count β Number of full checkpoint save threads to expect.