trinity.trainer.verl.dp_actor module#

Single Process Actor. Modified from volcengine/verl

class trinity.trainer.verl.dp_actor.DataParallelPPOActor(config, actor_module: Module, actor_optimizer: Optimizer = None)[源代码]#

基类:DataParallelPPOActor

__init__(config, actor_module: Module, actor_optimizer: Optimizer = None)[源代码]#

When optimizer is None, it is Reference Policy

set_algorithm(algorithm_config: AlgorithmConfig)[源代码]#
update_policy(**kwargs)#

Update the policy with an iterator of DataProto

参数:

data (DataProto) -- an iterator over the DataProto that returns by `make_minibatch_iterator`

返回:

a dictionary contains anything. Typically, it contains the statistics during updating the model such as `loss`, `grad_norm`, etc,.

返回类型:

Dict

compute_log_prob(**kwargs)#

Compute logits given a batch of data.

参数:

data (DataProto) -- a batch of data represented by DataProto. It must contain key `input_ids`, `attention_mask` and `position_ids`.

返回:

a DataProto containing the key `log_probs`

返回类型:

DataProto