trinity.common.models.utils 源代码

# -*- coding: utf-8 -*-
import os
import re
from collections.abc import Iterator
from copy import deepcopy
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
import transformers

from trinity.utils.log import get_logger


def _get_common_kwargs(
    tokenizer: Any,
    *,
    tools: Optional[List[dict]] = None,
    chat_template: Optional[str] = None,
    enable_thinking: Optional[bool] = None,
):
    common_kwargs = dict(
        tools=tools,
        chat_template=chat_template,
        tokenize=True,
        return_dict=True,
        enable_thinking=enable_thinking,
    )
    text_kwargs = dict(
        padding=False,
        truncation=True,
        add_special_tokens=False,
    )
    if isinstance(tokenizer, transformers.ProcessorMixin):
        common_kwargs["processor_kwargs"] = text_kwargs
    else:
        common_kwargs.update(text_kwargs)
    return common_kwargs


[文档] def tokenize_and_mask_messages_hf( tokenizer: Any, messages: List[dict], tools: Optional[List[dict]] = None, chat_template: Optional[str] = None, enable_thinking: Optional[bool] = None, ) -> dict[str, torch.Tensor]: """Calculate the assistant token mask with `chat_template`. Args: tokenizer (Any): The tokenizer or processor. messages (List[dict]): Messages with `role` and `content` fields. tools (Optional[List[dict]]): The list of tool dictionaries. chat_template (str): The chat template with `{% generation %}` symbol. Returns: `dict[str, torch.Tensor]`: A token dictionary returned by `apply_chat_template`, containing at least `input_ids` and `assistant_masks`. """ common_kwargs = _get_common_kwargs( tokenizer, tools=tools, chat_template=chat_template, enable_thinking=enable_thinking, ) token_dict = tokenizer.apply_chat_template( messages, add_generation_prompt=False, return_assistant_tokens_mask=True, return_tensors="pt", **common_kwargs, ) token_dict.pop("attention_mask", None) # remove attention mask if exists return token_dict
[文档] def tokenize_and_mask_messages_default( tokenizer: Any, messages: List[dict], tools: Optional[List[dict]] = None, chat_template: Optional[str] = None, enable_thinking: Optional[bool] = None, ) -> dict[str, torch.Tensor]: """Calculate the assistant token mask. Args: tokenizer (Any): The tokenizer or processor. messages (List[dict]): Messages with `role` and `content` fields. tools (Optional[List[dict]]): The list of tool dictionaries. chat_template (str): The chat template with `{% generation %}` symbol. Returns: `dict[str, torch.Tensor]`: A token dictionary containing `input_ids` and `assistant_masks`. Note: This method is based on the assumption that as the number of chat rounds increases, the tokens of the previous round are exactly the prefix tokens of the next round. If the assumption is not met, the function may produce incorrect results. Please check the chat template before using this method. """ if len(messages) == 0: raise ValueError("Messages should not be empty") common_kwargs = _get_common_kwargs( tokenizer, tools=tools, chat_template=chat_template, enable_thinking=enable_thinking, ) generation_messages = [] response_messages = [] start_idx = 0 if "<think>" in (chat_template or tokenizer.chat_template): # find last user message for thinking template for idx in range(len(messages) - 1, -1, -1): message = messages[idx] if message["role"] == "user": start_idx = idx break for idx in range(start_idx, len(messages)): message = messages[idx] if message["role"] == "assistant": generation_messages.append(messages[:idx]) response_messages.append(messages[: idx + 1]) token_dict = tokenizer.apply_chat_template( messages, add_generation_prompt=False, return_tensors="pt", **deepcopy(common_kwargs), ) assistant_masks = torch.zeros_like(token_dict["input_ids"]).squeeze() if len(generation_messages) != 0: first_generation_message_empty_flag = len(generation_messages[0]) == 0 if first_generation_message_empty_flag: # the first message is from assistant, so generation_messages[0] is empty generation_messages[0] = response_messages[0] prompt_token_ids_list = tokenizer.apply_chat_template( generation_messages, add_generation_prompt=True, **deepcopy(common_kwargs), )["input_ids"] response_token_ids_list = tokenizer.apply_chat_template( response_messages, add_generation_prompt=False, **deepcopy(common_kwargs), )["input_ids"] if first_generation_message_empty_flag: # the first message is from assistant, so set the first prompt_token_ids to empty prompt_token_ids_list[0] = [] for prompt_token_ids, response_token_ids in zip( prompt_token_ids_list, response_token_ids_list ): prompt_len = len(prompt_token_ids) response_len = len(response_token_ids) assistant_masks[prompt_len:response_len] = 1 token_dict.pop("attention_mask", None) # remove attention mask if exists token_dict["assistant_masks"] = assistant_masks.unsqueeze(0) return token_dict
[文档] def get_action_mask_method(chat_template: Optional[str] = None) -> Callable: """Get the action mask method according to the chat template. Args: chat_template (str): The chat template. If { % generation % } is present, use HF tokenizer's `return_assistant_tokens_mask`. Returns: The action mask method. """ if chat_template is None: return tokenize_and_mask_messages_default # check if the chat template contains `{% generation %}` symbol elif re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template): return tokenize_and_mask_messages_hf else: return tokenize_and_mask_messages_default
[文档] def get_checkpoint_dir_with_step_num( checkpoint_root_path: str, trainer_type: str = "verl", step_num: Optional[int] = None, raise_error: bool = True, ) -> Tuple[str, int]: """Get the checkpoint directory from a root checkpoint directory. Args: checkpoint_root_path (str): The root checkpoint directory. trainer_type (str): The trainer type. Only support "verl" for now. step_num (Optional[int], optional): The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None. raise_error (bool): Whether to raise an error if the checkpoint does not exist. Returns: Tuple[str, int]: The checkpoint directory and the step number of the checkpoint. If the checkpoint does not exist and `raise_error` is False, return (None, 0). """ if trainer_type == "verl": return get_verl_checkpoint_info( checkpoint_path=checkpoint_root_path, step_num=step_num, raise_error=raise_error ) else: raise NotImplementedError(f"Unsupported trainer type {trainer_type}")
[文档] def get_latest_state_dict( checkpoint_root_path: str, trainer_type: str = "verl", ) -> Tuple[str, int]: """Get the latest state dict from a root checkpoint directory. Args: checkpoint_root_path (str): The root checkpoint directory. Returns: Tuple[str, int]: The state dict path and the iteration of the state dict. If the state dict does not exist, return (None, 0). """ if trainer_type != "verl": raise NotImplementedError(f"Unsupported trainer type {trainer_type}") latest_state_dict_iteration_path = os.path.join( checkpoint_root_path, "latest_state_dict_iteration.txt" ) if os.path.exists(latest_state_dict_iteration_path): with open(latest_state_dict_iteration_path, "r", encoding="utf-8") as f: iteration = f.read().strip() state_dict_path = os.path.join( checkpoint_root_path, f"global_step_{iteration}", "actor" ) return state_dict_path, int(iteration) return None, 0 # type: ignore
[文档] def has_huggingface_model_weights(checkpoint_path: str) -> bool: """Return True when ``checkpoint_path`` contains serialized HF model weights.""" weight_file_prefixes = ( "model.safetensors", "pytorch_model", "adapter_model", ) if not os.path.isdir(checkpoint_path): return False return any(name.startswith(weight_file_prefixes) for name in os.listdir(checkpoint_path))
[文档] def load_state_dict_iterator(checkpoint_dir: str) -> Iterator[Tuple[str, torch.Tensor]]: """Load model state dict from a checkpoint directory as an iterator of (name, tensor) tuples.""" state_dict = load_state_dict(checkpoint_dir) if isinstance(state_dict, dict): for name, tensor in state_dict.items(): yield name, tensor else: raise ValueError(f"Unsupported state dict format: {type(state_dict)}")
[文档] def load_state_dict( checkpoint_dir: str, trust_remote_code: bool = False ) -> Union[dict, Tuple[str, str]]: """Load model state dict from a checkpoint directory. Auto-detects the checkpoint format from directory contents: 1. **safetensors** — ``model.safetensors`` produced by the unified ``save_state_dict`` path. Loaded directly and returned as a dict. 2. **HuggingFace weights** — detected by :func:`has_huggingface_model_weights` in either a ``huggingface/`` subdirectory or the directory itself. Returns ``("huggingface", path)`` for lazy loading by the caller. 3. **FSDP shards** — ``model_world_size_N_rank_M.pt`` files. Merged via :func:`load_fsdp_state_dict_from_verl_checkpoint` and returned as a dict. 4. **Megatron dist checkpoint** — fallback. Returns ``("megatron", checkpoint_dir)`` for lazy loading via converter. Args: checkpoint_dir: Path to the checkpoint directory (typically ``global_step_N/actor/``). Returns: Either a ``dict`` of model weights, or a ``(method, path)`` tuple indicating the format for lazy loading. """ import glob # 1. safetensors (unified save_state_dict format) safetensors_path = os.path.join(checkpoint_dir, "model.safetensors") if os.path.exists(safetensors_path): from safetensors.torch import load_file return load_file(safetensors_path, device="cpu") # 2. HuggingFace weights in huggingface/ subdirectory huggingface_dir = os.path.join(checkpoint_dir, "huggingface") if has_huggingface_model_weights(huggingface_dir): return load_huggingface_state_dict(huggingface_dir, trust_remote_code=trust_remote_code) # 3. FSDP shards → merge if glob.glob(os.path.join(checkpoint_dir, "model_world_size_*_rank_*.pt")): return load_fsdp_state_dict_from_verl_checkpoint(checkpoint_dir) # 4. Megatron dist_ckpt (fallback) return get_megatron_converter(checkpoint_dir).get_state_dict(checkpoint_dir)
[文档] def get_verl_checkpoint_info( checkpoint_path: str, step_num: Optional[int] = None, raise_error: bool = True ) -> Tuple[str, int]: """Get the checkpoint directory from a Verl root checkpoint directory. Args: checkpoint_path (str): The root checkpoint directory. step_num (Optional[int], optional): The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None. raise_error (bool): Whether to raise an error if the checkpoint does not exist. Returns: Tuple[str, int]: The checkpoint directory and the step number of the checkpoint. """ if step_num is None: # load latest checkpoint iteration_file = os.path.join(checkpoint_path, "latest_checkpointed_iteration.txt") if os.path.exists(iteration_file): with open( iteration_file, "r", encoding="utf-8" ) as f: # TODO: this file may be modified simultaneously iteration = f.read().strip() return os.path.join(checkpoint_path, f"global_step_{iteration}"), int(iteration) elif raise_error: raise FileNotFoundError(f"No iteration file found in {checkpoint_path}") else: return None, 0 # type: ignore else: # load specific iteration checkpoint path = os.path.join(checkpoint_path, f"global_step_{step_num}") if not os.path.exists(path) and raise_error: raise FileNotFoundError(f"Checkpoint {path} not found") return path, step_num
# modified from verl/model_merger/fsdp_model_merger.py def _infer_world_size_from_checkpoint(checkpoint_path: str) -> int: """Infer FSDP world_size from shard filenames in *checkpoint_path*. The sharded state dicts are named ``model_world_size_{N}_rank_{M}.pt``. We glob for rank-0 files and extract *N*. This avoids depending on ``fsdp_config.json`` which ``save_state_dict`` (weight-sync shortcut) does not produce. """ import glob import re pattern = os.path.join(checkpoint_path, "model_world_size_*_rank_0.pt") matches = glob.glob(pattern) if not matches: raise FileNotFoundError( f"No FSDP shard files matching {pattern} found in {checkpoint_path}" ) # Extract world_size from the first (and usually only) match. m = re.search(r"model_world_size_(\d+)_rank_0\.pt$", matches[0]) if m is None: raise ValueError(f"Cannot parse world_size from filename: {matches[0]}") return int(m.group(1))
[文档] def load_fsdp_state_dict_from_verl_checkpoint(checkpoint_path: str) -> dict: # noqa: C901 """Load state dict from a Verl checkpoint.""" from verl.model_merger.base_model_merger import ModelMergerConfig from verl.model_merger.fsdp_model_merger import FSDPModelMerger logger = get_logger(__name__) config = ModelMergerConfig( operation="merge", backend="fsdp", trust_remote_code=False, is_value_model=False, local_dir=checkpoint_path, hf_model_config_path=os.path.join(checkpoint_path, "huggingface"), ) merger = FSDPModelMerger(config) # Prefer fsdp_config.json (written by full checkpoints), fall back to # inferring from shard filenames (weight-sync state dicts). try: world_size = merger._get_world_size() except FileNotFoundError: world_size = _infer_world_size_from_checkpoint(checkpoint_path) logger.info(f"Inferred world_size={world_size} from shard filenames") rank_zero_state_dict = merger._load_rank_zero_state_dict(world_size) mesh, mesh_dim_names = merger._extract_device_mesh_info(rank_zero_state_dict, world_size) logger.info(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") total_shards, mesh_shape = merger._calculate_shard_configuration(mesh, mesh_dim_names) logger.info(f"Processing model shards with {total_shards} {mesh_shape} in total") merged_state_dict = merger._load_and_merge_state_dicts( world_size, total_shards, mesh_shape, mesh_dim_names ) return merged_state_dict
[文档] def load_huggingface_state_dict(checkpoint_path: str, trust_remote_code: bool = False): import transformers from verl.utils.model import get_hf_auto_model_class model_config = transformers.AutoConfig.from_pretrained( checkpoint_path, trust_remote_code=trust_remote_code, ) auto_model_cls = get_hf_auto_model_class(model_config) model = auto_model_cls.from_pretrained( checkpoint_path, trust_remote_code=trust_remote_code, ) return model.state_dict()
[文档] def get_megatron_converter(checkpoint_path: str): import builtins from contextlib import contextmanager from verl.model_merger.base_model_merger import ModelMergerConfig from verl.model_merger.megatron_model_merger import MegatronModelMerger from trinity.trainer.verl_legacy.utils import patch_rope_theta_in_hf_config # modified from verl/model_merger/megatron_model_merger.py class MegatronStateDictConverter(MegatronModelMerger): def __init__(self, config: ModelMergerConfig): # Patch Megatron-Core ModelType enum compatibility: # newer mcore renamed encoder_and_decoder → encoder_or_decoder, # but verl's get_model() still references the old name. from megatron.core.enums import ModelType if not hasattr(ModelType, "encoder_and_decoder"): ModelType.encoder_and_decoder = getattr(ModelType, "encoder_or_decoder", None) original_init_process_group = torch.distributed.init_process_group original_get_rank = torch.distributed.get_rank original_get_world_size = torch.distributed.get_world_size torch.distributed.init_process_group = lambda *args, **kwargs: None torch.distributed.get_rank = lambda: 0 torch.distributed.get_world_size = lambda: 1 self.logger = get_logger(__name__) with self._redirect_print_to_logger(): super().__init__(config) torch.distributed.init_process_group = original_init_process_group torch.distributed.get_rank = original_get_rank torch.distributed.get_world_size = original_get_world_size # start of patch for verl to support transformers v5 patch_rope_theta_in_hf_config(self.hf_config) # end of patch for verl to support transformers v5 @contextmanager def _redirect_print_to_logger(self): original_print = builtins.print def logger_print(*args, **kwargs): message = " ".join(str(arg) for arg in args) self.logger.debug(message) builtins.print = logger_print try: yield finally: builtins.print = original_print def get_state_dict(self, checkpoint_path): self.config.local_dir = checkpoint_path from verl.utils.megatron_utils import get_dist_checkpoint_path with self._redirect_print_to_logger(): model_ckpt_path = get_dist_checkpoint_path(self.config.local_dir) model_state_dict = self._load_state_dicts(model_ckpt_path) merged_state_dict = self._merge_state_dicts(model_state_dict) del model_state_dict return merged_state_dict config = ModelMergerConfig( operation="merge", backend="megatron", tie_word_embedding=False, trust_remote_code=False, is_value_model=False, local_dir=checkpoint_path, hf_model_config_path=os.path.join(checkpoint_path, "huggingface"), use_cpu_initialization=True, ) converter = MegatronStateDictConverter(config) return converter