trinity.common.patch.qwen3_5 源代码

from dataclasses import dataclass
from functools import wraps
from typing import Any, Optional

import torch
import torch.distributed as dist
from torch import Tensor
from transformers.models.qwen3_5.modeling_qwen3_5 import (
    BaseModelOutputWithPast,
    Cache,
    Qwen3_5CausalLMOutputWithPast,
    Qwen3_5ForConditionalGeneration,
    Qwen3_5ModelOutputWithPast,
    TransformersKwargs,
    Unpack,
    capture_outputs,
    create_causal_mask,
    merge_with_config_defaults,
)
from verl.utils.ulysses import all_gather_tensor


[文档] class Slice(torch.autograd.Function):
[文档] @staticmethod def forward( ctx: Any, group: dist.ProcessGroup, global_tensor: Tensor, dim: int, grad_scaler: bool = True, async_op=False, ) -> Tensor: ctx.group = group ctx.dim = dim ctx.grad_scaler = grad_scaler ctx.async_op = async_op sp_world_size = dist.get_world_size(group=group) ctx.sp_world_size = sp_world_size sp_rank = dist.get_rank(group=group) ctx.sp_rank = sp_rank # slice the input tensor dim_size = global_tensor.size(dim) if dim_size % sp_world_size != 0: raise ValueError( f"Cannot evenly slice tensor of size {dim_size} along dim {dim} " f"across {sp_world_size} ranks. This would truncate data. " "Ensure the dimension size is divisible by the SP world size." ) parts = dim_size // sp_world_size slc = [slice(None)] * len(global_tensor.shape) slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts) return global_tensor[tuple(slc)].contiguous()
[文档] @staticmethod def backward(ctx, grad_outputs: Tensor) -> Any: if ctx.grad_scaler: grad_outputs = grad_outputs / ctx.sp_world_size output = all_gather_tensor(grad_outputs, ctx.group, ctx.async_op) return ( None, torch.cat(output.split(grad_outputs.size(0), dim=0), dim=ctx.dim).contiguous(), None, None, None, None, )
# TODO: may optimize this function
[文档] def ulysses_gated_delta_net_forward_decorator(func): @wraps(func) def wrapper( hidden_states: torch.Tensor, **kwargs, ): from verl.utils.ulysses import ( gather_outputs_and_unpad, get_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_world_size, ) ulysses_sp_size = get_ulysses_sequence_parallel_world_size() if ulysses_sp_size > 1: hidden_states = gather_outputs_and_unpad(hidden_states, gather_dim=1) output = func(hidden_states, **kwargs) if ulysses_sp_size > 1: group = get_ulysses_sequence_parallel_group() output = Slice.apply(group, output, 1) return output return wrapper
[文档] @merge_with_config_defaults @capture_outputs def qwen35_text_forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DynamicCache past_key_values = Qwen3_5DynamicCache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) # mrope: the hard coded `3` is for temporal, height and width. if position_ids is None: position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) elif position_ids.ndim == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) if position_ids.ndim == 3 and position_ids.shape[0] == 4: text_position_ids = position_ids[0] position_ids = position_ids[1:] else: text_position_ids = position_ids[0] causal_mask = create_causal_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=text_position_ids, ) linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): layer_mask = ( linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask ) hidden_states = decoder_layer( hidden_states, position_embeddings=position_embeddings, attention_mask=layer_mask, position_ids=text_position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = self.norm(hidden_states) return Qwen3_5ModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, )
[文档] @dataclass class Qwen3_5CausalLMOutputForPPO(Qwen3_5CausalLMOutputWithPast): log_probs: Optional[torch.FloatTensor] = None entropy: Optional[torch.FloatTensor] = None
[文档] def forward_with_torch_backend( self: Qwen3_5ForConditionalGeneration, input_ids: torch.LongTensor = None, labels: Optional[torch.LongTensor] = None, temperature: float = 1.0, **kwargs, ) -> tuple | Qwen3_5CausalLMOutputForPPO: from verl.utils.experimental.torch_functional import FusedLinearForPPO outputs = self.model(input_ids=input_ids, **kwargs) hidden_states = outputs[0] # Loss calculations if labels is not None: rolled_labels = torch.roll(labels, shifts=-1, dims=-1) elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: raise RuntimeError( "To use forward_with_torch_backend, either labels or input_ids must be provided." ) fused_linear_for_ppo = FusedLinearForPPO() log_probs, entropy = fused_linear_for_ppo.forward( hidden_states=hidden_states, vocab_weights=self.lm_head.weight, input_ids=rolled_labels, temperature=temperature, ) return Qwen3_5CausalLMOutputForPPO( log_probs=log_probs, entropy=entropy, hidden_states=outputs.hidden_states, )
[文档] def forward_with_triton_backend( self: Qwen3_5ForConditionalGeneration, input_ids: torch.LongTensor = None, labels: Optional[torch.LongTensor] = None, temperature: float = 1.0, **kwargs, ) -> tuple | Qwen3_5CausalLMOutputForPPO: from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy outputs = self.model(input_ids=input_ids, **kwargs) hidden_states = outputs[0] # Loss calculations if labels is not None: rolled_labels = torch.roll(labels, shifts=-1, dims=-1) elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: raise RuntimeError( "To use forward_with_triton_backend, either labels or input_ids must be provided." ) log_probs, entropy = linear_cross_entropy( hidden_states, self.lm_head.weight, rolled_labels, temperature, "none", ) return Qwen3_5CausalLMOutputForPPO( log_probs=log_probs, entropy=entropy, hidden_states=outputs.hidden_states, )