Source code for trinity.common.patch.qwen3_5

from dataclasses import dataclass
from functools import wraps
from typing import Any, Optional

import torch
import torch.distributed as dist
from torch import Tensor
from transformers.models.qwen3_5.modeling_qwen3_5 import (
    BaseModelOutputWithPooling,
    Cache,
    F,
    Qwen3_5CausalLMOutputWithPast,
    Qwen3_5ForConditionalGeneration,
    Qwen3_5ModelOutputWithPast,
    TransformersKwargs,
    Unpack,
    apply_mask_to_padding_states,
    can_return_tuple,
)
from verl.utils.ulysses import all_gather_tensor


[docs] class Slice(torch.autograd.Function):
[docs] @staticmethod def forward( ctx: Any, group: dist.ProcessGroup, global_tensor: Tensor, dim: int, grad_scaler: bool = True, async_op=False, ) -> Tensor: ctx.group = group ctx.dim = dim ctx.grad_scaler = grad_scaler ctx.async_op = async_op sp_world_size = dist.get_world_size(group=group) ctx.sp_world_size = sp_world_size sp_rank = dist.get_rank(group=group) ctx.sp_rank = sp_rank # slice the input tensor dim_size = global_tensor.size(dim) if dim_size % sp_world_size != 0: raise ValueError( f"Cannot evenly slice tensor of size {dim_size} along dim {dim} " f"across {sp_world_size} ranks. This would truncate data. " "Ensure the dimension size is divisible by the SP world size." ) parts = dim_size // sp_world_size slc = [slice(None)] * len(global_tensor.shape) slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts) return global_tensor[tuple(slc)].contiguous()
[docs] @staticmethod def backward(ctx, grad_outputs: Tensor) -> Any: if ctx.grad_scaler: grad_outputs = grad_outputs / ctx.sp_world_size output = all_gather_tensor(grad_outputs, ctx.group, ctx.async_op) return ( None, torch.cat(output.split(grad_outputs.size(0), dim=0), dim=ctx.dim).contiguous(), None, None, None, None, )
_in_gate_delta_net_with_sp = False
[docs] def ulysses_gate_delta_net_decorator(net, ulysses_sp_size): """Decorator to enable Ulysses Sequence Parallel for Qwen3.5 GateDeltaNet linear attention. This decorator patches the GateDeltaNet module to support sequence parallelism using the Ulysses strategy. It intercepts various operations (forward pass, projections, convolutions, and attention) to properly scatter/gather tensors across sequence parallel ranks. Args: net: The GateDeltaNet module to patch (typically a linear attention layer). ulysses_sp_size: The sequence parallel world size. If 1, no patching is performed. Note: - This function patches the module in-place and sets a `_is_patched` flag to avoid double-patching. - The sequence parallel operations are controlled via a global `_in_gate_delta_net_with_sp` flag. - The patching includes modifications to forward, in_proj_qkv, conv1d, torch.split, and chunk_gated_delta_rule. """ if getattr(net, "_is_patched", False): return net._is_patched = True # ulysses sequence parallel setup from verl.utils.ulysses import ( gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_group, ) if ulysses_sp_size == 1: # no need to patch return # Patch net.forward original_net_forward = net.forward @wraps(original_net_forward) def new_net_forward(*args, **kwargs): global _in_gate_delta_net_with_sp _in_gate_delta_net_with_sp = True output = original_net_forward(*args, **kwargs) _in_gate_delta_net_with_sp = False return output net.forward = new_net_forward # Patch in_proj_qkv original_in_proj_qkv_forward = net.in_proj_qkv.forward @wraps(original_in_proj_qkv_forward) def new_in_proj_qkv_forward(input): output = original_in_proj_qkv_forward(input) group = get_ulysses_sequence_parallel_group() output = gather_seq_scatter_heads(output, seq_dim=1, head_dim=2, group=group) return output net.in_proj_qkv.forward = new_in_proj_qkv_forward # Patch conv1d layer original_conv1d_class = net.conv1d.__class__ original_conv1d_getattr = original_conv1d_class.__getattr__ @wraps(original_conv1d_getattr) def new_conv1d_getattr(self, name): global _in_gate_delta_net_with_sp attr = original_conv1d_getattr(self, name) # bias is None in Qwen3.5, so no need to patch for bias if name == "weight" and _in_gate_delta_net_with_sp: group = get_ulysses_sequence_parallel_group() return Slice.apply(group, attr, 0, True) return attr new_conv1d_class = type( f"UlyssesGated{original_conv1d_class.__name__}", (original_conv1d_class,), {"__getattr__": new_conv1d_getattr}, ) net.conv1d.__class__ = new_conv1d_class # Patch torch.split if not getattr(torch.split, "_is_patched_by_ulysses_gate_delta_net", False): original_split = torch.split @wraps(original_split) def new_split(tensor, split_size_or_sections, dim=0): global _in_gate_delta_net_with_sp if _in_gate_delta_net_with_sp and dim == -1 and len(split_size_or_sections) == 3: tensor = gather_heads_scatter_seq(tensor, seq_dim=1, head_dim=2) return original_split(tensor, split_size_or_sections, dim) torch.split = new_split torch.split._is_patched_by_ulysses_gate_delta_net = True # Patch chunk_gated_delta_rule original_chunk_gated_delta_rule = net.chunk_gated_delta_rule @wraps(original_chunk_gated_delta_rule) def new_chunk_gated_delta_rule(query, key, value, g, beta, **kwargs): query = gather_seq_scatter_heads(query, seq_dim=1, head_dim=2) key = gather_seq_scatter_heads(key, seq_dim=1, head_dim=2) value = gather_seq_scatter_heads(value, seq_dim=1, head_dim=2) g = gather_seq_scatter_heads(g, seq_dim=1, head_dim=2) beta = gather_seq_scatter_heads(beta, seq_dim=1, head_dim=2) output, last_recurrent_state = original_chunk_gated_delta_rule( query, key, value, g, beta, **kwargs ) output = gather_heads_scatter_seq(output, seq_dim=1, head_dim=2) return output, last_recurrent_state net.chunk_gated_delta_rule = new_chunk_gated_delta_rule
# removed when following PR is merged # https://github.com/huggingface/transformers/pull/45034/changes
[docs] def gate_delta_net_forward( self, hidden_states: torch.Tensor, cache_params: Cache | None = None, attention_mask: torch.Tensor | None = None, **kwargs, ): """Forward pass for Qwen3.5 GateDeltaNet linear attention with packing support. This implementation of the linear attention forward pass supports packed sequences for efficient training, following the approach referenced in the Hugging Face transformers PR #45034. It handles both incremental (cached) and non-cached inference modes. Args: hidden_states: Input hidden states of shape (batch_size, seq_len, hidden_dim). cache_params: Optional cache parameters for incremental decoding. attention_mask: Optional attention mask to mask out padding positions. **kwargs: Additional keyword arguments passed to sub-components (e.g., seq_idx for packed sequences). Returns: Output tensor of shape (batch_size, seq_len, hidden_dim) after linear attention computation. """ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) # getting projected states from cache if it exists if use_precomputed_states: conv_state = cache_params.layers[self.layer_idx].conv_states recurrent_state = cache_params.layers[self.layer_idx].recurrent_states mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) z = self.in_proj_z(hidden_states) z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) if use_precomputed_states: # 2. Convolution sequence transformation # NOTE: the conv state is updated in `causal_conv1d_update` mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, ) else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: seq_idx = kwargs.get("seq_idx", None) mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, weight=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, seq_idx=seq_idx, ) else: mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( mixed_qkv, [ self.key_dim, self.key_dim, self.value_dim, ], dim=-1, ) query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) beta = b.sigmoid() # If the model is loaded in fp16, without the .float() here, A might be -inf g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) if self.num_v_heads // self.num_k_heads > 1: query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) if not use_precomputed_states: chunk_kwargs = {} if getattr(self.chunk_gated_delta_rule, "__module__", "").startswith("fla."): chunk_kwargs["cu_seqlens"] = kwargs.get("cu_seqlens", None) core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, initial_state=None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, **chunk_kwargs, ) else: core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) # Update cache if cache_params is not None: cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) z = z.reshape(-1, self.head_v_dim) core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) output = self.out_proj(core_attn_out) return output
# removed when following PR is merged # https://github.com/huggingface/transformers/pull/45034/changes
[docs] def decoder_layer_forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: """Forward pass for a Qwen3.5 decoder layer supporting packed sequences. This function implements a full transformer decoder layer with support for packed sequences (packing training). It combines token mixing (via linear or full attention) with a feed-forward network, with residual connections around each sub-layer. Args: hidden_states: Input hidden states of shape (batch_size, seq_len, hidden_dim). position_embeddings: Tuple of (cos_cached, sin_cached) for rotary position embeddings. attention_mask: Optional attention mask. position_ids: Optional position IDs for the sequence. past_key_values: Optional cache for incremental decoding. **kwargs: Additional arguments including: - layer_type: Either 'linear_attention' or 'full_attention' to determine token mixer. - seq_idx: Sequence indices for packed sequence training. Returns: Output hidden states of same shape as input (batch_size, seq_len, hidden_dim). """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Token Mixer if self.layer_type == "linear_attention": hidden_states = self.linear_attn( hidden_states=hidden_states, cache_params=past_key_values, attention_mask=attention_mask, **kwargs, ) elif self.layer_type == "full_attention": # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states
[docs] def qwen35_vision_fast_pos_embed_interpolate(self, grid_thw): """Interpolate vision position embeddings for variable resolution inputs with proper device handling. This function performs bilinear interpolation of position embeddings to support variable spatial resolutions. It fixes the device handling issue that occurred during CPU offloading, ensuring all tensors are created and operated on the same device as the input. Args: grid_thw: Tensor of shape (num_images, 3) containing temporal, height, and width dimensions for each image in the batch. Returns: Interpolated position embeddings of shape (total_patches, embedding_dim) after merging, where total_patches is the sum of all h*w for each image after spatial merging. Note: - The function supports batch processing of multiple images with different resolutions. - Spatial merging is applied based on config.spatial_merge_size. - All tensors are properly placed on the same device as the input grid_thw. """ grid_thw_list = grid_thw.tolist() grid_ts = [row[0] for row in grid_thw_list] grid_hs = [row[1] for row in grid_thw_list] grid_ws = [row[2] for row in grid_thw_list] device = grid_thw.device # modified to ensure tensors are created on the correct device idx_list = [[] for _ in range(4)] weight_list = [[] for _ in range(4)] for t, h, w in grid_thw_list: h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) h_idxs_floor = h_idxs.int() w_idxs_floor = w_idxs.int() h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) dh = h_idxs - h_idxs_floor dw = w_idxs - w_idxs_floor base_h = h_idxs_floor * self.num_grid_per_side base_h_ceil = h_idxs_ceil * self.num_grid_per_side indices = [ (base_h[None].T + w_idxs_floor[None]).flatten(), (base_h[None].T + w_idxs_ceil[None]).flatten(), (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), ] weights = [ ((1 - dh)[None].T * (1 - dw)[None]).flatten(), ((1 - dh)[None].T * dw[None]).flatten(), (dh[None].T * (1 - dw)[None]).flatten(), (dh[None].T * dw[None]).flatten(), ] for i in range(4): idx_list[i].extend(indices[i].tolist()) weight_list[i].extend(weights[i].tolist()) idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) patch_pos_embeds_permute = [] merge_size = self.config.spatial_merge_size for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): pos_embed = pos_embed.repeat(t, 1) pos_embed = ( pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) .permute(0, 1, 3, 2, 4, 5) .flatten(0, 4) ) patch_pos_embeds_permute.append(pos_embed) patch_pos_embeds = torch.cat(patch_pos_embeds_permute) return patch_pos_embeds
[docs] @can_return_tuple def qwen35_model_forward( self, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, pixel_values: torch.Tensor | None = None, pixel_values_videos: torch.FloatTensor | None = None, image_grid_thw: torch.LongTensor | None = None, video_grid_thw: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen3_5ModelOutputWithPast: """Qwen3.5 model forward pass with multimodal support and gradient synchronization across ranks. This forward function handles multimodal training (images and/or videos) across multiple GPU ranks with proper synchronization. When a rank doesn't have image/video inputs but other ranks do (common in distributed training with different data samples), it creates dummy images/videos to maintain consistency and avoid hanging in collective operations. Args: input_ids: Token IDs of shape (batch_size, seq_len). attention_mask: Attention mask for padding tokens. position_ids: Position IDs for embeddings. past_key_values: Cached key-values for incremental decoding. inputs_embeds: Pre-computed input embeddings (alternative to input_ids). pixel_values: Image pixel values of shape (num_images, channels, height, width). pixel_values_videos: Video pixel values of shape (num_videos, frames, channels, height, width). image_grid_thw: Grid dimensions (temporal, height, width) for images. video_grid_thw: Grid dimensions (temporal, height, width) for videos. mm_token_type_ids: Token type IDs to distinguish image, video, and text tokens. **kwargs: Additional arguments. Returns: Qwen3_5ModelOutputWithPast containing language model outputs with rope_deltas for position embeddings. Note: - Dummy images/videos are created with shape based on spatial_merge_size when needed for gradient synchronization. - Uses distributed communication (dist.all_reduce) to synchronize multimodal input availability across ranks. """ r""" image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) vision_config = self.config.vision_config pixel_values_dim = ( vision_config.in_channels * vision_config.temporal_patch_size * (vision_config.patch_size**2) ) merge_size = vision_config.spatial_merge_size device = inputs_embeds.device has_mm_local = torch.tensor( [int(pixel_values is not None), int(pixel_values_videos is not None)], device=device ) has_mm_global = has_mm_local.clone() if dist.is_initialized(): dist.all_reduce(has_mm_global) has_mm_global = has_mm_global > 0 # check images if has_mm_global[0].item(): if not has_mm_local[0].item(): pixel_values = torch.zeros( (merge_size * merge_size, pixel_values_dim), dtype=torch.float32, device=device ) image_grid_thw = torch.ones((1, 3), dtype=torch.int64, device=device) image_grid_thw[:, 1:] = merge_size image_outputs: BaseModelOutputWithPooling = self.get_image_features( pixel_values, image_grid_thw, return_dict=True ) image_embeds = image_outputs.pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) if has_mm_local[0].item(): image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) else: # patched for backward inputs_embeds[0] = inputs_embeds[0] + image_embeds[0] * 0.0 # check videos if has_mm_global[1].item(): if not has_mm_local[1].item(): pixel_values_videos = torch.zeros( (merge_size * merge_size, pixel_values_dim), dtype=torch.float32, device=device ) video_grid_thw = torch.ones((1, 3), dtype=torch.int64, device=device) video_grid_thw[:, 1:] = merge_size video_outputs: BaseModelOutputWithPooling = self.get_video_features( pixel_values_videos, video_grid_thw, return_dict=True ) video_embeds = video_outputs.pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) if has_mm_local[1].item(): _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds ) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) else: # patched for backward inputs_embeds[0] = inputs_embeds[0] + video_embeds[0] * 0.0 if position_ids is None: position_ids = self.compute_3d_position_ids( input_ids=input_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, mm_token_type_ids=mm_token_type_ids, ) outputs = self.language_model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs, ) return Qwen3_5ModelOutputWithPast( **outputs, rope_deltas=self.rope_deltas, )
[docs] @dataclass class Qwen3_5CausalLMOutputForPPO(Qwen3_5CausalLMOutputWithPast): log_probs: Optional[torch.FloatTensor] = None entropy: Optional[torch.FloatTensor] = None
[docs] def forward_with_torch_backend( self: Qwen3_5ForConditionalGeneration, input_ids: torch.LongTensor = None, labels: Optional[torch.LongTensor] = None, temperature: float = 1.0, **kwargs, ) -> tuple | Qwen3_5CausalLMOutputForPPO: """Compute log probabilities and entropy for reinforcement learning using PyTorch backend. This function computes per-token log probabilities and entropy from the language model's hidden states using a fused PyTorch-based linear projection. It's designed for PPO and other RL algorithms that require per-token probability distributions over the vocabulary. Args: input_ids: Token IDs of shape (batch_size, seq_len). labels: Optional labels for loss computation. If None, input_ids are rolled to compute shifted targets. temperature: Temperature scaling for softmax (default: 1.0). Used to control probability distribution sharpness. **kwargs: Additional arguments passed to the model (e.g., attention_mask). Returns: Qwen3_5CausalLMOutputForPPO containing: - log_probs: Log probabilities of shape (batch_size, seq_len) - entropy: Entropy values of shape (batch_size, seq_len) - hidden_states: Hidden states from the model forward pass Raises: RuntimeError: If neither labels nor input_ids is provided. Note: - Uses FusedLinearForPPO for efficient torch-based computation. - The log probability target is computed by rolling labels (or input_ids) by -1 to create next-token prediction targets. """ from verl.utils.experimental.torch_functional import FusedLinearForPPO outputs = self.model(input_ids=input_ids, **kwargs) hidden_states = outputs[0] # Loss calculations if labels is not None: rolled_labels = torch.roll(labels, shifts=-1, dims=-1) elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: raise RuntimeError( "To use forward_with_torch_backend, either labels or input_ids must be provided." ) fused_linear_for_ppo = FusedLinearForPPO() log_probs, entropy = fused_linear_for_ppo.forward( hidden_states=hidden_states, vocab_weights=self.lm_head.weight, input_ids=rolled_labels, temperature=temperature, ) return Qwen3_5CausalLMOutputForPPO( log_probs=log_probs, entropy=entropy, hidden_states=outputs.hidden_states, )
[docs] def forward_with_triton_backend( self: Qwen3_5ForConditionalGeneration, input_ids: torch.LongTensor = None, labels: Optional[torch.LongTensor] = None, temperature: float = 1.0, **kwargs, ) -> tuple | Qwen3_5CausalLMOutputForPPO: """Compute log probabilities and entropy for reinforcement learning using Triton kernel backend. This function computes per-token log probabilities and entropy from the language model's hidden states using an optimized Triton kernel (linear_cross_entropy). It provides better performance compared to the PyTorch backend for large vocabularies, suitable for PPO and other RL algorithms. Args: input_ids: Token IDs of shape (batch_size, seq_len). labels: Optional labels for loss computation. If None, input_ids are rolled to compute shifted targets. temperature: Temperature scaling for softmax (default: 1.0). Used to control probability distribution sharpness. **kwargs: Additional arguments passed to the model (e.g., attention_mask). Returns: Qwen3_5CausalLMOutputForPPO containing: - log_probs: Log probabilities of shape (batch_size, seq_len) - entropy: Entropy values of shape (batch_size, seq_len) - hidden_states: Hidden states from the model forward pass Raises: RuntimeError: If neither labels nor input_ids is provided. Note: - Uses the linear_cross_entropy Triton kernel from verl for highly optimized computation. - The log probability target is computed by rolling labels (or input_ids) by -1 to create next-token prediction targets. - Generally faster than forward_with_torch_backend for large vocabulary sizes. """ from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy outputs = self.model(input_ids=input_ids, **kwargs) hidden_states = outputs[0] # Loss calculations if labels is not None: rolled_labels = torch.roll(labels, shifts=-1, dims=-1) elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: raise RuntimeError( "To use forward_with_triton_backend, either labels or input_ids must be provided." ) log_probs, entropy = linear_cross_entropy( hidden_states, self.lm_head.weight, rolled_labels, temperature, "none", ) return Qwen3_5CausalLMOutputForPPO( log_probs=log_probs, entropy=entropy, hidden_states=outputs.hidden_states, )