trinity.trainer.verl.utils 源代码

"""Utils for compatibility issues with verl."""
from logging import Logger
from typing import List

import numpy as np
import torch
from transformers import PreTrainedModel
from verl import DataProto
from verl.trainer.ppo.metric_utils import _compute_response_info

from trinity.common.experience import (
    Experience,
    gather_action_masks,
    gather_attention_masks,
    gather_response_attrs,
    gather_token_ids,
    split_dpo_experience_to_single_turn,
)


def _gather_routed_experts(
    experiences: List[Experience], max_prompt_length: int, max_response_length: int
) -> torch.Tensor:
    """Pad routed experts to the full left-padded sequence layout expected by verl."""
    batch_size = len(experiences)
    total_length = max_prompt_length + max_response_length
    routed_experts = experiences[0].routed_experts
    assert routed_experts is not None, "No routed_experts provided."
    _, layer_num, topk_num = routed_experts.shape
    batch_routed_experts = torch.zeros(
        batch_size,
        total_length,
        layer_num,
        topk_num,
        dtype=routed_experts.dtype,
    )

    for idx, exp in enumerate(experiences):
        exp_routed_experts = exp.routed_experts
        assert exp_routed_experts is not None, "No routed_experts provided."
        if exp_routed_experts.ndim != 3:
            raise ValueError(
                "Experience.routed_experts must have shape [seq_length - 1, layer_num, topk]."
            )
        if exp_routed_experts.shape[1:] != (layer_num, topk_num):
            raise ValueError("routed_experts shape is inconsistent across experiences.")

        start_pos = max_prompt_length - exp.prompt_length
        end_pos = min(start_pos + exp_routed_experts.shape[0], total_length)
        batch_routed_experts[idx, start_pos:end_pos] = exp_routed_experts[: end_pos - start_pos]

    return batch_routed_experts


