trinity.trainer.verl.monkey_patch 源代码

from types import MethodType
from typing import Optional

import torch
from verl.workers.engine.fsdp.transformer_impl import (
    FSDPEngine,
    load_fsdp_model_to_gpu,
    offload_fsdp_model_to_cpu,
)


# from https://github.com/verl-project/verl/pull/6604
# Remove this patch once the fix is released in veRL
[文档] def save_checkpoint( self, local_path: str, hdfs_path: Optional[str] = None, global_step: int = 0, max_ckpt_to_keep: Optional[int] = None, **kwargs, ) -> None: """ Save FSDP checkpoint, handling parameter offload as needed. """ origin_module_device = next(self.module.parameters()).device.type if (self._is_offload_param or origin_module_device == "cpu") and not getattr( self, "_uses_fsdp2_cpu_offload_policy", False ): load_fsdp_model_to_gpu(self.module) self.checkpoint_manager.save_checkpoint( local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep, ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.module)
[文档] def patch_verl_engine(engine): if engine is None: return if isinstance(engine, FSDPEngine) and not getattr(engine, "_patched", False): engine.save_checkpoint = MethodType(save_checkpoint, engine) setattr(engine, "_patched", True)