trinity.trainer package#
Subpackages#
- trinity.trainer.tinker package
- trinity.trainer.verl package
- Submodules
- trinity.trainer.verl.checkpoint module
- trinity.trainer.verl.config module
- trinity.trainer.verl.fsdp_engine module
- trinity.trainer.verl.losses module
- trinity.trainer.verl.megatron_engine module
- trinity.trainer.verl.monkey_patch module
- trinity.trainer.verl.trainer module
- trinity.trainer.verl.utils module
- trinity.trainer.verl.workers module
- Module contents
- Submodules
- trinity.trainer.verl_legacy package
- Submodules
- trinity.trainer.verl_legacy.dp_actor module
- trinity.trainer.verl_legacy.fsdp_checkpoint_manager module
- trinity.trainer.verl_legacy.fsdp_workers module
- trinity.trainer.verl_legacy.megatron_actor module
- trinity.trainer.verl_legacy.megatron_checkpoint_manager module
- trinity.trainer.verl_legacy.megatron_workers module
- trinity.trainer.verl_legacy.monkey_patch module
- trinity.trainer.verl_legacy.utils module
- trinity.trainer.verl_legacy.verl_config module
- trinity.trainer.verl_legacy.verl_trainer module
- Module contents
- Submodules
Submodules#
- trinity.trainer.trainer module
TrainerTrainer.__init__()Trainer.prepare()Trainer.get_weight_sync_info()Trainer.setup_weight_sync_group()Trainer.teardown_weight_sync_group()Trainer.train()Trainer.train_step()Trainer.need_sync()Trainer.need_save()Trainer.sync_weight()Trainer.save_checkpoint()Trainer.shutdown()Trainer.train_step_numTrainer.is_alive()Trainer.get_actor()
TrainEngineWrapperTrainEngineWrapper.prepare()TrainEngineWrapper.train_step_numTrainEngineWrapper.train_step()TrainEngineWrapper.save_checkpoint()TrainEngineWrapper.wait_for_save()TrainEngineWrapper.sync_weight_nccl()TrainEngineWrapper.upload_state_dict()TrainEngineWrapper.save_state_dict()TrainEngineWrapper.get_weight_sync_info()TrainEngineWrapper.setup_weight_sync_group()TrainEngineWrapper.teardown_weight_sync_group()
is_verl_legacy()get_latest_hf_checkpoint_path()get_trainer_wrapper()
Module contents#
- class trinity.trainer.Trainer(config: Config)[源代码]#
基类:
objectConsume the experience and train the model.
- async get_weight_sync_info() Tuple[str, int, List] | None[源代码]#
Get rendezvous info for NCCL weight sync group setup.
Returns (master_address, master_port, state_dict_meta) from the trainer's GPU worker rank 0. Called by Synchronizer before coordinating NCCL group creation.
- async setup_weight_sync_group(master_address: str, master_port: int, world_size: int, group_name: str, timeout: int) None[源代码]#
Join the NCCL weight sync group. Called by Synchronizer.
- async teardown_weight_sync_group() None[源代码]#
Destroy the NCCL weight sync group. Called by Synchronizer.
- async train_step(exps: List[Experience]) Dict[源代码]#
Train one step.
- 返回:
Whether to continue training. Dict: Metrics of the training step.
- 返回类型:
bool
- property train_step_num: int#
Get the current training step number.
- class trinity.trainer.TrainEngineWrapper[源代码]#
基类:
ABCA wrapper class to wrap various training engines.
- abstractmethod async get_weight_sync_info() Tuple[str, int, List] | None[源代码]#
Get (master_address, master_port, state_dict_meta) for NCCL group setup.
- abstractmethod async save_checkpoint(block_until_saved: bool = False, save_as_hf: bool = False) None[源代码]#
Save the whole checkpoint (Including model, optimizer, and other states).
- abstractmethod async save_state_dict() None[源代码]#
Only save the model state dict for Synchronizer. (For CHECKPOINT sync method)
- abstractmethod async setup_weight_sync_group(master_address: str, master_port: int, world_size: int, group_name: str, timeout: int) None[源代码]#
Join the NCCL weight sync group.
- abstractmethod async train_step(batch_exps: List[Experience]) Dict[源代码]#
Training one step.
- 参数:
batch_exps (List[Experience]) -- A batch of experiences to train.
- 返回:
Metrics of the training step.
- 返回类型:
Dict
- abstract property train_step_num: int#
Get the current training step number.
- trinity.trainer.get_latest_hf_checkpoint_path(config: Config) str | None[源代码]#
Return the latest HF checkpoint path for a verl trainer config.
- trinity.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper[源代码]#
Get a trainer wrapper.