trinity.trainer.verl.dp_actor module#
Single Process Actor. Modified from volcengine/verl
- class trinity.trainer.verl.dp_actor.DataParallelPPOActor(config, actor_module: Module, actor_optimizer: Optimizer = None)[source]#
Bases:
DataParallelPPOActor- __init__(config, actor_module: Module, actor_optimizer: Optimizer = None)[source]#
When optimizer is None, it is Reference Policy
- set_algorithm(algorithm_config: AlgorithmConfig)[source]#
- update_policy(**kwargs)#
Update the policy with an iterator of DataProto
- Parameters:
data (DataProto) – an iterator over the DataProto that returns by
`make_minibatch_iterator`- Returns:
a dictionary contains anything. Typically, it contains the statistics during updating the model such as
`loss`,`grad_norm`, etc,.- Return type:
Dict
- compute_log_prob(**kwargs)#
Compute logits given a batch of data.
- Parameters:
data (DataProto) – a batch of data represented by DataProto. It must contain key
`input_ids`,`attention_mask`and`position_ids`.- Returns:
a DataProto containing the key
`log_probs`- Return type:
DataProto