# -*- coding: utf-8 -*-
"""Experience Class."""
from __future__ import annotations
import pickle
import struct
import uuid
from dataclasses import asdict, dataclass, field, fields
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
import torch
from safetensors.torch import load as st_load
from safetensors.torch import save as st_save
from torch import Tensor
if TYPE_CHECKING:
from datasets import Dataset
[文档]
@dataclass
class EID:
"""Experience ID class to uniquely identify an experience.
To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows.
"""
# TODO: do we need to add project/name here to make it unique across different projects?
# Batch number, e.g., the explorer step num
# Automatically set by the workflow runner
batch: Union[int, str] = ""
# Task number, e.g., the task sequence in the batch, the first task in the batch has task=0
# Automatically set by the workflow runner
task: Union[int, str] = ""
# Run id, e.g., the first run in the task has run=0
# User should set this field in custom workflows when creating experiences
run: int = 0
# Step number when running the task, e.g., the first step in the task has step=0
# User should set this field in custom workflows when creating experiences
step: int = 0
suffix: str = field(
default_factory=lambda: uuid.uuid4().hex[:6]
) # Unique identifier suffix, e.g., a UUID
@property
def uid(self) -> str:
"""An unique identifier for the experience."""
return f"{self.batch}/{self.task}/{self.run}/{self.step}/{self.suffix}"
@property
def sid(self) -> str:
"""Step ID of the experience.
For example, experiences generated by all runs of a same task at the same step will have the same sid.
"""
return f"{self.batch}/{self.task}/{self.step}"
@property
def rid(self) -> str:
"""Run ID of the experience.
For example, experiences generated by one run of a task at all steps will have the same run_id.
"""
return f"{self.batch}/{self.task}/{self.run}"
@property
def tid(self) -> str:
"""Task ID for the experience.
For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid.
"""
return f"{self.batch}/{self.task}"
def __str__(self):
return self.uid
def __repr__(self):
return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})"
[文档]
def to_dict(self) -> dict:
"""Convert the EID to a dictionary."""
return {
"batch": self.batch,
"task": self.task,
"run": self.run,
"step": self.step,
"suffix": self.suffix,
}
[文档]
@dataclass(frozen=True)
class CustomField:
"""Custom field for Experiences.
This is used to store additional information into the Experiences class.
"""
source_field: str # The source field name in the Experience.info
destination_field: str # The destination field name in the Experiences class
data_type: torch.dtype # The data type of the field, e.g., torch.float32, torch.int64, etc.
[文档]
@dataclass
class Experience:
_SER_MAGIC = b"TRXP"
_SER_VERSION = 1
_TENSOR_FIELDS = (
"tokens",
"logprobs",
"token_level_reward",
"advantages",
"returns",
"action_mask",
"chosen",
"rejected",
"teacher_logprobs",
)
_META_FIELDS = (
"eid",
"reward",
"truncate_status",
"info",
"metrics",
"prompt_length",
"response_text",
"prompt_text",
"messages",
"tools",
"chosen_messages",
"rejected_messages",
)
eid: EID = field(default_factory=EID) # Unique identifier for the experience
tokens: Optional[Tensor] = None # [seq_length]
prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks
logprobs: Optional[Tensor] = None # [resp_length]
reward: Optional[float] = None
token_level_reward: Optional[Tensor] = None # [resp_length]
advantages: Optional[Tensor] = None # [resp_length]
returns: Optional[Tensor] = None # [resp_length]
truncate_status: Optional[
str
] = None # The status of truncation, e.g., "prompt_truncated", "response_truncated"; Not working for openai api
info: dict = field(
default_factory=dict
) # Additional information about the experience, can also be used to store custom fields
metrics: dict[str, float] = field(
default_factory=dict
) # Metrics associated with the experience, directly used by the monitor
# for single-turn experiences
response_text: Optional[str] = None # Text of the response
prompt_text: Optional[str] = None # Text of the prompt
# for multi-turn experiences
# Action mask indicates which tokens are generated by the model
action_mask: Optional[Tensor] = None # [resp_length]
messages: Optional[List[dict]] = None # List of messages
tools: Optional[List[dict]] = None
# for dpo experiences
chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length]
rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length]
chosen_messages: Optional[List[dict]] = None # Chosen message list (Include prompt message)
rejected_messages: Optional[List[dict]] = None # Rejected message list (Include prompt message)
# for multi-modal data
multi_modal_inputs: Optional[Dict[str, Tensor]] = None # Multi-modal inputs for verl trainer
# for on-policy distillation
teacher_logprobs: Optional[Tensor] = None # [resp_length]
custom_fields: List[CustomField] = field(default_factory=list)
[文档]
def __init__( # noqa: C901
self,
*,
eid=None,
tokens,
logprobs=None,
reward=None,
token_level_reward=None,
advantages=None,
returns=None,
truncate_status=None,
info=None,
metrics=None,
prompt_length=1,
response_text=None,
prompt_text=None,
action_mask=None,
messages=None,
tools=None,
chosen=None,
rejected=None,
chosen_messages=None,
rejected_messages=None,
multi_modal_inputs=None,
teacher_logprobs=None,
custom_fields=None,
):
if action_mask is not None:
experience_type = "multi_turn"
elif chosen is not None and rejected is not None:
experience_type = "dpo"
else:
experience_type = "single_turn"
if experience_type == "single_turn":
assert (
prompt_length > 0
), "Prompt length must be greater than 0 for single-turn experiences."
if truncate_status != "prompt_truncated":
assert (
len(tokens) > prompt_length
), f"Token ids must be larger than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}."
action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool)
else:
action_mask = torch.zeros(len(logprobs), dtype=torch.bool)
elif experience_type == "dpo":
prompt_length = len(tokens)
if eid is None:
self.eid = EID()
elif isinstance(eid, dict):
self.eid = EID(**eid)
else:
self.eid = eid
if isinstance(tokens, list):
tokens = torch.tensor(tokens, dtype=torch.int32)
self.tokens = tokens
if isinstance(logprobs, list):
logprobs = torch.tensor(logprobs, dtype=torch.float32)
self.logprobs = logprobs
self.reward = reward
if isinstance(token_level_reward, list):
token_level_reward = torch.tensor(token_level_reward, dtype=torch.float32)
self.token_level_reward = token_level_reward
if isinstance(advantages, list):
advantages = torch.tensor(advantages, dtype=torch.float32)
self.advantages = advantages
if isinstance(returns, list):
returns = torch.tensor(returns, dtype=torch.float32)
self.returns = returns
self.experience_type = experience_type
self.info = info or {}
self.metrics = metrics or {}
self.truncate_status = truncate_status
self.prompt_length = prompt_length
self.response_text = response_text
self.prompt_text = prompt_text
if isinstance(action_mask, list):
action_mask = torch.tensor(action_mask, dtype=torch.bool)
self.action_mask = action_mask
self.messages = messages
self.tools = tools
if isinstance(chosen, list):
chosen = torch.tensor(chosen, dtype=torch.int32)
self.chosen = chosen
if isinstance(rejected, list):
rejected = torch.tensor(rejected, dtype=torch.int32)
self.rejected = rejected
self.chosen_messages = chosen_messages
self.rejected_messages = rejected_messages
self.multi_modal_inputs = multi_modal_inputs
if multi_modal_inputs is not None:
self.multi_modal_inputs = {}
for key, value in multi_modal_inputs.items():
if not isinstance(value, Tensor):
self.multi_modal_inputs[key] = torch.tensor(value)
else:
self.multi_modal_inputs[key] = value
# Handle teacher_logprobs
if isinstance(teacher_logprobs, list):
teacher_logprobs = torch.tensor(teacher_logprobs, dtype=torch.float32)
self.teacher_logprobs = teacher_logprobs
if not isinstance(self.tokens, Tensor):
self.tokens = torch.tensor(self.tokens)
if self.logprobs is not None and not isinstance(self.logprobs, Tensor):
self.logprobs = torch.tensor(self.logprobs)
if self.action_mask is not None and not isinstance(self.action_mask, Tensor):
self.action_mask = torch.tensor(self.action_mask)
if self.chosen is not None and not isinstance(self.chosen, Tensor):
self.chosen = torch.tensor(self.chosen)
if self.rejected is not None and not isinstance(self.rejected, Tensor):
self.rejected = torch.tensor(self.rejected)
if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor):
self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32)
self.custom_fields = custom_fields or []
[文档]
def serialize(self) -> bytes:
"""Serialize the experience to bytes."""
return self.serialize_many([self])
[文档]
@classmethod
def deserialize(cls, data: bytes) -> Experience:
experiences = cls.deserialize_many(data)
if len(experiences) != 1:
raise ValueError(
f"Expected a single Experience payload, got batch size {len(experiences)}. "
"Use Experience.deserialize_many for batched payloads."
)
return experiences[0]
@staticmethod
def _serialize_custom_fields(custom_fields: Optional[List[CustomField]]) -> list[dict]:
if not custom_fields:
return []
return [
{
"source_field": field.source_field,
"destination_field": field.destination_field,
"data_type": str(field.data_type),
}
for field in custom_fields
]
@staticmethod
def _deserialize_custom_fields(serialized_fields: Optional[List[dict]]) -> List[CustomField]:
if not serialized_fields:
return []
custom_fields = []
for field_dict in serialized_fields:
dtype_name = field_dict["data_type"].replace("torch.", "")
dtype = getattr(torch, dtype_name)
custom_fields.append(
CustomField(
source_field=field_dict["source_field"],
destination_field=field_dict["destination_field"],
data_type=dtype,
)
)
return custom_fields
[文档]
@classmethod
def serialize_many(cls, experiences: List[Experience]) -> bytes:
"""Serialize a list of experiences into a compact bytes payload.
Tensor fields are packed with safetensors while non-tensor fields are packed
as metadata via pickle.
"""
metadata = {"version": cls._SER_VERSION, "num_items": len(experiences), "items": []}
tensor_data = {}
for index, exp in enumerate(experiences):
item_meta = {}
for field_name in cls._META_FIELDS:
value = getattr(exp, field_name)
if field_name == "eid" and value is not None:
item_meta[field_name] = value.to_dict() if isinstance(value, EID) else value
else:
item_meta[field_name] = value
item_meta["custom_fields"] = cls._serialize_custom_fields(exp.custom_fields)
for field_name in cls._TENSOR_FIELDS:
value = getattr(exp, field_name)
if value is None:
continue
tensor_data[f"{index}:{field_name}"] = value.detach().cpu().contiguous()
if exp.multi_modal_inputs is None:
item_meta["multi_modal_input_keys"] = []
else:
mm_keys = list(exp.multi_modal_inputs.keys())
item_meta["multi_modal_input_keys"] = mm_keys
for key in mm_keys:
value = exp.multi_modal_inputs[key]
tensor_data[f"{index}:multi_modal_inputs:{key}"] = (
value.detach()
.cpu()
.contiguous()
.clone() # clone to avoid shared memory issues
)
metadata["items"].append(item_meta)
metadata_bytes = pickle.dumps(metadata, protocol=pickle.HIGHEST_PROTOCOL)
tensor_bytes = st_save(tensor_data)
header = (
cls._SER_MAGIC
+ struct.pack("<B", cls._SER_VERSION)
+ struct.pack("<Q", len(metadata_bytes))
+ struct.pack("<Q", len(tensor_bytes))
)
return header + metadata_bytes + tensor_bytes
[文档]
@classmethod
def deserialize_many(cls, data: bytes) -> List[Experience]:
"""Deserialize bytes into a list of experiences.
Supports both new batched payloads and legacy single-experience pickle payloads.
"""
if not data.startswith(cls._SER_MAGIC):
legacy = pickle.loads(data)
if isinstance(legacy, list):
return legacy
return [legacy]
offset = len(cls._SER_MAGIC)
version = struct.unpack("<B", data[offset : offset + 1])[0]
offset += 1
if version != cls._SER_VERSION:
raise ValueError(
f"Unsupported Experience serialization version: {version}, expected {cls._SER_VERSION}."
)
metadata_len = struct.unpack("<Q", data[offset : offset + 8])[0]
offset += 8
tensor_len = struct.unpack("<Q", data[offset : offset + 8])[0]
offset += 8
metadata_bytes = data[offset : offset + metadata_len]
offset += metadata_len
tensor_bytes = data[offset : offset + tensor_len]
metadata = pickle.loads(metadata_bytes)
tensor_data = st_load(tensor_bytes)
experiences = []
for index, item_meta in enumerate(metadata["items"]):
init_kwargs = {
"eid": item_meta.get("eid"),
"reward": item_meta.get("reward"),
"truncate_status": item_meta.get("truncate_status"),
"info": item_meta.get("info"),
"metrics": item_meta.get("metrics"),
"prompt_length": item_meta.get("prompt_length", 1),
"response_text": item_meta.get("response_text"),
"prompt_text": item_meta.get("prompt_text"),
"messages": item_meta.get("messages"),
"tools": item_meta.get("tools"),
"chosen_messages": item_meta.get("chosen_messages"),
"rejected_messages": item_meta.get("rejected_messages"),
"custom_fields": cls._deserialize_custom_fields(item_meta.get("custom_fields")),
}
for field_name in cls._TENSOR_FIELDS:
tensor_key = f"{index}:{field_name}"
init_kwargs[field_name] = tensor_data.get(tensor_key)
mm_keys = item_meta.get("multi_modal_input_keys", [])
if mm_keys:
init_kwargs["multi_modal_inputs"] = {
key: tensor_data[f"{index}:multi_modal_inputs:{key}"] for key in mm_keys
}
else:
init_kwargs["multi_modal_inputs"] = None
experiences.append(cls(**init_kwargs))
return experiences
[文档]
def to_dict(self) -> dict:
"""Convert the experience to a dictionary."""
res = {
"eid": self.eid.to_dict(),
"type": self.experience_type,
"prompt_length": self.prompt_length,
"response_length": len(self.tokens) - self.prompt_length, # type: ignore [arg-type]
"info": self.info,
"metrics": self.metrics,
}
if self.prompt_text is not None:
res["prompt_text"] = self.prompt_text
if self.response_text is not None:
res["response_text"] = self.response_text
if self.messages is not None:
res["messages"] = self.messages
if self.tools is not None:
res["tools"] = self.tools
if self.chosen_messages is not None:
res["chosen_messages"] = self.chosen_messages
if self.rejected_messages is not None:
res["rejected_messages"] = self.rejected_messages
if self.reward is not None:
res["reward"] = float(self.reward)
if self.truncate_status is not None:
res["truncate_status"] = self.truncate_status
return res
[文档]
def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]:
single_turn_experiences = []
for exp in experiences:
single_turn_experiences.append(
Experience(
eid=EID(
batch=exp.eid.batch,
task=exp.eid.task,
step=exp.eid.step,
run=exp.eid.run,
),
tokens=torch.cat([exp.tokens, exp.chosen]),
reward=exp.reward,
info=exp.info,
metrics=exp.metrics,
prompt_length=len(exp.tokens), # type: ignore [arg-type]
prompt_text=exp.prompt_text,
messages=exp.chosen_messages,
)
)
single_turn_experiences.append(
Experience(
eid=EID(
batch=exp.eid.batch,
task=exp.eid.task,
step=exp.eid.step,
run=exp.eid.run,
),
tokens=torch.cat([exp.tokens, exp.rejected]),
reward=exp.reward,
info=exp.info,
metrics=exp.metrics,
prompt_length=len(exp.tokens), # type: ignore [arg-type]
prompt_text=exp.prompt_text,
messages=exp.rejected_messages,
)
)
return single_turn_experiences
[文档]
def gather_token_ids(
experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int
) -> Tensor:
token_ids_dtype = experiences[0].tokens.dtype
return torch.stack(
[
torch.cat(
[
torch.full(
(max_prompt_length - exp.prompt_length,),
pad_token_id,
dtype=token_ids_dtype,
),
exp.tokens,
torch.full(
(max_response_length + exp.prompt_length - len(exp.tokens),),
pad_token_id,
dtype=token_ids_dtype,
),
]
)
for exp in experiences
]
)
[文档]
def gather_action_masks(experiences, max_response_length: int) -> Tensor:
return torch.stack(
[
torch.cat(
[
exp.action_mask,
torch.full(
(max_response_length - len(exp.action_mask),),
0,
dtype=torch.bool,
),
]
)
for exp in experiences
]
)
[文档]
def gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor:
attention_masks = torch.zeros(
(len(experiences), max_prompt_length + max_response_length), dtype=torch.bool
)
for i, exp in enumerate(experiences):
start = max_prompt_length - exp.prompt_length
end = start + len(exp.tokens)
attention_masks[i, start:end] = 1
return attention_masks
[文档]
def gather_response_attrs(
experiences, attr_name: str, max_response_length: int, pad_value: int = 0
) -> Tensor:
dtype = getattr(experiences[0], attr_name).dtype
pad_value = torch.tensor(pad_value, dtype=dtype)
return torch.stack(
[
torch.cat(
[
getattr(exp, attr_name),
torch.full(
(max_response_length - len(getattr(exp, attr_name)),),
pad_value,
dtype=dtype,
),
]
)
for exp in experiences
]
)
[文档]
def group_by(
experiences: List[Experience], id_type: Literal["task", "run", "step"]
) -> Dict[str, List[Experience]]:
"""Group experiences by ID."""
if id_type == "task":
id_type = "tid"
elif id_type == "run":
id_type = "rid"
elif id_type == "step":
id_type = "sid"
else:
raise ValueError(f"Unknown id_type: {id_type}")
grouped = {}
for exp in experiences:
group_id = getattr(exp.eid, id_type)
if group_id not in grouped:
grouped[group_id] = []
grouped[group_id].append(exp)
return grouped
[文档]
def to_hf_datasets(experiences: list[Experience]) -> "Dataset":
"""
Convert a list of Experience objects to a HuggingFace Dataset,
preserving all fields.
"""
from datasets import Dataset
return Dataset.from_list([asdict(exp) for exp in experiences])
[文档]
def from_hf_datasets(dataset: "Dataset") -> List[Experience]:
"""
Convert a HuggingFace Dataset back to a list of Experience objects.
"""
def dict_to_dataclass(cls, d):
valid_keys = {f.name for f in fields(cls)}
filtered = {k: v for k, v in d.items() if k in valid_keys}
return cls(**filtered)
experiences = [dict_to_dataclass(Experience, row) for row in dataset.to_list()]
return experiences