trinity.common.models.utils module#

trinity.common.models.utils.tokenize_and_mask_messages_hf(tokenizer: Any, messages: List[dict], tools: List[dict] | None = None, chat_template: str | None = None, enable_thinking: bool | None = None) dict[str, Tensor][源代码]#

Calculate the assistant token mask with chat_template.

参数:
  • tokenizer (Any) -- The tokenizer or processor.

  • messages (List[dict]) -- Messages with role and content fields.

  • tools (Optional[List[dict]]) -- The list of tool dictionaries.

  • chat_template (str) -- The chat template with {% generation %} symbol.

返回:

A token dictionary returned by

apply_chat_template, containing at least input_ids and assistant_masks.

返回类型:

dict[str, torch.Tensor]

trinity.common.models.utils.tokenize_and_mask_messages_default(tokenizer: Any, messages: List[dict], tools: List[dict] | None = None, chat_template: str | None = None, enable_thinking: bool | None = None) dict[str, Tensor][源代码]#

Calculate the assistant token mask.

参数:
  • tokenizer (Any) -- The tokenizer or processor.

  • messages (List[dict]) -- Messages with role and content fields.

  • tools (Optional[List[dict]]) -- The list of tool dictionaries.

  • chat_template (str) -- The chat template with {% generation %} symbol.

返回:

A token dictionary containing

input_ids and assistant_masks.

返回类型:

dict[str, torch.Tensor]

备注

This method is based on the assumption that as the number of chat rounds increases, the tokens of the previous round are exactly the prefix tokens of the next round. If the assumption is not met, the function may produce incorrect results. Please check the chat template before using this method.

trinity.common.models.utils.get_action_mask_method(chat_template: str | None = None) Callable[源代码]#

Get the action mask method according to the chat template.

参数:

chat_template (str) -- The chat template. If { % generation % } is present, use HF tokenizer's return_assistant_tokens_mask.

返回:

The action mask method.

trinity.common.models.utils.get_checkpoint_dir_with_step_num(checkpoint_root_path: str, trainer_type: str = 'verl', step_num: int | None = None, raise_error: bool = True) Tuple[str, int][源代码]#

Get the checkpoint directory from a root checkpoint directory.

参数:
  • checkpoint_root_path (str) -- The root checkpoint directory.

  • trainer_type (str) -- The trainer type. Only support "verl" for now.

  • step_num (Optional[int], optional) -- The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.

  • raise_error (bool) -- Whether to raise an error if the checkpoint does not exist.

返回:

The checkpoint directory and the step number of the checkpoint.

If the checkpoint does not exist and raise_error is False, return (None, 0).

返回类型:

Tuple[str, int]

trinity.common.models.utils.get_latest_state_dict(checkpoint_root_path: str, trainer_type: str = 'verl') Tuple[str, int][源代码]#

Get the latest state dict from a root checkpoint directory.

参数:

checkpoint_root_path (str) -- The root checkpoint directory.

返回:

The state dict path and the iteration of the state dict.

If the state dict does not exist, return (None, 0).

返回类型:

Tuple[str, int]

trinity.common.models.utils.has_huggingface_model_weights(checkpoint_path: str) bool[源代码]#

Return True when checkpoint_path contains serialized HF model weights.

trinity.common.models.utils.load_state_dict_iterator(checkpoint_dir: str) Iterator[Tuple[str, Tensor]][源代码]#

Load model state dict from a checkpoint directory as an iterator of (name, tensor) tuples.

trinity.common.models.utils.load_state_dict(checkpoint_dir: str, trust_remote_code: bool = False) dict | Tuple[str, str][源代码]#

Load model state dict from a checkpoint directory.

Auto-detects the checkpoint format from directory contents:

  1. safetensorsmodel.safetensors produced by the unified save_state_dict path. Loaded directly and returned as a dict.

  2. HuggingFace weights — detected by has_huggingface_model_weights() in either a huggingface/ subdirectory or the directory itself. Returns ("huggingface", path) for lazy loading by the caller.

  3. FSDP shardsmodel_world_size_N_rank_M.pt files. Merged via load_fsdp_state_dict_from_verl_checkpoint() and returned as a dict.

  4. Megatron dist checkpoint — fallback. Returns ("megatron", checkpoint_dir) for lazy loading via converter.

参数:

checkpoint_dir -- Path to the checkpoint directory (typically global_step_N/actor/).

返回:

Either a dict of model weights, or a (method, path) tuple indicating the format for lazy loading.

trinity.common.models.utils.get_verl_checkpoint_info(checkpoint_path: str, step_num: int | None = None, raise_error: bool = True) Tuple[str, int][源代码]#

Get the checkpoint directory from a Verl root checkpoint directory.

参数:
  • checkpoint_path (str) -- The root checkpoint directory.

  • step_num (Optional[int], optional) -- The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.

  • raise_error (bool) -- Whether to raise an error if the checkpoint does not exist.

返回:

The checkpoint directory and the step number of the checkpoint.

返回类型:

Tuple[str, int]

trinity.common.models.utils.load_fsdp_state_dict_from_verl_checkpoint(checkpoint_path: str) dict[源代码]#

Load state dict from a Verl checkpoint.

trinity.common.models.utils.load_huggingface_state_dict(checkpoint_path: str, trust_remote_code: bool = False)[源代码]#
trinity.common.models.utils.get_megatron_converter(checkpoint_path: str)[源代码]#