trinity.common.models.utils module#
- trinity.common.models.utils.tokenize_and_mask_messages_hf(tokenizer: Any, messages: List[dict], tools: List[dict] | None = None, chat_template: str | None = None, enable_thinking: bool | None = None) dict[str, Tensor][源代码]#
Calculate the assistant token mask with chat_template.
- 参数:
tokenizer (Any) -- The tokenizer or processor.
messages (List[dict]) -- Messages with role and content fields.
tools (Optional[List[dict]]) -- The list of tool dictionaries.
chat_template (str) -- The chat template with {% generation %} symbol.
- 返回:
- A token dictionary returned by
apply_chat_template, containing at least input_ids and assistant_masks.
- 返回类型:
dict[str, torch.Tensor]
- trinity.common.models.utils.tokenize_and_mask_messages_default(tokenizer: Any, messages: List[dict], tools: List[dict] | None = None, chat_template: str | None = None, enable_thinking: bool | None = None) dict[str, Tensor][源代码]#
Calculate the assistant token mask.
- 参数:
tokenizer (Any) -- The tokenizer or processor.
messages (List[dict]) -- Messages with role and content fields.
tools (Optional[List[dict]]) -- The list of tool dictionaries.
chat_template (str) -- The chat template with {% generation %} symbol.
- 返回:
- A token dictionary containing
input_ids and assistant_masks.
- 返回类型:
dict[str, torch.Tensor]
备注
This method is based on the assumption that as the number of chat rounds increases, the tokens of the previous round are exactly the prefix tokens of the next round. If the assumption is not met, the function may produce incorrect results. Please check the chat template before using this method.
- trinity.common.models.utils.get_action_mask_method(chat_template: str | None = None) Callable[源代码]#
Get the action mask method according to the chat template.
- 参数:
chat_template (str) -- The chat template. If { % generation % } is present, use HF tokenizer's return_assistant_tokens_mask.
- 返回:
The action mask method.
- trinity.common.models.utils.get_checkpoint_dir_with_step_num(checkpoint_root_path: str, trainer_type: str = 'verl', step_num: int | None = None, raise_error: bool = True) Tuple[str, int][源代码]#
Get the checkpoint directory from a root checkpoint directory.
- 参数:
checkpoint_root_path (str) -- The root checkpoint directory.
trainer_type (str) -- The trainer type. Only support "verl" for now.
step_num (Optional[int], optional) -- The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.
raise_error (bool) -- Whether to raise an error if the checkpoint does not exist.
- 返回:
- The checkpoint directory and the step number of the checkpoint.
If the checkpoint does not exist and raise_error is False, return (None, 0).
- 返回类型:
Tuple[str, int]
- trinity.common.models.utils.get_latest_state_dict(checkpoint_root_path: str, trainer_type: str = 'verl') Tuple[str, int][源代码]#
Get the latest state dict from a root checkpoint directory.
- 参数:
checkpoint_root_path (str) -- The root checkpoint directory.
- 返回:
- The state dict path and the iteration of the state dict.
If the state dict does not exist, return (None, 0).
- 返回类型:
Tuple[str, int]
- trinity.common.models.utils.has_huggingface_model_weights(checkpoint_path: str) bool[源代码]#
Return True when
checkpoint_pathcontains serialized HF model weights.
- trinity.common.models.utils.load_state_dict_iterator(checkpoint_dir: str) Iterator[Tuple[str, Tensor]][源代码]#
Load model state dict from a checkpoint directory as an iterator of (name, tensor) tuples.
- trinity.common.models.utils.load_state_dict(checkpoint_dir: str, trust_remote_code: bool = False) dict | Tuple[str, str][源代码]#
Load model state dict from a checkpoint directory.
Auto-detects the checkpoint format from directory contents:
safetensors —
model.safetensorsproduced by the unifiedsave_state_dictpath. Loaded directly and returned as a dict.HuggingFace weights — detected by
has_huggingface_model_weights()in either ahuggingface/subdirectory or the directory itself. Returns("huggingface", path)for lazy loading by the caller.FSDP shards —
model_world_size_N_rank_M.ptfiles. Merged viaload_fsdp_state_dict_from_verl_checkpoint()and returned as a dict.Megatron dist checkpoint — fallback. Returns
("megatron", checkpoint_dir)for lazy loading via converter.
- 参数:
checkpoint_dir -- Path to the checkpoint directory (typically
global_step_N/actor/).- 返回:
Either a
dictof model weights, or a(method, path)tuple indicating the format for lazy loading.
- trinity.common.models.utils.get_verl_checkpoint_info(checkpoint_path: str, step_num: int | None = None, raise_error: bool = True) Tuple[str, int][源代码]#
Get the checkpoint directory from a Verl root checkpoint directory.
- 参数:
checkpoint_path (str) -- The root checkpoint directory.
step_num (Optional[int], optional) -- The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.
raise_error (bool) -- Whether to raise an error if the checkpoint does not exist.
- 返回:
The checkpoint directory and the step number of the checkpoint.
- 返回类型:
Tuple[str, int]
- trinity.common.models.utils.load_fsdp_state_dict_from_verl_checkpoint(checkpoint_path: str) dict[源代码]#
Load state dict from a Verl checkpoint.