trinity.common.experience 源代码

# -*- coding: utf-8 -*-
"""Experience Class."""
from __future__ import annotations

import pickle
import struct
import uuid
from dataclasses import asdict, dataclass, field, fields
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union

import torch
from safetensors.torch import load as st_load
from safetensors.torch import save as st_save
from torch import Tensor

if TYPE_CHECKING:
    from datasets import Dataset


[文档] @dataclass class EID: """Experience ID class to uniquely identify an experience. To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows. """ # TODO: do we need to add project/name here to make it unique across different projects? # Batch number, e.g., the explorer step num # Automatically set by the workflow runner batch: Union[int, str] = "" # Task number, e.g., the task sequence in the batch, the first task in the batch has task=0 # Automatically set by the workflow runner task: Union[int, str] = "" # Run id, e.g., the first run in the task has run=0 # User should set this field in custom workflows when creating experiences run: int = 0 # Step number when running the task, e.g., the first step in the task has step=0 # User should set this field in custom workflows when creating experiences step: int = 0 suffix: str = field( default_factory=lambda: uuid.uuid4().hex[:6] ) # Unique identifier suffix, e.g., a UUID @property def uid(self) -> str: """An unique identifier for the experience.""" return f"{self.batch}/{self.task}/{self.run}/{self.step}/{self.suffix}" @property def sid(self) -> str: """Step ID of the experience. For example, experiences generated by all runs of a same task at the same step will have the same sid. """ return f"{self.batch}/{self.task}/{self.step}" @property def rid(self) -> str: """Run ID of the experience. For example, experiences generated by one run of a task at all steps will have the same run_id. """ return f"{self.batch}/{self.task}/{self.run}" @property def tid(self) -> str: """Task ID for the experience. For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid. """ return f"{self.batch}/{self.task}" def __str__(self): return self.uid def __repr__(self): return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})"
[文档] def to_dict(self) -> dict: """Convert the EID to a dictionary.""" return { "batch": self.batch, "task": self.task, "run": self.run, "step": self.step, "suffix": self.suffix, }
[文档] @dataclass(frozen=True) class CustomField: """Custom field for Experiences. This is used to store additional information into the Experiences class. """ source_field: str # The source field name in the Experience.info destination_field: str # The destination field name in the Experiences class data_type: torch.dtype # The data type of the field, e.g., torch.float32, torch.int64, etc.
[文档] @dataclass class Experience: _SER_MAGIC = b"TRXP" _SER_VERSION = 1 _TENSOR_FIELDS = ( "tokens", "logprobs", "token_level_reward", "advantages", "returns", "action_mask", "chosen", "rejected", "teacher_logprobs", ) _META_FIELDS = ( "eid", "reward", "truncate_status", "info", "metrics", "prompt_length", "response_text", "prompt_text", "messages", "tools", "chosen_messages", "rejected_messages", ) eid: EID = field(default_factory=EID) # Unique identifier for the experience tokens: Optional[Tensor] = None # [seq_length] prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks logprobs: Optional[Tensor] = None # [resp_length] reward: Optional[float] = None token_level_reward: Optional[Tensor] = None # [resp_length] advantages: Optional[Tensor] = None # [resp_length] returns: Optional[Tensor] = None # [resp_length] truncate_status: Optional[ str ] = None # The status of truncation, e.g., "prompt_truncated", "response_truncated"; Not working for openai api info: dict = field( default_factory=dict ) # Additional information about the experience, can also be used to store custom fields metrics: dict[str, float] = field( default_factory=dict ) # Metrics associated with the experience, directly used by the monitor # for single-turn experiences response_text: Optional[str] = None # Text of the response prompt_text: Optional[str] = None # Text of the prompt # for multi-turn experiences # Action mask indicates which tokens are generated by the model action_mask: Optional[Tensor] = None # [resp_length] messages: Optional[List[dict]] = None # List of messages tools: Optional[List[dict]] = None # for dpo experiences chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length] rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length] chosen_messages: Optional[List[dict]] = None # Chosen message list (Include prompt message) rejected_messages: Optional[List[dict]] = None # Rejected message list (Include prompt message) # for multi-modal data multi_modal_inputs: Optional[Dict[str, Tensor]] = None # Multi-modal inputs for verl trainer # for on-policy distillation teacher_logprobs: Optional[Tensor] = None # [resp_length] custom_fields: List[CustomField] = field(default_factory=list)
[文档] def __init__( # noqa: C901 self, *, eid=None, tokens, logprobs=None, reward=None, token_level_reward=None, advantages=None, returns=None, truncate_status=None, info=None, metrics=None, prompt_length=1, response_text=None, prompt_text=None, action_mask=None, messages=None, tools=None, chosen=None, rejected=None, chosen_messages=None, rejected_messages=None, multi_modal_inputs=None, teacher_logprobs=None, custom_fields=None, ): if action_mask is not None: experience_type = "multi_turn" elif chosen is not None and rejected is not None: experience_type = "dpo" else: experience_type = "single_turn" if experience_type == "single_turn": assert ( prompt_length > 0 ), "Prompt length must be greater than 0 for single-turn experiences." if truncate_status != "prompt_truncated": assert ( len(tokens) > prompt_length ), f"Token ids must be larger than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}." action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool) else: action_mask = torch.zeros(len(logprobs), dtype=torch.bool) elif experience_type == "dpo": prompt_length = len(tokens) if eid is None: self.eid = EID() elif isinstance(eid, dict): self.eid = EID(**eid) else: self.eid = eid if isinstance(tokens, list): tokens = torch.tensor(tokens, dtype=torch.int32) self.tokens = tokens if isinstance(logprobs, list): logprobs = torch.tensor(logprobs, dtype=torch.float32) self.logprobs = logprobs self.reward = reward if isinstance(token_level_reward, list): token_level_reward = torch.tensor(token_level_reward, dtype=torch.float32) self.token_level_reward = token_level_reward if isinstance(advantages, list): advantages = torch.tensor(advantages, dtype=torch.float32) self.advantages = advantages if isinstance(returns, list): returns = torch.tensor(returns, dtype=torch.float32) self.returns = returns self.experience_type = experience_type self.info = info or {} self.metrics = metrics or {} self.truncate_status = truncate_status self.prompt_length = prompt_length self.response_text = response_text self.prompt_text = prompt_text if isinstance(action_mask, list): action_mask = torch.tensor(action_mask, dtype=torch.bool) self.action_mask = action_mask self.messages = messages self.tools = tools if isinstance(chosen, list): chosen = torch.tensor(chosen, dtype=torch.int32) self.chosen = chosen if isinstance(rejected, list): rejected = torch.tensor(rejected, dtype=torch.int32) self.rejected = rejected self.chosen_messages = chosen_messages self.rejected_messages = rejected_messages self.multi_modal_inputs = multi_modal_inputs if multi_modal_inputs is not None: self.multi_modal_inputs = {} for key, value in multi_modal_inputs.items(): if not isinstance(value, Tensor): self.multi_modal_inputs[key] = torch.tensor(value) else: self.multi_modal_inputs[key] = value # Handle teacher_logprobs if isinstance(teacher_logprobs, list): teacher_logprobs = torch.tensor(teacher_logprobs, dtype=torch.float32) self.teacher_logprobs = teacher_logprobs if not isinstance(self.tokens, Tensor): self.tokens = torch.tensor(self.tokens) if self.logprobs is not None and not isinstance(self.logprobs, Tensor): self.logprobs = torch.tensor(self.logprobs) if self.action_mask is not None and not isinstance(self.action_mask, Tensor): self.action_mask = torch.tensor(self.action_mask) if self.chosen is not None and not isinstance(self.chosen, Tensor): self.chosen = torch.tensor(self.chosen) if self.rejected is not None and not isinstance(self.rejected, Tensor): self.rejected = torch.tensor(self.rejected) if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor): self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32) self.custom_fields = custom_fields or []
[文档] def serialize(self) -> bytes: """Serialize the experience to bytes.""" return self.serialize_many([self])
[文档] @classmethod def deserialize(cls, data: bytes) -> Experience: experiences = cls.deserialize_many(data) if len(experiences) != 1: raise ValueError( f"Expected a single Experience payload, got batch size {len(experiences)}. " "Use Experience.deserialize_many for batched payloads." ) return experiences[0]
@staticmethod def _serialize_custom_fields(custom_fields: Optional[List[CustomField]]) -> list[dict]: if not custom_fields: return [] return [ { "source_field": field.source_field, "destination_field": field.destination_field, "data_type": str(field.data_type), } for field in custom_fields ] @staticmethod def _deserialize_custom_fields(serialized_fields: Optional[List[dict]]) -> List[CustomField]: if not serialized_fields: return [] custom_fields = [] for field_dict in serialized_fields: dtype_name = field_dict["data_type"].replace("torch.", "") dtype = getattr(torch, dtype_name) custom_fields.append( CustomField( source_field=field_dict["source_field"], destination_field=field_dict["destination_field"], data_type=dtype, ) ) return custom_fields
[文档] @classmethod def serialize_many(cls, experiences: List[Experience]) -> bytes: """Serialize a list of experiences into a compact bytes payload. Tensor fields are packed with safetensors while non-tensor fields are packed as metadata via pickle. """ metadata = {"version": cls._SER_VERSION, "num_items": len(experiences), "items": []} tensor_data = {} for index, exp in enumerate(experiences): item_meta = {} for field_name in cls._META_FIELDS: value = getattr(exp, field_name) if field_name == "eid" and value is not None: item_meta[field_name] = value.to_dict() if isinstance(value, EID) else value else: item_meta[field_name] = value item_meta["custom_fields"] = cls._serialize_custom_fields(exp.custom_fields) for field_name in cls._TENSOR_FIELDS: value = getattr(exp, field_name) if value is None: continue tensor_data[f"{index}:{field_name}"] = value.detach().cpu().contiguous() if exp.multi_modal_inputs is None: item_meta["multi_modal_input_keys"] = [] else: mm_keys = list(exp.multi_modal_inputs.keys()) item_meta["multi_modal_input_keys"] = mm_keys for key in mm_keys: value = exp.multi_modal_inputs[key] tensor_data[f"{index}:multi_modal_inputs:{key}"] = ( value.detach() .cpu() .contiguous() .clone() # clone to avoid shared memory issues ) metadata["items"].append(item_meta) metadata_bytes = pickle.dumps(metadata, protocol=pickle.HIGHEST_PROTOCOL) tensor_bytes = st_save(tensor_data) header = ( cls._SER_MAGIC + struct.pack("<B", cls._SER_VERSION) + struct.pack("<Q", len(metadata_bytes)) + struct.pack("<Q", len(tensor_bytes)) ) return header + metadata_bytes + tensor_bytes
[文档] @classmethod def deserialize_many(cls, data: bytes) -> List[Experience]: """Deserialize bytes into a list of experiences. Supports both new batched payloads and legacy single-experience pickle payloads. """ if not data.startswith(cls._SER_MAGIC): legacy = pickle.loads(data) if isinstance(legacy, list): return legacy return [legacy] offset = len(cls._SER_MAGIC) version = struct.unpack("<B", data[offset : offset + 1])[0] offset += 1 if version != cls._SER_VERSION: raise ValueError( f"Unsupported Experience serialization version: {version}, expected {cls._SER_VERSION}." ) metadata_len = struct.unpack("<Q", data[offset : offset + 8])[0] offset += 8 tensor_len = struct.unpack("<Q", data[offset : offset + 8])[0] offset += 8 metadata_bytes = data[offset : offset + metadata_len] offset += metadata_len tensor_bytes = data[offset : offset + tensor_len] metadata = pickle.loads(metadata_bytes) tensor_data = st_load(tensor_bytes) experiences = [] for index, item_meta in enumerate(metadata["items"]): init_kwargs = { "eid": item_meta.get("eid"), "reward": item_meta.get("reward"), "truncate_status": item_meta.get("truncate_status"), "info": item_meta.get("info"), "metrics": item_meta.get("metrics"), "prompt_length": item_meta.get("prompt_length", 1), "response_text": item_meta.get("response_text"), "prompt_text": item_meta.get("prompt_text"), "messages": item_meta.get("messages"), "tools": item_meta.get("tools"), "chosen_messages": item_meta.get("chosen_messages"), "rejected_messages": item_meta.get("rejected_messages"), "custom_fields": cls._deserialize_custom_fields(item_meta.get("custom_fields")), } for field_name in cls._TENSOR_FIELDS: tensor_key = f"{index}:{field_name}" init_kwargs[field_name] = tensor_data.get(tensor_key) mm_keys = item_meta.get("multi_modal_input_keys", []) if mm_keys: init_kwargs["multi_modal_inputs"] = { key: tensor_data[f"{index}:multi_modal_inputs:{key}"] for key in mm_keys } else: init_kwargs["multi_modal_inputs"] = None experiences.append(cls(**init_kwargs)) return experiences
[文档] def to_dict(self) -> dict: """Convert the experience to a dictionary.""" res = { "eid": self.eid.to_dict(), "type": self.experience_type, "prompt_length": self.prompt_length, "response_length": len(self.tokens) - self.prompt_length, # type: ignore [arg-type] "info": self.info, "metrics": self.metrics, } if self.prompt_text is not None: res["prompt_text"] = self.prompt_text if self.response_text is not None: res["response_text"] = self.response_text if self.messages is not None: res["messages"] = self.messages if self.tools is not None: res["tools"] = self.tools if self.chosen_messages is not None: res["chosen_messages"] = self.chosen_messages if self.rejected_messages is not None: res["rejected_messages"] = self.rejected_messages if self.reward is not None: res["reward"] = float(self.reward) if self.truncate_status is not None: res["truncate_status"] = self.truncate_status return res
[文档] def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]: single_turn_experiences = [] for exp in experiences: single_turn_experiences.append( Experience( eid=EID( batch=exp.eid.batch, task=exp.eid.task, step=exp.eid.step, run=exp.eid.run, ), tokens=torch.cat([exp.tokens, exp.chosen]), reward=exp.reward, info=exp.info, metrics=exp.metrics, prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, messages=exp.chosen_messages, ) ) single_turn_experiences.append( Experience( eid=EID( batch=exp.eid.batch, task=exp.eid.task, step=exp.eid.step, run=exp.eid.run, ), tokens=torch.cat([exp.tokens, exp.rejected]), reward=exp.reward, info=exp.info, metrics=exp.metrics, prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, messages=exp.rejected_messages, ) ) return single_turn_experiences
[文档] def gather_token_ids( experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int ) -> Tensor: token_ids_dtype = experiences[0].tokens.dtype return torch.stack( [ torch.cat( [ torch.full( (max_prompt_length - exp.prompt_length,), pad_token_id, dtype=token_ids_dtype, ), exp.tokens, torch.full( (max_response_length + exp.prompt_length - len(exp.tokens),), pad_token_id, dtype=token_ids_dtype, ), ] ) for exp in experiences ] )
[文档] def gather_action_masks(experiences, max_response_length: int) -> Tensor: return torch.stack( [ torch.cat( [ exp.action_mask, torch.full( (max_response_length - len(exp.action_mask),), 0, dtype=torch.bool, ), ] ) for exp in experiences ] )
[文档] def gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: attention_masks = torch.zeros( (len(experiences), max_prompt_length + max_response_length), dtype=torch.bool ) for i, exp in enumerate(experiences): start = max_prompt_length - exp.prompt_length end = start + len(exp.tokens) attention_masks[i, start:end] = 1 return attention_masks
[文档] def gather_response_attrs( experiences, attr_name: str, max_response_length: int, pad_value: int = 0 ) -> Tensor: dtype = getattr(experiences[0], attr_name).dtype pad_value = torch.tensor(pad_value, dtype=dtype) return torch.stack( [ torch.cat( [ getattr(exp, attr_name), torch.full( (max_response_length - len(getattr(exp, attr_name)),), pad_value, dtype=dtype, ), ] ) for exp in experiences ] )
[文档] def gather_multi_modal_inputs(experiences) -> Dict[str, Tensor]: keys = experiences[0].multi_modal_inputs.keys() return {key: [exp.multi_modal_inputs[key] for exp in experiences] for key in keys}
[文档] def group_by( experiences: List[Experience], id_type: Literal["task", "run", "step"] ) -> Dict[str, List[Experience]]: """Group experiences by ID.""" if id_type == "task": id_type = "tid" elif id_type == "run": id_type = "rid" elif id_type == "step": id_type = "sid" else: raise ValueError(f"Unknown id_type: {id_type}") grouped = {} for exp in experiences: group_id = getattr(exp.eid, id_type) if group_id not in grouped: grouped[group_id] = [] grouped[group_id].append(exp) return grouped
[文档] def to_hf_datasets(experiences: list[Experience]) -> "Dataset": """ Convert a list of Experience objects to a HuggingFace Dataset, preserving all fields. """ from datasets import Dataset return Dataset.from_list([asdict(exp) for exp in experiences])
[文档] def from_hf_datasets(dataset: "Dataset") -> List[Experience]: """ Convert a HuggingFace Dataset back to a list of Experience objects. """ def dict_to_dataclass(cls, d): valid_keys = {f.name for f in fields(cls)} filtered = {k: v for k, v in d.items() if k in valid_keys} return cls(**filtered) experiences = [dict_to_dataclass(Experience, row) for row in dataset.to_list()] return experiences