Source code for trinity.trainer.verl.megatron_engine

# -*- coding: utf-8 -*-
"""Megatron-specific checkpoint and weight sync helpers for Trinity.

These helper functions are called by `TrinityActorRolloutRefWorker` to
perform Megatron-specific operations that veRL 0.8's engine does not
provide natively:

- megatron_save_state_dict:    Save state dict for checkpoint sync
- megatron_upload_state_dict:  Upload state dict to Synchronizer (memory sync)
- megatron_sync_weight_nccl:   Broadcast params via NCCL

All functions receive the `engine` object (a McoreEngine instance from
`verl.workers.engine.megatron`) which exposes:
  - engine.module: the Megatron model
  - engine.get_per_tensor_param(): generator of (name, tensor) pairs
  - engine.save_checkpoint() / engine.load_checkpoint()
"""
import ray
import torch
from verl.utils.memory_utils import aggressive_empty_cache

from trinity.trainer.verl.checkpoint import CheckpointCoordinator


[docs] def megatron_save_state_dict( engine, local_path: str, global_step: int, coordinator: CheckpointCoordinator, logger, ): """Save Megatron actor model state dict for checkpoint-based weight sync. Delegates to the engine's built-in ``save_checkpoint`` for proper distributed checkpoint handling, but temporarily restricts the checkpoint manager's ``checkpoint_save_contents`` to ``["model"]`` so that only the actor model parameters are written to disk -- optimizer and extra (rng / scheduler) states are skipped, since this path is used only for weight-sync to the rollout side, not for training resume. The save is wrapped with CheckpointMonitor notifications via the coordinator. Note: Megatron's save_checkpoint involves distributed barriers internally, so it runs synchronously. The coordinator's ``save_sync`` is used to add Monitor notifications without background threading. Args: engine: The McoreEngine instance (engine.actor.engine). local_path: Local directory path to save the state dict. global_step: Current training step. coordinator: CheckpointCoordinator for Monitor integration. logger: Logger instance from the calling worker. """ if local_path is None: return # Temporarily restrict checkpoint contents to model-only. The veRL # Megatron engine attribute is ``checkpoint_mananager`` (sic, upstream typo). ckpt_mgr = getattr(engine, "checkpoint_mananager", None) or getattr( engine, "checkpoint_manager", None ) original_save_contents = None if ckpt_mgr is not None and hasattr(ckpt_mgr, "checkpoint_save_contents"): original_save_contents = list(ckpt_mgr.checkpoint_save_contents) # Preserve hf_model export if it was originally enabled, but drop # optimizer / extra so we don't pay their cost on every weight sync. new_contents = ["model"] if "hf_model" in original_save_contents: new_contents.append("hf_model") ckpt_mgr.checkpoint_save_contents = new_contents def _do_save(): engine.save_checkpoint(local_path=local_path, global_step=global_step) try: if torch.distributed.get_rank() == 0: coordinator.save_sync(_do_save, global_step, is_state_dict=True) else: _do_save() finally: if original_save_contents is not None: ckpt_mgr.checkpoint_save_contents = original_save_contents torch.distributed.barrier() logger.info( f"[Megatron] actor state_dict saved: path={local_path}, step={global_step} " f"(optimizer/extra skipped)" )
[docs] def megatron_upload_state_dict(engine, synchronizer, global_step: int, logger): """Upload Megatron model state dict to Synchronizer for memory-based weight sync. Iterates over per-tensor parameters and collects them on rank 0, then sends the full state dict to the Synchronizer actor. Args: engine: The McoreEngine instance (engine.actor.engine). synchronizer: The Synchronizer Ray actor handle. global_step: Current training step (used as version key). """ if global_step == 0: return aggressive_empty_cache(force_sync=True) state_dict = {} per_tensor_param, _ = engine.get_per_tensor_param() for name, weight in per_tensor_param: if torch.distributed.get_rank() == 0: state_dict[name] = weight.cpu().detach() del weight if torch.distributed.get_rank() == 0: ray.get(synchronizer.set_model_state_dict.remote(state_dict, global_step)) torch.distributed.barrier() torch.cuda.empty_cache() logger.info(f"[Megatron] state_dict uploaded to Synchronizer: step={global_step}")
[docs] def megatron_sync_weight_nccl(engine, model_update_group): """Broadcast Megatron model parameters via NCCL. Uses the engine's get_per_tensor_param() to iterate over parameters and broadcasts each from rank 0. Args: engine: The McoreEngine instance (engine.actor.engine). model_update_group: The NCCL process group for weight broadcast. """ per_tensor_param, _ = engine.get_per_tensor_param() for _, param in per_tensor_param: torch.distributed.broadcast(param, src=0, group=model_update_group) if torch.distributed.get_rank() == 0: torch.cuda.synchronize()