Source code for trinity.common.experience

# -*- coding: utf-8 -*-
"""Experience Class."""
from __future__ import annotations

import pickle
import uuid
from dataclasses import asdict, dataclass, field, fields
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union

import torch
from torch import Tensor

if TYPE_CHECKING:
    from datasets import Dataset


[docs] @dataclass class EID: """Experience ID class to uniquely identify an experience. To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows. """ # TODO: do we need to add project/name here to make it unique across different projects? # Batch number, e.g., the explorer step num # Automatically set by the workflow runner batch: Union[int, str] = "" # Task number, e.g., the task sequence in the batch, the first task in the batch has task=0 # Automatically set by the workflow runner task: Union[int, str] = "" # Run id, e.g., the first run in the task has run=0 # User should set this field in custom workflows when creating experiences run: int = 0 # Step number when running the task, e.g., the first step in the task has step=0 # User should set this field in custom workflows when creating experiences step: int = 0 suffix: str = field( default_factory=lambda: uuid.uuid4().hex[:6] ) # Unique identifier suffix, e.g., a UUID @property def uid(self) -> str: """An unique identifier for the experience.""" return f"{self.batch}/{self.task}/{self.run}/{self.step}/{self.suffix}" @property def sid(self) -> str: """Step ID of the experience. For example, experiences generated by all runs of a same task at the same step will have the same sid. """ return f"{self.batch}/{self.task}/{self.step}" @property def rid(self) -> str: """Run ID of the experience. For example, experiences generated by one run of a task at all steps will have the same run_id. """ return f"{self.batch}/{self.task}/{self.run}" @property def tid(self) -> str: """Task ID for the experience. For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid. """ return f"{self.batch}/{self.task}" def __str__(self): return self.uid def __repr__(self): return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})"
[docs] def to_dict(self) -> dict: """Convert the EID to a dictionary.""" return { "batch": self.batch, "task": self.task, "run": self.run, "step": self.step, "suffix": self.suffix, }
[docs] @dataclass(frozen=True) class CustomField: """Custom field for Experiences. This is used to store additional information into the Experiences class. """ source_field: str # The source field name in the Experience.info destination_field: str # The destination field name in the Experiences class data_type: torch.dtype # The data type of the field, e.g., torch.float32, torch.int64, etc.
[docs] @dataclass class Experience: eid: EID = field(default_factory=EID) # Unique identifier for the experience tokens: Optional[Tensor] = None # [seq_length] prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks logprobs: Optional[Tensor] = None # [resp_length] reward: Optional[float] = None token_level_reward: Optional[Tensor] = None # [resp_length] advantages: Optional[Tensor] = None # [resp_length] returns: Optional[Tensor] = None # [resp_length] truncate_status: Optional[ str ] = None # The status of truncation, e.g., "prompt_truncated", "response_truncated"; Not working for openai api info: dict = field( default_factory=dict ) # Additional information about the experience, can also be used to store custom fields metrics: dict[str, float] = field( default_factory=dict ) # Metrics associated with the experience, directly used by the monitor # for single-turn experiences response_text: Optional[str] = None # Text of the response prompt_text: Optional[str] = None # Text of the prompt # for multi-turn experiences # Action mask indicates which tokens are generated by the model action_mask: Optional[Tensor] = None # [resp_length] messages: Optional[List[dict]] = None # List of messages tools: Optional[List[dict]] = None # for dpo experiences chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length] rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length] chosen_messages: Optional[List[dict]] = None # Chosen message list (Include prompt message) rejected_messages: Optional[List[dict]] = None # Rejected message list (Include prompt message) # for multi-modal data multi_modal_inputs: Optional[Dict[str, Tensor]] = None # Multi-modal inputs for verl trainer # for on-policy distillation teacher_logprobs: Optional[Tensor] = None # [resp_length] custom_fields: List[CustomField] = field(default_factory=list)
[docs] def __init__( # noqa: C901 self, *, eid=None, tokens, logprobs=None, reward=None, token_level_reward=None, advantages=None, returns=None, truncate_status=None, info=None, metrics=None, prompt_length=1, response_text=None, prompt_text=None, action_mask=None, messages=None, tools=None, chosen=None, rejected=None, chosen_messages=None, rejected_messages=None, multi_modal_inputs=None, teacher_logprobs=None, custom_fields=None, ): if action_mask is not None: experience_type = "multi_turn" elif chosen is not None and rejected is not None: experience_type = "dpo" else: experience_type = "single_turn" if experience_type == "single_turn": assert ( prompt_length > 0 ), "Prompt length must be greater than 0 for single-turn experiences." if truncate_status != "prompt_truncated": assert ( len(tokens) > prompt_length ), f"Token ids must be larger than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}." action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool) else: action_mask = torch.zeros(len(logprobs), dtype=torch.bool) elif experience_type == "dpo": prompt_length = len(tokens) if eid is None: self.eid = EID() elif isinstance(eid, dict): self.eid = EID(**eid) else: self.eid = eid if isinstance(tokens, list): tokens = torch.tensor(tokens, dtype=torch.int32) self.tokens = tokens if isinstance(logprobs, list): logprobs = torch.tensor(logprobs, dtype=torch.float32) self.logprobs = logprobs self.reward = reward if isinstance(token_level_reward, list): token_level_reward = torch.tensor(token_level_reward, dtype=torch.float32) self.token_level_reward = token_level_reward if isinstance(advantages, list): advantages = torch.tensor(advantages, dtype=torch.float32) self.advantages = advantages if isinstance(returns, list): returns = torch.tensor(returns, dtype=torch.float32) self.returns = returns self.experience_type = experience_type self.info = info or {} self.metrics = metrics or {} self.truncate_status = truncate_status self.prompt_length = prompt_length self.response_text = response_text self.prompt_text = prompt_text if isinstance(action_mask, list): action_mask = torch.tensor(action_mask, dtype=torch.bool) self.action_mask = action_mask self.messages = messages self.tools = tools if isinstance(chosen, list): chosen = torch.tensor(chosen, dtype=torch.int32) self.chosen = chosen if isinstance(rejected, list): rejected = torch.tensor(rejected, dtype=torch.int32) self.rejected = rejected self.chosen_messages = chosen_messages self.rejected_messages = rejected_messages self.multi_modal_inputs = multi_modal_inputs if multi_modal_inputs is not None: self.multi_modal_inputs = {} for key, value in multi_modal_inputs.items(): if not isinstance(value, Tensor): self.multi_modal_inputs[key] = torch.tensor(value) else: self.multi_modal_inputs[key] = value # Handle teacher_logprobs if isinstance(teacher_logprobs, list): teacher_logprobs = torch.tensor(teacher_logprobs, dtype=torch.float32) self.teacher_logprobs = teacher_logprobs if not isinstance(self.tokens, Tensor): self.tokens = torch.tensor(self.tokens) if self.logprobs is not None and not isinstance(self.logprobs, Tensor): self.logprobs = torch.tensor(self.logprobs) if self.action_mask is not None and not isinstance(self.action_mask, Tensor): self.action_mask = torch.tensor(self.action_mask) if self.chosen is not None and not isinstance(self.chosen, Tensor): self.chosen = torch.tensor(self.chosen) if self.rejected is not None and not isinstance(self.rejected, Tensor): self.rejected = torch.tensor(self.rejected) if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor): self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32) self.custom_fields = custom_fields or []
[docs] def serialize(self) -> bytes: """Serialize the experience to bytes.""" return pickle.dumps(self)
[docs] @classmethod def deserialize(cls, data: bytes) -> Experience: return pickle.loads(data)
[docs] def to_dict(self) -> dict: """Convert the experience to a dictionary.""" res = { "eid": self.eid.to_dict(), "type": self.experience_type, "prompt_length": self.prompt_length, "response_length": len(self.tokens) - self.prompt_length, # type: ignore [arg-type] "info": self.info, "metrics": self.metrics, } if self.prompt_text is not None: res["prompt_text"] = self.prompt_text if self.response_text is not None: res["response_text"] = self.response_text if self.messages is not None: res["messages"] = self.messages if self.tools is not None: res["tools"] = self.tools if self.chosen_messages is not None: res["chosen_messages"] = self.chosen_messages if self.rejected_messages is not None: res["rejected_messages"] = self.rejected_messages if self.reward is not None: res["reward"] = float(self.reward) if self.truncate_status is not None: res["truncate_status"] = self.truncate_status return res
[docs] @classmethod def gather( cls, experiences: List[Experience], pad_token_id: int = 0, custom_fields: Optional[List[CustomField]] = None, ) -> Experiences: if len(experiences) == 0: return empty_experiences(custom_fields) exp_type = experiences[0].experience_type if exp_type == "dpo": experiences = split_dpo_experience_to_single_turn(experiences) max_prompt_length = max([exp.prompt_length for exp in experiences]) # type: ignore [type-var] max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore [arg-type] eids = [exp.eid for exp in experiences] # Gather tokens tokens = gather_token_ids(experiences, max_prompt_length, max_response_length, pad_token_id) # Gather rewards if experiences[0].reward is not None: rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float) else: rewards = None # Gather token level rewards if all(exp.token_level_reward is not None for exp in experiences): token_level_rewards = gather_response_attrs( experiences, "token_level_reward", max_response_length ) else: token_level_rewards = None # gather action_masks action_masks = gather_action_masks(experiences, max_response_length) # gather attention_masks attention_masks = gather_attention_masks( experiences, max_prompt_length, max_response_length ) # gather logprobs if all(exp.logprobs is not None for exp in experiences): logprobs = gather_response_attrs(experiences, "logprobs", max_response_length) else: logprobs = None # gather advantages if all(exp.advantages is not None for exp in experiences): advantages = gather_response_attrs(experiences, "advantages", max_response_length) else: advantages = None # gather returns if all(exp.returns is not None for exp in experiences): returns = gather_response_attrs(experiences, "returns", max_response_length) else: returns = None # gather multi_modal_inputs if all(exp.multi_modal_inputs is not None for exp in experiences): multi_modal_inputs = gather_multi_modal_inputs(experiences) else: multi_modal_inputs = None # gather teacher_logprobs if all(exp.teacher_logprobs is not None for exp in experiences): teacher_logprobs = gather_response_attrs( experiences, "teacher_logprobs", max_response_length ) else: teacher_logprobs = None exps = Experiences( eids=eids, tokens=tokens, rewards=rewards, token_level_rewards=token_level_rewards, advantages=advantages, returns=returns, attention_masks=attention_masks, action_masks=action_masks, prompt_length=max_prompt_length, logprobs=logprobs, multi_modal_inputs=multi_modal_inputs, teacher_logprobs=teacher_logprobs, ) if custom_fields is not None: for custom_field in custom_fields: exps.custom_fields.append(custom_field.destination_field) setattr( exps, custom_field.destination_field, torch.tensor( [exp.info[custom_field.source_field] for exp in experiences], dtype=custom_field.data_type, ), ) return exps
[docs] def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]: single_turn_experiences = [] for exp in experiences: single_turn_experiences.append( Experience( eid=EID( batch=exp.eid.batch, task=exp.eid.task, step=exp.eid.step, run=exp.eid.run, ), tokens=torch.cat([exp.tokens, exp.chosen]), reward=exp.reward, info=exp.info, metrics=exp.metrics, prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, messages=exp.chosen_messages, ) ) single_turn_experiences.append( Experience( eid=EID( batch=exp.eid.batch, task=exp.eid.task, step=exp.eid.step, run=exp.eid.run, ), tokens=torch.cat([exp.tokens, exp.rejected]), reward=exp.reward, info=exp.info, metrics=exp.metrics, prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, messages=exp.rejected_messages, ) ) return single_turn_experiences
[docs] @dataclass class Experiences: """A container for a batch of experiences, for high performance communication usage. Example: >>> |<- prompt_length ->| | >>> tokens: ('P' represents prompt, 'O' represents output) >>> exp1: |........PPPPPPPPPPP|OOOOOOOOOO.....| >>> exp2: |......PPPPPPPPPPPPP|OOOOOOO........| >>> >>> attention_masks: ('.' represents False and '1' represents True) >>> exp1: |........11111111111|1111111111.....| >>> exp2: |......1111111111111|1111111........| """ eids: List[EID] # Experience IDs of each experience in the batch tokens: Tensor # [batch_size, seq_length] # At least one of `rewards` or `token_level_rewards` must be provided (not None). # If both are provided, `token_level_rewards` will be used and `rewards` will be ignored. rewards: Tensor # [batch_size] token_level_rewards: Tensor # [batch_size, response_length] advantages: Optional[Tensor] # [batch_size, response_length] returns: Optional[Tensor] # [batch_size, response_length] attention_masks: Tensor # [batch_size, sequence_length] action_masks: Optional[Tensor] # [batch_size, response_length] prompt_length: int logprobs: Optional[Tensor] # [batch_size, response_length] multi_modal_inputs: Optional[Any] custom_fields: List[str] = field( default_factory=list ) # Custom fields to include in the gathered experiences teacher_logprobs: Optional[Tensor] = None # [batch_size, response_length] @property def batch_size(self) -> int: """Get the batch size.""" return self.tokens.size(0)
[docs] @classmethod def gather_experiences( cls, experiences: list[Experience], pad_token_id: int = 0, custom_fields: Optional[List[CustomField]] = None, ) -> Experiences: """Gather a batch of experiences from a list of experiences. This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length. Args: experiences (list[Experience]): A list of experiences to gather. pad_token_id (int): The token ID to use for padding. Default is 0. custom_fields (Optional[List[CustomField]]): Custom fields to include in the gathered experiences. """ if len(experiences) == 0: return empty_experiences(custom_fields) return experiences[0].__class__.gather( experiences, pad_token_id=pad_token_id, custom_fields=custom_fields )
[docs] def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences: exps = Experiences( tokens=torch.empty(0, dtype=torch.int32), rewards=torch.empty(0, dtype=torch.float32), token_level_rewards=torch.empty(0, dtype=torch.float32), advantages=torch.empty(0, dtype=torch.float32), returns=torch.empty(0, dtype=torch.float32), attention_masks=torch.empty(0, dtype=torch.bool), action_masks=torch.empty(0, dtype=torch.bool), logprobs=torch.empty(0, dtype=torch.float32), prompt_length=torch.empty(0, dtype=torch.int32), eids=[], multi_modal_inputs=torch.empty(0, dtype=torch.float32), ) if custom_fields is not None: for custom_field in custom_fields: exps.custom_fields.append(custom_field.destination_field) setattr( exps, custom_field.destination_field, torch.empty(0, dtype=custom_field.data_type) ) return exps
[docs] def gather_token_ids( experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int ) -> Tensor: token_ids_dtype = experiences[0].tokens.dtype return torch.stack( [ torch.cat( [ torch.full( (max_prompt_length - exp.prompt_length,), pad_token_id, dtype=token_ids_dtype, ), exp.tokens, torch.full( (max_response_length + exp.prompt_length - len(exp.tokens),), pad_token_id, dtype=token_ids_dtype, ), ] ) for exp in experiences ] )
[docs] def gather_action_masks(experiences, max_response_length: int) -> Tensor: return torch.stack( [ torch.cat( [ exp.action_mask, torch.full( (max_response_length - len(exp.action_mask),), 0, dtype=torch.bool, ), ] ) for exp in experiences ] )
[docs] def gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: attention_masks = torch.zeros( (len(experiences), max_prompt_length + max_response_length), dtype=torch.bool ) for i, exp in enumerate(experiences): start = max_prompt_length - exp.prompt_length end = start + len(exp.tokens) attention_masks[i, start:end] = 1 return attention_masks
[docs] def gather_response_attrs( experiences, attr_name: str, max_response_length: int, pad_value: int = 0 ) -> Tensor: dtype = getattr(experiences[0], attr_name).dtype pad_value = torch.tensor(pad_value, dtype=dtype) return torch.stack( [ torch.cat( [ getattr(exp, attr_name), torch.full( (max_response_length - len(getattr(exp, attr_name)),), pad_value, dtype=dtype, ), ] ) for exp in experiences ] )
[docs] def gather_multi_modal_inputs(experiences) -> Dict[str, Tensor]: keys = experiences[0].multi_modal_inputs.keys() return {key: [exp.multi_modal_inputs[key] for exp in experiences] for key in keys}
[docs] def group_by( experiences: List[Experience], id_type: Literal["task", "run", "step"] ) -> Dict[str, List[Experience]]: """Group experiences by ID.""" if id_type == "task": id_type = "tid" elif id_type == "run": id_type = "rid" elif id_type == "step": id_type = "sid" else: raise ValueError(f"Unknown id_type: {id_type}") grouped = {} for exp in experiences: group_id = getattr(exp.eid, id_type) if group_id not in grouped: grouped[group_id] = [] grouped[group_id].append(exp) return grouped
[docs] def to_hf_datasets(experiences: list[Experience]) -> "Dataset": """ Convert a list of Experience objects to a HuggingFace Dataset, preserving all fields. """ from datasets import Dataset return Dataset.from_list([asdict(exp) for exp in experiences])
[docs] def from_hf_datasets(dataset: "Dataset") -> List[Experience]: """ Convert a HuggingFace Dataset back to a list of Experience objects. """ def dict_to_dataclass(cls, d): valid_keys = {f.name for f in fields(cls)} filtered = {k: v for k, v in d.items() if k in valid_keys} return cls(**filtered) experiences = [dict_to_dataclass(Experience, row) for row in dataset.to_list()] return experiences