trinity.trainer.verl.fsdp_engine module

trinity.trainer.verl.fsdp_engine module#

FSDP-specific checkpoint and weight sync helpers for Trinity.

These helper functions are called by TrinityActorRolloutRefWorker to perform FSDP/FSDP2-specific operations that veRL 0.8’s engine does not provide natively:

  • fsdp_save_state_dict: Save sharded state dict for checkpoint sync

  • fsdp_upload_state_dict: Upload full state dict to Synchronizer (memory sync)

  • fsdp_sync_weight_nccl: Broadcast full params via NCCL

All functions receive the engine object (an FSDPEngine instance from verl.workers.engine.fsdp) which exposes:

  • engine.module: the FSDP-wrapped model

  • engine.get_per_tensor_param(): generator of (name, tensor) pairs

  • engine.save_checkpoint() / engine.load_checkpoint()

trinity.trainer.verl.fsdp_engine.fsdp_save_state_dict(engine, local_path: str, global_step: int, coordinator: CheckpointCoordinator, logger)[source]#

Save FSDP model state dict (sharded) for checkpoint-based weight sync.

Collects the sharded state dict on the main thread (requires FSDP context), then offloads torch.save to a background thread via coordinator so the training loop can continue without waiting for I/O.

On rank 0, also saves HF model config and tokenizer to {local_path}/huggingface/ so that FSDPModelMerger (used by Synchronizer) can merge the shards back into a full state dict.

The coordinator notifies CheckpointMonitor before and after saving, which gates the iteration-file update and prevents the Synchronizer from reading an incomplete checkpoint.

Parameters:
  • engine – The FSDPEngine instance (engine.actor.engine).

  • local_path – Local directory path to save the state dict.

  • global_step – Current training step.

  • coordinator – CheckpointCoordinator for background save + Monitor integration.

  • logger – Logger instance from the calling worker.

trinity.trainer.verl.fsdp_engine.fsdp_upload_state_dict(engine, synchronizer, global_step: int, logger)[source]#

Upload full FSDP model state dict to Synchronizer for memory-based weight sync.

Gathers the full state dict on rank 0 and sends it to the Synchronizer actor, which makes it available for the Explorer to load.

Parameters:
  • engine – The FSDPEngine instance (engine.actor.engine).

  • synchronizer – The Synchronizer Ray actor handle.

  • global_step – Current training step (used as version key).

trinity.trainer.verl.fsdp_engine.fsdp_sync_weight_nccl(engine, model_update_group)[source]#

Broadcast full model parameters via NCCL for FSDP/FSDP2.

For FSDP1: Uses FSDP.summon_full_params to gather full parameters before broadcast. For FSDP2: Uses param.full_tensor() to get the full parameter.

Parameters:
  • engine – The FSDPEngine instance (engine.actor.engine).

  • model_update_group – The NCCL process group for weight broadcast.