# -*- coding: utf-8 -*-
"""Experience Class."""
from __future__ import annotations
import pickle
import uuid
from dataclasses import asdict, dataclass, field, fields
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
import torch
from torch import Tensor
if TYPE_CHECKING:
from datasets import Dataset
[文档]
@dataclass
class EID:
"""Experience ID class to uniquely identify an experience.
To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows.
"""
# TODO: do we need to add project/name here to make it unique across different projects?
# Batch number, e.g., the explorer step num
# Automatically set by the workflow runner
batch: Union[int, str] = ""
# Task number, e.g., the task sequence in the batch, the first task in the batch has task=0
# Automatically set by the workflow runner
task: Union[int, str] = ""
# Run id, e.g., the first run in the task has run=0
# User should set this field in custom workflows when creating experiences
run: int = 0
# Step number when running the task, e.g., the first step in the task has step=0
# User should set this field in custom workflows when creating experiences
step: int = 0
suffix: str = field(
default_factory=lambda: uuid.uuid4().hex[:6]
) # Unique identifier suffix, e.g., a UUID
@property
def uid(self) -> str:
"""An unique identifier for the experience."""
return f"{self.batch}/{self.task}/{self.run}/{self.step}/{self.suffix}"
@property
def sid(self) -> str:
"""Step ID of the experience.
For example, experiences generated by all runs of a same task at the same step will have the same sid.
"""
return f"{self.batch}/{self.task}/{self.step}"
@property
def rid(self) -> str:
"""Run ID of the experience.
For example, experiences generated by one run of a task at all steps will have the same run_id.
"""
return f"{self.batch}/{self.task}/{self.run}"
@property
def tid(self) -> str:
"""Task ID for the experience.
For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid.
"""
return f"{self.batch}/{self.task}"
def __str__(self):
return self.uid
def __repr__(self):
return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})"
[文档]
def to_dict(self) -> dict:
"""Convert the EID to a dictionary."""
return {
"batch": self.batch,
"task": self.task,
"run": self.run,
"step": self.step,
"suffix": self.suffix,
}
[文档]
@dataclass(frozen=True)
class CustomField:
"""Custom field for Experiences.
This is used to store additional information into the Experiences class.
"""
source_field: str # The source field name in the Experience.info
destination_field: str # The destination field name in the Experiences class
data_type: torch.dtype # The data type of the field, e.g., torch.float32, torch.int64, etc.
[文档]
@dataclass
class Experience:
eid: EID = field(default_factory=EID) # Unique identifier for the experience
tokens: Optional[Tensor] = None # [seq_length]
prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks
logprobs: Optional[Tensor] = None # [resp_length]
reward: Optional[float] = None
token_level_reward: Optional[Tensor] = None # [resp_length]
advantages: Optional[Tensor] = None # [resp_length]
returns: Optional[Tensor] = None # [resp_length]
truncate_status: Optional[
str
] = None # The status of truncation, e.g., "prompt_truncated", "response_truncated"; Not working for openai api
info: dict = field(
default_factory=dict
) # Additional information about the experience, can also be used to store custom fields
metrics: dict[str, float] = field(
default_factory=dict
) # Metrics associated with the experience, directly used by the monitor
# for single-turn experiences
response_text: Optional[str] = None # Text of the response
prompt_text: Optional[str] = None # Text of the prompt
# for multi-turn experiences
# Action mask indicates which tokens are generated by the model
action_mask: Optional[Tensor] = None # [resp_length]
messages: Optional[List[dict]] = None # List of messages
tools: Optional[List[dict]] = None
# for dpo experiences
chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length]
rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length]
chosen_messages: Optional[List[dict]] = None # Chosen message list (Include prompt message)
rejected_messages: Optional[List[dict]] = None # Rejected message list (Include prompt message)
# for multi-modal data
multi_modal_inputs: Optional[Dict[str, Tensor]] = None # Multi-modal inputs for verl trainer
# for on-policy distillation
teacher_logprobs: Optional[Tensor] = None # [resp_length]
custom_fields: List[CustomField] = field(default_factory=list)
[文档]
def __init__( # noqa: C901
self,
*,
eid=None,
tokens,
logprobs=None,
reward=None,
token_level_reward=None,
advantages=None,
returns=None,
truncate_status=None,
info=None,
metrics=None,
prompt_length=1,
response_text=None,
prompt_text=None,
action_mask=None,
messages=None,
tools=None,
chosen=None,
rejected=None,
chosen_messages=None,
rejected_messages=None,
multi_modal_inputs=None,
teacher_logprobs=None,
custom_fields=None,
):
if action_mask is not None:
experience_type = "multi_turn"
elif chosen is not None and rejected is not None:
experience_type = "dpo"
else:
experience_type = "single_turn"
if experience_type == "single_turn":
assert (
prompt_length > 0
), "Prompt length must be greater than 0 for single-turn experiences."
if truncate_status != "prompt_truncated":
assert (
len(tokens) > prompt_length
), f"Token ids must be larger than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}."
action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool)
else:
action_mask = torch.zeros(len(logprobs), dtype=torch.bool)
elif experience_type == "dpo":
prompt_length = len(tokens)
if eid is None:
self.eid = EID()
elif isinstance(eid, dict):
self.eid = EID(**eid)
else:
self.eid = eid
if isinstance(tokens, list):
tokens = torch.tensor(tokens, dtype=torch.int32)
self.tokens = tokens
if isinstance(logprobs, list):
logprobs = torch.tensor(logprobs, dtype=torch.float32)
self.logprobs = logprobs
self.reward = reward
if isinstance(token_level_reward, list):
token_level_reward = torch.tensor(token_level_reward, dtype=torch.float32)
self.token_level_reward = token_level_reward
if isinstance(advantages, list):
advantages = torch.tensor(advantages, dtype=torch.float32)
self.advantages = advantages
if isinstance(returns, list):
returns = torch.tensor(returns, dtype=torch.float32)
self.returns = returns
self.experience_type = experience_type
self.info = info or {}
self.metrics = metrics or {}
self.truncate_status = truncate_status
self.prompt_length = prompt_length
self.response_text = response_text
self.prompt_text = prompt_text
if isinstance(action_mask, list):
action_mask = torch.tensor(action_mask, dtype=torch.bool)
self.action_mask = action_mask
self.messages = messages
self.tools = tools
if isinstance(chosen, list):
chosen = torch.tensor(chosen, dtype=torch.int32)
self.chosen = chosen
if isinstance(rejected, list):
rejected = torch.tensor(rejected, dtype=torch.int32)
self.rejected = rejected
self.chosen_messages = chosen_messages
self.rejected_messages = rejected_messages
self.multi_modal_inputs = multi_modal_inputs
if multi_modal_inputs is not None:
self.multi_modal_inputs = {}
for key, value in multi_modal_inputs.items():
if not isinstance(value, Tensor):
self.multi_modal_inputs[key] = torch.tensor(value)
else:
self.multi_modal_inputs[key] = value
# Handle teacher_logprobs
if isinstance(teacher_logprobs, list):
teacher_logprobs = torch.tensor(teacher_logprobs, dtype=torch.float32)
self.teacher_logprobs = teacher_logprobs
if not isinstance(self.tokens, Tensor):
self.tokens = torch.tensor(self.tokens)
if self.logprobs is not None and not isinstance(self.logprobs, Tensor):
self.logprobs = torch.tensor(self.logprobs)
if self.action_mask is not None and not isinstance(self.action_mask, Tensor):
self.action_mask = torch.tensor(self.action_mask)
if self.chosen is not None and not isinstance(self.chosen, Tensor):
self.chosen = torch.tensor(self.chosen)
if self.rejected is not None and not isinstance(self.rejected, Tensor):
self.rejected = torch.tensor(self.rejected)
if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor):
self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32)
self.custom_fields = custom_fields or []
[文档]
def serialize(self) -> bytes:
"""Serialize the experience to bytes."""
return pickle.dumps(self)
[文档]
@classmethod
def deserialize(cls, data: bytes) -> Experience:
return pickle.loads(data)
[文档]
def to_dict(self) -> dict:
"""Convert the experience to a dictionary."""
res = {
"eid": self.eid.to_dict(),
"type": self.experience_type,
"prompt_length": self.prompt_length,
"response_length": len(self.tokens) - self.prompt_length, # type: ignore [arg-type]
"info": self.info,
"metrics": self.metrics,
}
if self.prompt_text is not None:
res["prompt_text"] = self.prompt_text
if self.response_text is not None:
res["response_text"] = self.response_text
if self.messages is not None:
res["messages"] = self.messages
if self.tools is not None:
res["tools"] = self.tools
if self.chosen_messages is not None:
res["chosen_messages"] = self.chosen_messages
if self.rejected_messages is not None:
res["rejected_messages"] = self.rejected_messages
if self.reward is not None:
res["reward"] = float(self.reward)
if self.truncate_status is not None:
res["truncate_status"] = self.truncate_status
return res
[文档]
@classmethod
def gather(
cls,
experiences: List[Experience],
pad_token_id: int = 0,
custom_fields: Optional[List[CustomField]] = None,
) -> Experiences:
if len(experiences) == 0:
return empty_experiences(custom_fields)
exp_type = experiences[0].experience_type
if exp_type == "dpo":
experiences = split_dpo_experience_to_single_turn(experiences)
max_prompt_length = max([exp.prompt_length for exp in experiences]) # type: ignore [type-var]
max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore [arg-type]
eids = [exp.eid for exp in experiences]
# Gather tokens
tokens = gather_token_ids(experiences, max_prompt_length, max_response_length, pad_token_id)
# Gather rewards
if experiences[0].reward is not None:
rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float)
else:
rewards = None
# Gather token level rewards
if all(exp.token_level_reward is not None for exp in experiences):
token_level_rewards = gather_response_attrs(
experiences, "token_level_reward", max_response_length
)
else:
token_level_rewards = None
# gather action_masks
action_masks = gather_action_masks(experiences, max_response_length)
# gather attention_masks
attention_masks = gather_attention_masks(
experiences, max_prompt_length, max_response_length
)
# gather logprobs
if all(exp.logprobs is not None for exp in experiences):
logprobs = gather_response_attrs(experiences, "logprobs", max_response_length)
else:
logprobs = None
# gather advantages
if all(exp.advantages is not None for exp in experiences):
advantages = gather_response_attrs(experiences, "advantages", max_response_length)
else:
advantages = None
# gather returns
if all(exp.returns is not None for exp in experiences):
returns = gather_response_attrs(experiences, "returns", max_response_length)
else:
returns = None
# gather multi_modal_inputs
if all(exp.multi_modal_inputs is not None for exp in experiences):
multi_modal_inputs = gather_multi_modal_inputs(experiences)
else:
multi_modal_inputs = None
# gather teacher_logprobs
if all(exp.teacher_logprobs is not None for exp in experiences):
teacher_logprobs = gather_response_attrs(
experiences, "teacher_logprobs", max_response_length
)
else:
teacher_logprobs = None
exps = Experiences(
eids=eids,
tokens=tokens,
rewards=rewards,
token_level_rewards=token_level_rewards,
advantages=advantages,
returns=returns,
attention_masks=attention_masks,
action_masks=action_masks,
prompt_length=max_prompt_length,
logprobs=logprobs,
multi_modal_inputs=multi_modal_inputs,
teacher_logprobs=teacher_logprobs,
)
if custom_fields is not None:
for custom_field in custom_fields:
exps.custom_fields.append(custom_field.destination_field)
setattr(
exps,
custom_field.destination_field,
torch.tensor(
[exp.info[custom_field.source_field] for exp in experiences],
dtype=custom_field.data_type,
),
)
return exps
[文档]
def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]:
single_turn_experiences = []
for exp in experiences:
single_turn_experiences.append(
Experience(
eid=EID(
batch=exp.eid.batch,
task=exp.eid.task,
step=exp.eid.step,
run=exp.eid.run,
),
tokens=torch.cat([exp.tokens, exp.chosen]),
reward=exp.reward,
info=exp.info,
metrics=exp.metrics,
prompt_length=len(exp.tokens), # type: ignore [arg-type]
prompt_text=exp.prompt_text,
messages=exp.chosen_messages,
)
)
single_turn_experiences.append(
Experience(
eid=EID(
batch=exp.eid.batch,
task=exp.eid.task,
step=exp.eid.step,
run=exp.eid.run,
),
tokens=torch.cat([exp.tokens, exp.rejected]),
reward=exp.reward,
info=exp.info,
metrics=exp.metrics,
prompt_length=len(exp.tokens), # type: ignore [arg-type]
prompt_text=exp.prompt_text,
messages=exp.rejected_messages,
)
)
return single_turn_experiences
[文档]
@dataclass
class Experiences:
"""A container for a batch of experiences, for high performance communication usage.
Example:
>>> |<- prompt_length ->| |
>>> tokens: ('P' represents prompt, 'O' represents output)
>>> exp1: |........PPPPPPPPPPP|OOOOOOOOOO.....|
>>> exp2: |......PPPPPPPPPPPPP|OOOOOOO........|
>>>
>>> attention_masks: ('.' represents False and '1' represents True)
>>> exp1: |........11111111111|1111111111.....|
>>> exp2: |......1111111111111|1111111........|
"""
eids: List[EID] # Experience IDs of each experience in the batch
tokens: Tensor # [batch_size, seq_length]
# At least one of `rewards` or `token_level_rewards` must be provided (not None).
# If both are provided, `token_level_rewards` will be used and `rewards` will be ignored.
rewards: Tensor # [batch_size]
token_level_rewards: Tensor # [batch_size, response_length]
advantages: Optional[Tensor] # [batch_size, response_length]
returns: Optional[Tensor] # [batch_size, response_length]
attention_masks: Tensor # [batch_size, sequence_length]
action_masks: Optional[Tensor] # [batch_size, response_length]
prompt_length: int
logprobs: Optional[Tensor] # [batch_size, response_length]
multi_modal_inputs: Optional[Any]
custom_fields: List[str] = field(
default_factory=list
) # Custom fields to include in the gathered experiences
teacher_logprobs: Optional[Tensor] = None # [batch_size, response_length]
@property
def batch_size(self) -> int:
"""Get the batch size."""
return self.tokens.size(0)
[文档]
@classmethod
def gather_experiences(
cls,
experiences: list[Experience],
pad_token_id: int = 0,
custom_fields: Optional[List[CustomField]] = None,
) -> Experiences:
"""Gather a batch of experiences from a list of experiences.
This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length.
Args:
experiences (list[Experience]): A list of experiences to gather.
pad_token_id (int): The token ID to use for padding. Default is 0.
custom_fields (Optional[List[CustomField]]): Custom fields to include in the gathered experiences.
"""
if len(experiences) == 0:
return empty_experiences(custom_fields)
return experiences[0].__class__.gather(
experiences, pad_token_id=pad_token_id, custom_fields=custom_fields
)
[文档]
def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences:
exps = Experiences(
tokens=torch.empty(0, dtype=torch.int32),
rewards=torch.empty(0, dtype=torch.float32),
token_level_rewards=torch.empty(0, dtype=torch.float32),
advantages=torch.empty(0, dtype=torch.float32),
returns=torch.empty(0, dtype=torch.float32),
attention_masks=torch.empty(0, dtype=torch.bool),
action_masks=torch.empty(0, dtype=torch.bool),
logprobs=torch.empty(0, dtype=torch.float32),
prompt_length=torch.empty(0, dtype=torch.int32),
eids=[],
multi_modal_inputs=torch.empty(0, dtype=torch.float32),
)
if custom_fields is not None:
for custom_field in custom_fields:
exps.custom_fields.append(custom_field.destination_field)
setattr(
exps, custom_field.destination_field, torch.empty(0, dtype=custom_field.data_type)
)
return exps
[文档]
def gather_token_ids(
experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int
) -> Tensor:
token_ids_dtype = experiences[0].tokens.dtype
return torch.stack(
[
torch.cat(
[
torch.full(
(max_prompt_length - exp.prompt_length,),
pad_token_id,
dtype=token_ids_dtype,
),
exp.tokens,
torch.full(
(max_response_length + exp.prompt_length - len(exp.tokens),),
pad_token_id,
dtype=token_ids_dtype,
),
]
)
for exp in experiences
]
)
[文档]
def gather_action_masks(experiences, max_response_length: int) -> Tensor:
return torch.stack(
[
torch.cat(
[
exp.action_mask,
torch.full(
(max_response_length - len(exp.action_mask),),
0,
dtype=torch.bool,
),
]
)
for exp in experiences
]
)
[文档]
def gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor:
attention_masks = torch.zeros(
(len(experiences), max_prompt_length + max_response_length), dtype=torch.bool
)
for i, exp in enumerate(experiences):
start = max_prompt_length - exp.prompt_length
end = start + len(exp.tokens)
attention_masks[i, start:end] = 1
return attention_masks
[文档]
def gather_response_attrs(
experiences, attr_name: str, max_response_length: int, pad_value: int = 0
) -> Tensor:
dtype = getattr(experiences[0], attr_name).dtype
pad_value = torch.tensor(pad_value, dtype=dtype)
return torch.stack(
[
torch.cat(
[
getattr(exp, attr_name),
torch.full(
(max_response_length - len(getattr(exp, attr_name)),),
pad_value,
dtype=dtype,
),
]
)
for exp in experiences
]
)
[文档]
def group_by(
experiences: List[Experience], id_type: Literal["task", "run", "step"]
) -> Dict[str, List[Experience]]:
"""Group experiences by ID."""
if id_type == "task":
id_type = "tid"
elif id_type == "run":
id_type = "rid"
elif id_type == "step":
id_type = "sid"
else:
raise ValueError(f"Unknown id_type: {id_type}")
grouped = {}
for exp in experiences:
group_id = getattr(exp.eid, id_type)
if group_id not in grouped:
grouped[group_id] = []
grouped[group_id].append(exp)
return grouped
[文档]
def to_hf_datasets(experiences: list[Experience]) -> "Dataset":
"""
Convert a list of Experience objects to a HuggingFace Dataset,
preserving all fields.
"""
from datasets import Dataset
return Dataset.from_list([asdict(exp) for exp in experiences])
[文档]
def from_hf_datasets(dataset: "Dataset") -> List[Experience]:
"""
Convert a HuggingFace Dataset back to a list of Experience objects.
"""
def dict_to_dataclass(cls, d):
valid_keys = {f.name for f in fields(cls)}
filtered = {k: v for k, v in d.items() if k in valid_keys}
return cls(**filtered)
experiences = [dict_to_dataclass(Experience, row) for row in dataset.to_list()]
return experiences