Source code for trinity.trainer.verl.dp_actor

# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Single Process Actor.
Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/actor/dp_actor.py
"""

import logging
import os

import torch
import verl.utils.torch_functional as verl_F
from torch import nn
from verl import DataProto
from verl.utils.attention_utils import (
    index_first_axis,
    pad_input,
    rearrange,
    unpad_input,
)
from verl.utils.debug import GPUMemoryLogger
from verl.utils.device import get_device_id
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import (
    gather_outputs_and_unpad,
    ulysses_pad,
    ulysses_pad_and_slice_inputs,
)
from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor

from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN
from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn
from trinity.algorithm.kl_fn.kl_fn import DummyKLFn
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import AlgorithmConfig

__all__ = ["DataParallelPPOActor"]

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


[docs] class DataParallelPPOActor(DPActor):
[docs] def __init__( self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None ): """When optimizer is None, it is Reference Policy""" super().__init__(config, actor_module, actor_optimizer) self.policy_loss_fn = None self.kl_loss_fn = None self.entropy_loss_fn = None
[docs] def set_algorithm(self, algorithm_config: AlgorithmConfig): self.loss_agg_mode = algorithm_config.loss_agg_mode self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)( backend="verl", **algorithm_config.policy_loss_fn_args ) self.kl_loss_fn = KL_FN.get(algorithm_config.kl_loss_fn)(**algorithm_config.kl_loss_fn_args) self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)( **algorithm_config.entropy_loss_fn_args )
@GPUMemoryLogger(role="dp actor", logger=logger) def update_policy(self, data: DataProto): # noqa: C901 # make sure we are in training mode self.actor_module.train() # temperature must be in the data.meta_info to avoid silent error temperature = data.meta_info["temperature"] select_keys = [ "input_ids", "position_ids", "attention_mask", "responses", "response_mask", ] select_keys.extend(self.policy_loss_fn.select_keys) if not isinstance(self.kl_loss_fn, DummyKLFn): select_keys.append("ref_log_prob") # rollout_is_weights will be used in policy loss # rollout_log_probs is equal to old_log_prob in Trinity select_keys = list(set(select_keys)) has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 mini_batches = data.split(self.config.ppo_mini_batch_size) # EXPERIMENTAL: apply loss scale fix do_fix_actor_microbatch_loss_scale = self.config.fix_actor_microbatch_loss_scale and ( self.loss_agg_mode == "token-mean" ) metrics = {} for _ in range(self.config.ppo_epochs): for batch_idx, mini_batch in enumerate(mini_batches): if self.config.use_dynamic_bsz: max_token_len = ( self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size ) micro_batches, _ = prepare_dynamic_batch( mini_batch, max_token_len=max_token_len ) else: self.gradient_accumulation = ( self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu ) micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) if do_fix_actor_microbatch_loss_scale: # calculate the total number of response tokens in the minibatch mini_batch_token_num = torch.sum( mini_batch.batch["response_mask"].to(get_device_id()) ).item() self.actor_optimizer.zero_grad() for micro_batch in micro_batches: micro_batch = micro_batch.to(get_device_id()) micro_batch_metrics = {} model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} response_mask = model_inputs["response_mask"] loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") # all return: (bsz, response_length) calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn outputs = self._forward_micro_batch( micro_batch=model_inputs, temperature=temperature, calculate_entropy=calculate_entropy, ) log_prob = outputs["log_probs"] entropy = outputs["entropys"] if calculate_entropy else None pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore logprob=log_prob, **model_inputs ) prefix_metrics( src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics ) # TODO: to be check # Skip if using bypass_mode loss (metrics already computed in pg_metrics) rollout_log_prob = model_inputs.get("rollout_log_probs", None) if loss_mode != "bypass_mode" and rollout_log_prob is not None: # Compute metrics using CURRENT policy π_θ vs π_rollout # Tracks evolving off-policy gap as π_θ updates during mini-batch training from verl.trainer.ppo.rollout_corr_helper import ( compute_rollout_corr_metrics_from_logprobs, ) rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs( log_prob=log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask, ) micro_batch_metrics.update(rollout_corr_metrics) # compute entropy loss from entropy entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore entropy=entropy, action_mask=response_mask, loss_agg_mode=self.loss_agg_mode, **model_inputs, ) prefix_metrics( src_metrics=entropy_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics, ) # compute policy loss policy_loss = pg_loss - entropy_loss kl_loss, kl_loss_metrics = self.kl_loss_fn.calculate_kl_loss( logprob=log_prob, ref_logprob=model_inputs.get("ref_log_prob", None), response_mask=response_mask, loss_agg_mode=self.loss_agg_mode, old_logprob=model_inputs.get("old_log_probs", None), ) prefix_metrics( src_metrics=kl_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics, ) policy_loss = policy_loss + kl_loss # set loss scale for the microbatch if not do_fix_actor_microbatch_loss_scale: # original implementation of microbatch loss scale if self.config.use_dynamic_bsz: loss_scale = response_mask.shape[0] / self.config.ppo_mini_batch_size else: loss_scale = 1.0 / self.gradient_accumulation else: # EXPERIMENTAL: fix for token-mean loss aggregation # scale microbatch loss according to the number of tokens (rather than sequences) loss_scale = torch.sum(response_mask).item() / (mini_batch_token_num + 1e-6) loss = policy_loss * loss_scale micro_batch_metrics["actor/final_loss"] = loss.detach().item() if "actor/kl_loss" in micro_batch_metrics: micro_batch_metrics["actor/kl_loss"] *= loss_scale if "actor/pg_loss" in micro_batch_metrics: micro_batch_metrics["actor/pg_loss"] *= loss_scale if self.scaler is not None: self.scaler.scale(loss).backward() else: loss.backward() append_to_dict(metrics, micro_batch_metrics) grad_norm = self._optimizer_step() mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} append_to_dict(metrics, mini_batch_metrics) self.actor_optimizer.zero_grad() return metrics # TODO: remove this method after upgrading verl def _forward_micro_batch( # type: ignore # noqa: C901 self, micro_batch, temperature, calculate_entropy=False ) -> dict[str, torch.Tensor]: """ Returns: dict[str, torch.Tensor]: a dict containing keys - ``entropy``: tensor of shape [batch_size, response_length]. torch.float32. - ``log_probs``: tensor of shape [batch_size, response_length]. torch.float32. """ response_length = micro_batch["responses"].size(-1) multi_modal_inputs = {} if "multi_modal_inputs" in micro_batch.keys(): from verl.utils.model import extract_multi_modal_inputs multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) with torch.autocast(device_type=self.device_name, dtype=self.param_dtype): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] position_ids = micro_batch["position_ids"] entropy = None if position_ids.dim() == 3: # qwen2vl mrope position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) if self.use_remove_padding: input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( input_ids.unsqueeze(-1), attention_mask ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary if position_ids.dim() == 3: position_ids_rmpad = ( index_first_axis( rearrange(position_ids, "c b s ... -> (b s) c ..."), indices ) .transpose(0, 1) .unsqueeze(1) ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) else: position_ids_rmpad = index_first_axis( rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices ).transpose(0, 1) is_mask_all_zero = attention_mask.sum() == 0 if is_mask_all_zero: input_ids_rmpad = torch.zeros( (1, self.ulysses_sequence_parallel_size), device=input_ids.device, dtype=input_ids.dtype, ) if position_ids.dim() == 3: position_ids_rmpad = torch.zeros( (position_ids.shape[0], 1, self.ulysses_sequence_parallel_size), device=position_ids.device, dtype=position_ids.dtype, ) else: position_ids_rmpad = torch.zeros( (1, self.ulysses_sequence_parallel_size), device=position_ids.device, dtype=position_ids.dtype, ) if "image_bound" in multi_modal_inputs: from verl.utils.dataset.vision_utils import ( process_multi_modal_inputs_for_minicpmo, ) multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs ) # for compute the log_prob input_ids_rmpad_rolled = torch.roll( input_ids_rmpad, shifts=-1, dims=1 ) # (1, total_nnz) # pad and slice the inputs if sp > 1 if self.use_ulysses_sp: is_vlm_model = hasattr( getattr(self.actor_module, "module", self.actor_module).config, "vision_config", ) if is_vlm_model: # vlm model's inputs will be sliced after embedding input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( input_ids_rmpad, position_ids_rmpad=position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size, ) else: ( input_ids_rmpad, position_ids_rmpad, pad_size, ) = ulysses_pad_and_slice_inputs( input_ids_rmpad, position_ids_rmpad=position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size, ) input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( input_ids_rmpad_rolled, position_ids_rmpad=None, sp_size=self.ulysses_sequence_parallel_size, ) input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( 0 ) # ((total_nnz / sp) + pad) # only pass input_ids and position_ids to enable flash_attn_varlen extra_args = {} if self.use_fused_kernels: extra_args["temperature"] = temperature extra_args["return_dict"] = True output = self.actor_module( input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, **multi_modal_inputs, use_cache=False, **extra_args, ) # prevent model thinks we are generating if self.use_fused_kernels: log_probs = output.log_probs.squeeze(0) # (total_nnz,) entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) else: logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) logits_rmpad.div_(temperature) # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) inplace_backward = True if calculate_entropy: inplace_backward = False log_probs = logprobs_from_logits( logits=logits_rmpad, labels=input_ids_rmpad_rolled, inplace_backward=inplace_backward, ) # compute entropy if calculate_entropy: if not self.config.entropy_checkpointing: entropy_rmpad = self.compute_entropy_from_logits( logits_rmpad ) # ((total_nnz / sp) + pad) else: entropy_rmpad = torch.utils.checkpoint.checkpoint( self.compute_entropy_from_logits, logits_rmpad ) # gather log_prob if sp > 1 if self.use_ulysses_sp: # gather and unpad for the ulysses sp log_probs = gather_outputs_and_unpad( log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size, ) if calculate_entropy: entropy_rmpad = gather_outputs_and_unpad( entropy_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size, ) if is_mask_all_zero: log_probs = log_probs[:0] if calculate_entropy: entropy_rmpad = entropy_rmpad[:0] # pad back to (bsz, seqlen) if calculate_entropy: full_entropy = pad_input( hidden_states=entropy_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen, ) full_log_probs = pad_input( hidden_states=log_probs.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen, ) # only return response part: if calculate_entropy: entropy = full_entropy.squeeze(-1)[ :, -response_length - 1 : -1 ] # (bsz, response_length) log_probs = full_log_probs.squeeze(-1)[ :, -response_length - 1 : -1 ] # (bsz, response_length) else: # not using rmpad and no ulysses sp extra_args = {} if self.use_fused_kernels: extra_args["temperature"] = temperature extra_args["return_dict"] = True output = self.actor_module( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, **multi_modal_inputs, use_cache=False, **extra_args, ) # prevent model thinks we are generating if self.use_fused_kernels: log_probs = output.log_probs[:, -response_length - 1 : -1] entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) else: logits = output.logits logits.div_(temperature) logits = logits[ :, -response_length - 1 : -1, : ] # (bsz, response_length, vocab_size) log_probs = logprobs_from_logits(logits, micro_batch["responses"]) if calculate_entropy: if not self.config.entropy_checkpointing: entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) else: entropy = torch.utils.checkpoint.checkpoint( verl_F.entropy_from_logits, logits ) outputs = {"log_probs": log_probs} if calculate_entropy: outputs["entropys"] = entropy return outputs # TODO: remove this method after upgrading verl @GPUMemoryLogger(role="dp actor", logger=logger) def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> dict[str, torch.Tensor]: """Compute the log probability of the responses given input_ids, attention_mask and position_ids Args: data (DataProto): a DataProto containing keys ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Returns: dict[str, torch.Tensor]: a dict containing keys - ``log_probs``: tensor of shape [batch_size, response_length]. torch.float32. - ``entropys``: tensor of shape [batch_size, response_length]. torch.float32. """ # set to eval self.actor_module.eval() micro_batch_size = data.meta_info["micro_batch_size"] temperature = data.meta_info[ "temperature" ] # temperature must be in the data.meta_info to avoid silent error use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) if use_dynamic_bsz: max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) else: micro_batches = data.split(micro_batch_size) log_probs_lst = [] entropy_lst = [] for micro_batch in micro_batches: micro_batch = micro_batch.to(get_device_id()) model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} with torch.no_grad(): outputs = self._forward_micro_batch( model_inputs, temperature=temperature, calculate_entropy=calculate_entropy ) log_probs_lst.append(outputs["log_probs"]) if calculate_entropy: entropy_lst.append(outputs["entropys"]) log_probs = torch.concat(log_probs_lst, dim=0) if calculate_entropy: entropys = torch.concat(entropy_lst, dim=0) if use_dynamic_bsz: log_probs = restore_dynamic_batch(log_probs, batch_idx_list) if calculate_entropy: entropys = restore_dynamic_batch(entropys, batch_idx_list) outputs = {"log_probs": log_probs} if calculate_entropy: outputs["entropys"] = entropys return outputs