trinity.common.patch.qwen3_5 module#
- class trinity.common.patch.qwen3_5.Slice(*args, **kwargs)[source]#
Bases:
Function- static forward(ctx: Any, group: ProcessGroup, global_tensor: Tensor, dim: int, grad_scaler: bool = True, async_op=False) Tensor[source]#
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()staticmethod to handle setting up thectxobject.outputis the output of the forward,inputsare a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()if they are intended to be used inbackward(equivalently,vjp) orctx.save_for_forward()if they are intended to be used for injvp.
- static backward(ctx, grad_outputs: Tensor) Any[source]#
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjpfunction.)It must accept a context
ctxas the first argument, followed by as many outputs as theforward()returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_gradas a tuple of booleans representing whether each input needs gradient. E.g.,backward()will havectx.needs_input_grad[0] = Trueif the first input toforward()needs gradient computed w.r.t. the output.
- trinity.common.patch.qwen3_5.ulysses_gate_delta_net_decorator(net, ulysses_sp_size)[source]#
Decorator to enable Ulysses Sequence Parallel for Qwen3.5 GateDeltaNet linear attention.
This decorator patches the GateDeltaNet module to support sequence parallelism using the Ulysses strategy. It intercepts various operations (forward pass, projections, convolutions, and attention) to properly scatter/gather tensors across sequence parallel ranks.
- Parameters:
net – The GateDeltaNet module to patch (typically a linear attention layer).
ulysses_sp_size – The sequence parallel world size. If 1, no patching is performed.
Note
This function patches the module in-place and sets a _is_patched flag to avoid double-patching.
The sequence parallel operations are controlled via a global _in_gate_delta_net_with_sp flag.
The patching includes modifications to forward, in_proj_qkv, conv1d, torch.split, and chunk_gated_delta_rule.
- trinity.common.patch.qwen3_5.gate_delta_net_forward(self, hidden_states: Tensor, cache_params: Cache | None = None, attention_mask: Tensor | None = None, **kwargs)[source]#
Forward pass for Qwen3.5 GateDeltaNet linear attention with packing support.
This implementation of the linear attention forward pass supports packed sequences for efficient training, following the approach referenced in the Hugging Face transformers PR #45034. It handles both incremental (cached) and non-cached inference modes.
- Parameters:
hidden_states – Input hidden states of shape (batch_size, seq_len, hidden_dim).
cache_params – Optional cache parameters for incremental decoding.
attention_mask – Optional attention mask to mask out padding positions.
**kwargs – Additional keyword arguments passed to sub-components (e.g., seq_idx for packed sequences).
- Returns:
Output tensor of shape (batch_size, seq_len, hidden_dim) after linear attention computation.
- trinity.common.patch.qwen3_5.decoder_layer_forward(self, hidden_states: Tensor, position_embeddings: tuple[Tensor, Tensor], attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs]) FloatTensor[source]#
Forward pass for a Qwen3.5 decoder layer supporting packed sequences.
This function implements a full transformer decoder layer with support for packed sequences (packing training). It combines token mixing (via linear or full attention) with a feed-forward network, with residual connections around each sub-layer.
- Parameters:
hidden_states – Input hidden states of shape (batch_size, seq_len, hidden_dim).
position_embeddings – Tuple of (cos_cached, sin_cached) for rotary position embeddings.
attention_mask – Optional attention mask.
position_ids – Optional position IDs for the sequence.
past_key_values – Optional cache for incremental decoding.
**kwargs – Additional arguments including: - layer_type: Either ‘linear_attention’ or ‘full_attention’ to determine token mixer. - seq_idx: Sequence indices for packed sequence training.
- Returns:
Output hidden states of same shape as input (batch_size, seq_len, hidden_dim).
- trinity.common.patch.qwen3_5.qwen35_vision_fast_pos_embed_interpolate(self, grid_thw)[source]#
Interpolate vision position embeddings for variable resolution inputs with proper device handling.
This function performs bilinear interpolation of position embeddings to support variable spatial resolutions. It fixes the device handling issue that occurred during CPU offloading, ensuring all tensors are created and operated on the same device as the input.
- Parameters:
grid_thw – Tensor of shape (num_images, 3) containing temporal, height, and width dimensions for each image in the batch.
- Returns:
Interpolated position embeddings of shape (total_patches, embedding_dim) after merging, where total_patches is the sum of all h*w for each image after spatial merging.
Note
The function supports batch processing of multiple images with different resolutions.
Spatial merging is applied based on config.spatial_merge_size.
All tensors are properly placed on the same device as the input grid_thw.
- trinity.common.patch.qwen3_5.qwen35_model_forward(self, input_ids: LongTensor = None, attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: FloatTensor | None = None, pixel_values: Tensor | None = None, pixel_values_videos: FloatTensor | None = None, image_grid_thw: LongTensor | None = None, video_grid_thw: LongTensor | None = None, mm_token_type_ids: IntTensor | None = None, **kwargs: Unpack[TransformersKwargs]) tuple | Qwen3_5ModelOutputWithPast[source]#
Qwen3.5 model forward pass with multimodal support and gradient synchronization across ranks.
This forward function handles multimodal training (images and/or videos) across multiple GPU ranks with proper synchronization. When a rank doesn’t have image/video inputs but other ranks do (common in distributed training with different data samples), it creates dummy images/videos to maintain consistency and avoid hanging in collective operations.
- Parameters:
input_ids – Token IDs of shape (batch_size, seq_len).
attention_mask – Attention mask for padding tokens.
position_ids – Position IDs for embeddings.
past_key_values – Cached key-values for incremental decoding.
inputs_embeds – Pre-computed input embeddings (alternative to input_ids).
pixel_values – Image pixel values of shape (num_images, channels, height, width).
pixel_values_videos – Video pixel values of shape (num_videos, frames, channels, height, width).
image_grid_thw – Grid dimensions (temporal, height, width) for images.
video_grid_thw – Grid dimensions (temporal, height, width) for videos.
mm_token_type_ids – Token type IDs to distinguish image, video, and text tokens.
**kwargs – Additional arguments.
- Returns:
Qwen3_5ModelOutputWithPast containing language model outputs with rope_deltas for position embeddings.
Note
Dummy images/videos are created with shape based on spatial_merge_size when needed for gradient synchronization.
Uses distributed communication (dist.all_reduce) to synchronize multimodal input availability across ranks.
- class trinity.common.patch.qwen3_5.Qwen3_5CausalLMOutputForPPO(loss: torch.FloatTensor | None = None, logits: torch.FloatTensor | None = None, past_key_values: transformers.cache_utils.Cache | None = None, hidden_states: tuple[torch.FloatTensor] | None = None, attentions: tuple[torch.FloatTensor] | None = None, rope_deltas: torch.LongTensor | None = None, log_probs: torch.FloatTensor | None = None, entropy: torch.FloatTensor | None = None)[source]#
Bases:
Qwen3_5CausalLMOutputWithPast- log_probs: FloatTensor | None = None#
- entropy: FloatTensor | None = None#
- __init__(loss: FloatTensor | None = None, logits: FloatTensor | None = None, past_key_values: Cache | None = None, hidden_states: tuple[FloatTensor] | None = None, attentions: tuple[FloatTensor] | None = None, rope_deltas: LongTensor | None = None, log_probs: FloatTensor | None = None, entropy: FloatTensor | None = None) None#
Args: loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided):
Language modeling loss (for next-token prediction).
- logits (torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size)):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (Cache, optional, returned when use_cache=True is passed or when config.use_cache=True):
It is a [~cache_utils.Cache] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.
- hidden_states (tuple[torch.FloatTensor], optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True):
Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (tuple[torch.FloatTensor], optional, returned when output_attentions=True is passed or when config.output_attentions=True):
Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
- rope_deltas (torch.LongTensor of shape (batch_size, ), optional):
The rope index difference between sequence length and multimodal rope.
- trinity.common.patch.qwen3_5.forward_with_torch_backend(self: Qwen3_5ForConditionalGeneration, input_ids: LongTensor = None, labels: LongTensor | None = None, temperature: float = 1.0, **kwargs) tuple | Qwen3_5CausalLMOutputForPPO[source]#
Compute log probabilities and entropy for reinforcement learning using PyTorch backend.
This function computes per-token log probabilities and entropy from the language model’s hidden states using a fused PyTorch-based linear projection. It’s designed for PPO and other RL algorithms that require per-token probability distributions over the vocabulary.
- Parameters:
input_ids – Token IDs of shape (batch_size, seq_len).
labels – Optional labels for loss computation. If None, input_ids are rolled to compute shifted targets.
temperature – Temperature scaling for softmax (default: 1.0). Used to control probability distribution sharpness.
**kwargs – Additional arguments passed to the model (e.g., attention_mask).
- Returns:
log_probs: Log probabilities of shape (batch_size, seq_len)
entropy: Entropy values of shape (batch_size, seq_len)
hidden_states: Hidden states from the model forward pass
- Return type:
Qwen3_5CausalLMOutputForPPO containing
- Raises:
RuntimeError – If neither labels nor input_ids is provided.
Note
Uses FusedLinearForPPO for efficient torch-based computation.
The log probability target is computed by rolling labels (or input_ids) by -1 to create next-token prediction targets.
- trinity.common.patch.qwen3_5.forward_with_triton_backend(self: Qwen3_5ForConditionalGeneration, input_ids: LongTensor = None, labels: LongTensor | None = None, temperature: float = 1.0, **kwargs) tuple | Qwen3_5CausalLMOutputForPPO[source]#
Compute log probabilities and entropy for reinforcement learning using Triton kernel backend.
This function computes per-token log probabilities and entropy from the language model’s hidden states using an optimized Triton kernel (linear_cross_entropy). It provides better performance compared to the PyTorch backend for large vocabularies, suitable for PPO and other RL algorithms.
- Parameters:
input_ids – Token IDs of shape (batch_size, seq_len).
labels – Optional labels for loss computation. If None, input_ids are rolled to compute shifted targets.
temperature – Temperature scaling for softmax (default: 1.0). Used to control probability distribution sharpness.
**kwargs – Additional arguments passed to the model (e.g., attention_mask).
- Returns:
log_probs: Log probabilities of shape (batch_size, seq_len)
entropy: Entropy values of shape (batch_size, seq_len)
hidden_states: Hidden states from the model forward pass
- Return type:
Qwen3_5CausalLMOutputForPPO containing
- Raises:
RuntimeError – If neither labels nor input_ids is provided.
Note
Uses the linear_cross_entropy Triton kernel from verl for highly optimized computation.
The log probability target is computed by rolling labels (or input_ids) by -1 to create next-token prediction targets.
Generally faster than forward_with_torch_backend for large vocabulary sizes.