trinity.trainer.trainer module#

Trainer Class

class trinity.trainer.trainer.Trainer(config: Config)[source]#

Bases: object

Consume the experience and train the model.

__init__(config: Config) None[source]#
async prepare() None[source]#

Prepare the trainer.

async get_weight_sync_info() Tuple[str, int, List] | None[source]#

Get rendezvous info for NCCL weight sync group setup.

Returns (master_address, master_port, state_dict_meta) from the trainer’s GPU worker rank 0. Called by Synchronizer before coordinating NCCL group creation.

async setup_weight_sync_group(master_address: str, master_port: int, world_size: int, group_name: str, timeout: int) None[source]#

Join the NCCL weight sync group. Called by Synchronizer.

async teardown_weight_sync_group() None[source]#

Destroy the NCCL weight sync group. Called by Synchronizer.

async train() str[source]#

Train the model.

async train_step(exps: List[Experience]) Dict[source]#

Train one step.

Returns:

Whether to continue training. Dict: Metrics of the training step.

Return type:

bool

async need_sync() bool[source]#

Whether to sync the model weight.

need_save() bool[source]#

Whether to save the checkpoint.

async sync_weight() Dict[source]#

Sync the model weight.

async save_checkpoint(block_until_saved: bool = False, save_as_hf: bool = False) Dict[source]#
async shutdown() None[source]#
property train_step_num: int#

Get the current training step number.

async is_alive() bool[source]#

Check if the trainer is alive.

classmethod get_actor(config: Config)[source]#

Get a Ray actor for the trainer.

class trinity.trainer.trainer.TrainEngineWrapper[source]#

Bases: ABC

A wrapper class to wrap various training engines.

abstractmethod async prepare() None[source]#

Do some preparation before training started.

abstract property train_step_num: int#

Get the current training step number.

abstractmethod async train_step(batch_exps: List[Experience]) Dict[source]#

Training one step.

Parameters:

batch_exps (List[Experience]) – A batch of experiences to train.

Returns:

Metrics of the training step.

Return type:

Dict

abstractmethod async save_checkpoint(block_until_saved: bool = False, save_as_hf: bool = False) None[source]#

Save the whole checkpoint (Including model, optimizer, and other states).

async wait_for_save() None[source]#

Wait for any pending background save operations to complete.

Default implementation is a no-op. Override in subclasses that use background save threads to ensure the checkpoint iteration file is written before the trainer exits.

abstractmethod sync_weight_nccl() None[source]#

Sync the model weight by NCCL. (For NCCL sync method)

abstractmethod async upload_state_dict() None[source]#

Upload the state dict to Synchronizer. (For MEMORY sync method)

abstractmethod async save_state_dict() None[source]#

Only save the model state dict for Synchronizer. (For CHECKPOINT sync method)

abstractmethod async get_weight_sync_info() Tuple[str, int, List] | None[source]#

Get (master_address, master_port, state_dict_meta) for NCCL group setup.

abstractmethod async setup_weight_sync_group(master_address: str, master_port: int, world_size: int, group_name: str, timeout: int) None[source]#

Join the NCCL weight sync group.

abstractmethod async teardown_weight_sync_group() None[source]#

Tear down the NCCL weight sync group.

trinity.trainer.trainer.is_verl_legacy() bool[source]#

Return True when the installed verl package is < 0.8 (legacy backend).

trinity.trainer.trainer.get_latest_hf_checkpoint_path(config: Config) str | None[source]#

Return the latest HF checkpoint path for a verl trainer config.

trinity.trainer.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper[source]#

Get a trainer wrapper.