[文档] def to_data_proto( # noqa: C901 experiences: List[Experience], pad_token_id: int, model: PreTrainedModel, logger: Logger ) -> DataProto: """Convert List[Experience] to verl DataProto.""" assert len(experiences) > 0, "No experiences provided." if experiences[0].experience_type == "dpo": experiences = split_dpo_experience_to_single_turn(experiences) max_prompt_length = max([exp.prompt_length for exp in experiences]) max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore attention_mask = gather_attention_masks( experiences, max_prompt_length, max_response_length ).long() cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() tokens = gather_token_ids( experiences, max_prompt_length, max_response_length, pad_token_id ).long() batch_dict = { "uid": np.array([exp.eid.tid for exp in experiences]), "unique_ids": np.array([exp.eid.uid for exp in experiences]), "position_ids": position_ids, "input_ids": tokens, "prompts": tokens[:, :max_prompt_length], "responses": tokens[:, max_prompt_length:], "attention_mask": attention_mask, "response_mask": gather_action_masks(experiences, max_response_length), } have_reward = all(exp.reward is not None for exp in experiences) have_token_level_reward = all(exp.token_level_reward is not None for exp in experiences) if have_reward or have_token_level_reward: assert all(exp.logprobs is not None for exp in experiences), "No logprobs provided." if have_token_level_reward: if have_reward: logger.warning( "Both experiences.rewards and experiences.token_level_rewards are provided. " "Using experiences.token_level_rewards." ) token_level_rewards = gather_response_attrs( experiences, "token_level_reward", max_response_length ) else: token_level_rewards = torch.zeros(attention_mask.shape, dtype=torch.float32) eos_mask_idx = cumsum.argmax(dim=-1) token_level_rewards[torch.arange(len(experiences)), eos_mask_idx] = torch.tensor( [exp.reward for exp in experiences], dtype=torch.float32, ) token_level_rewards = token_level_rewards[:, max_prompt_length:] batch_dict.update( { "token_level_scores": token_level_rewards, "rollout_log_probs": gather_response_attrs( experiences, "logprobs", max_response_length ), } ) for attr in ["advantages", "returns", "teacher_logprobs"]: if all(getattr(exp, attr, None) is not None for exp in experiences): batch_dict[attr] = gather_response_attrs(experiences, attr, max_response_length) if any(exp.routed_experts is not None for exp in experiences): if not all(exp.routed_experts is not None for exp in experiences): raise ValueError("routed_experts are not consistent across experiences.") batch_dict["routed_experts"] = _gather_routed_experts( experiences, max_prompt_length, max_response_length ) if hasattr(model, "get_rope_index"): # used for multi-modal model import inspect # Adapted from verl/experimental/agent_loop/agent_loop.py position_ids_list, multi_modal_inputs = [], [] for idx, exp in enumerate(experiences): mm_inputs = exp.multi_modal_inputs or {} input_ids = batch_dict["input_ids"][idx].unsqueeze(0) attention_mask = batch_dict["attention_mask"][idx].unsqueeze(0) get_rope_index_sig = inspect.signature(model.get_rope_index) get_rope_index_kwargs = {} for key in get_rope_index_sig.parameters: if key in {"self", "input_ids", "attention_mask", "kwargs"}: continue elif key == "mm_token_type_ids": pad_data = torch.zeros_like(input_ids) if key in mm_inputs: data = mm_inputs.pop(key) start = max_prompt_length - exp.prompt_length end = start + data.size(1) pad_data[:, start:end] = data get_rope_index_kwargs[key] = pad_data else: get_rope_index_kwargs[key] = mm_inputs.get(key, None) vision_position_ids, _ = model.get_rope_index( input_ids=input_ids, attention_mask=attention_mask, **get_rope_index_kwargs, ) # (3, 1, seq_len) vision_position_ids = vision_position_ids.squeeze(1) # (3, seq_len) text_position_ids = batch_dict["position_ids"][idx].unsqueeze(0) # (1, seq_length) position_ids = torch.cat( (text_position_ids, vision_position_ids), dim=0 ) # (4, seq_length) position_ids_list.append(position_ids) # (4, seq_length) multi_modal_inputs.append(mm_inputs) batch_dict["position_ids"] = torch.stack( position_ids_list, dim=0 ).long() # (bs, 4, seq_length) batch_dict["multi_modal_inputs"] = np.array(multi_modal_inputs, dtype=object) custom_fields_set = set(tuple(exp.custom_fields) for exp in experiences) if len(custom_fields_set) == 1: custom_fields = list(custom_fields_set)[0] for custom_field in custom_fields: batch_dict[custom_field.destination_field] = torch.tensor( [exp.info[custom_field.source_field] for exp in experiences], dtype=custom_field.data_type, ) else: raise ValueError("Custom fields are not consistent across experiences.") meta_info = { "model_versions": np.array([exp.info.get("model_version", 0) for exp in experiences]) } return DataProto.from_single_dict(batch_dict, meta_info=meta_info)
[文档] def compute_data_metrics(batch: DataProto) -> dict: """ Computes various metrics from a batch of data for PPO training. Modified from verl.trainer.ppo.metric_utils.compute_data_metrics This function calculates metrics related to scores, rewards, advantages, returns, values, and sequence lengths from a batch of data. It provides statistical information (mean, max, min) for each metric category. Args: batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. Returns: A dictionary of metrics including: - critic/score/mean, max, min: Statistics about sequence scores - critic/rewards/mean, max, min: Statistics about sequence rewards - critic/advantages/mean, max, min: Statistics about advantages - critic/returns/mean, max, min: Statistics about returns - critic/values/mean, max, min: Statistics about critic values - critic/vf_explained_var: Explained variance of the value function - response_length/mean, max, min, clip_ratio: Statistics about response lengths - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths """ metrics = {} if "token_level_rewards" in batch.batch and "token_level_scores" in batch.batch: sequence_score = batch.batch["token_level_scores"].sum(-1) sequence_reward = batch.batch["token_level_rewards"].sum(-1) metrics.update( { # score "critic/score/mean": torch.mean(sequence_score).detach().item(), "critic/score/max": torch.max(sequence_score).detach().item(), "critic/score/min": torch.min(sequence_score).detach().item(), # reward "critic/rewards/mean": torch.mean(sequence_reward).detach().item(), "critic/rewards/max": torch.max(sequence_reward).detach().item(), "critic/rewards/min": torch.min(sequence_reward).detach().item(), } ) max_response_length = batch.batch["responses"].shape[-1] prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() max_prompt_length = prompt_mask.size(-1) response_info = _compute_response_info(batch) prompt_length = response_info["prompt_length"] response_length = response_info["response_length"] metrics.update( { # response length "response_length/mean": torch.mean(response_length).detach().item(), "response_length/max": torch.max(response_length).detach().item(), "response_length/min": torch.min(response_length).detach().item(), "response_length/clip_ratio": torch.mean( torch.eq(response_length, max_response_length).float() ) .detach() .item(), # prompt length "prompt_length/mean": torch.mean(prompt_length).detach().item(), "prompt_length/max": torch.max(prompt_length).detach().item(), "prompt_length/min": torch.min(prompt_length).detach().item(), "prompt_length/clip_ratio": torch.mean( torch.eq(prompt_length, max_prompt_length).float() ) .detach() .item(), } ) if "advantages" in batch.batch: # adv advantages = batch.batch["advantages"] if response_mask.numel() > 0: valid_adv = torch.masked_select(advantages, response_mask) else: valid_adv = torch.zeros(1) metrics.update( { # adv "critic/advantages/mean": torch.mean(valid_adv).detach().item(), "critic/advantages/max": torch.max(valid_adv).detach().item(), "critic/advantages/min": torch.min(valid_adv).detach().item(), } ) if "returns" in batch.batch: # returns returns = batch.batch["returns"] if response_mask.numel() > 0: valid_returns = torch.masked_select(returns, response_mask) else: valid_returns = torch.zeros(1) metrics.update( { "critic/returns/mean": torch.mean(valid_returns).detach().item(), "critic/returns/max": torch.max(valid_returns).detach().item(), "critic/returns/min": torch.min(valid_returns).detach().item(), } ) return metrics