trinity.trainer.trainer module#
Trainer Class
- class trinity.trainer.trainer.Trainer(config: Config)[source]#
Bases:
objectConsume the experience and train the model.
- async get_weight_sync_info() Tuple[str, int, List] | None[source]#
Get rendezvous info for NCCL weight sync group setup.
Returns (master_address, master_port, state_dict_meta) from the trainer’s GPU worker rank 0. Called by Synchronizer before coordinating NCCL group creation.
- async setup_weight_sync_group(master_address: str, master_port: int, world_size: int, group_name: str, timeout: int) None[source]#
Join the NCCL weight sync group. Called by Synchronizer.
- async teardown_weight_sync_group() None[source]#
Destroy the NCCL weight sync group. Called by Synchronizer.
- async train_step(exps: List[Experience]) Dict[source]#
Train one step.
- Returns:
Whether to continue training. Dict: Metrics of the training step.
- Return type:
bool
- property train_step_num: int#
Get the current training step number.
- class trinity.trainer.trainer.TrainEngineWrapper[source]#
Bases:
ABCA wrapper class to wrap various training engines.
- abstract property train_step_num: int#
Get the current training step number.
- abstractmethod async train_step(batch_exps: List[Experience]) Dict[source]#
Training one step.
- Parameters:
batch_exps (List[Experience]) – A batch of experiences to train.
- Returns:
Metrics of the training step.
- Return type:
Dict
- abstractmethod async save_checkpoint(block_until_saved: bool = False, save_as_hf: bool = False) None[source]#
Save the whole checkpoint (Including model, optimizer, and other states).
- async wait_for_save() None[source]#
Wait for any pending background save operations to complete.
Default implementation is a no-op. Override in subclasses that use background save threads to ensure the checkpoint iteration file is written before the trainer exits.
- abstractmethod sync_weight_nccl() None[source]#
Sync the model weight by NCCL. (For NCCL sync method)
- abstractmethod async upload_state_dict() None[source]#
Upload the state dict to Synchronizer. (For MEMORY sync method)
- abstractmethod async save_state_dict() None[source]#
Only save the model state dict for Synchronizer. (For CHECKPOINT sync method)
- abstractmethod async get_weight_sync_info() Tuple[str, int, List] | None[source]#
Get (master_address, master_port, state_dict_meta) for NCCL group setup.
- trinity.trainer.trainer.is_verl_legacy() bool[source]#
Return True when the installed verl package is < 0.8 (legacy backend).
- trinity.trainer.trainer.get_latest_hf_checkpoint_path(config: Config) str | None[source]#
Return the latest HF checkpoint path for a verl trainer config.
- trinity.trainer.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper[source]#
Get a trainer wrapper.