Operator 开发指南#
步骤 0:Operator 模块基本概念#
Operator 模块负责处理由 Explorer 所生成的轨迹数据(我们称之为 Experience)。它原生支持来自 Data-Juicer 的数据处理功能,也允许开发者实现自己的算子。
通过自定义数据处理算子,开发者可以实现各种数据处理功能,如数据增强、过滤和转换。你甚至可以将优势值/回报值计算实现为 Operator,如 算法 部分所示。
DataJuicerOperator (
trinity.buffer.operators.DataJuicerOperator):封装后的 Data-Juicer 算子,使用时只需在配置文件中标明想要使用的 Data-Juicer 算子列表即可。完整的 Data-Juicer 算子列表请见 此处。ExperienceOperator (
trinity.buffer.operators.ExperienceOperator):用于 experience 数据处理的所有数据处理算子的基类。定义了所有数据处理算子应具备的接口和通用功能。每个算子处理一批 experience 数据,并返回处理后的数据及用于日志记录的指标。ExperiencePipeline (
trinity.buffer.pipelines.ExperiencePipeline):管理一系列数据处理算子的 experience 数据处理流水线。它从Explorer获取原始 experience,通过流水线中的每个算子处理,最后将最终处理过的 experience 写入Trainer的输入缓冲区。
备注
除了 ExperiencePipeline,Trinity-RFT 还提供 TaskPipeline 用于任务数据处理。
当前版本中,TaskPipeline 仅支持使用 Data-Juicer 算子。详情请参见 数据处理 部分。
开发者可通过以下步骤实现并使用自己的算子。
步骤 1:实现数据处理算子#
ExperienceOperatorV1 接口仅包含一个 process 方法。ExperiencePipeline 将调用此方法,传入 Explorer 在一次探索步骤中生成的一组 Experience。process 方法应返回一个元组,包含处理后的 Experience 列表和用于日志记录的指标字典。
class ExperienceOperatorV1(ABC):
@abstractmethod
async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
"""Process a list of experiences and return a transformed list.
Args:
exps (List[Experience]): List of experiences to process, which contains
all experiences generated by the Explorer in one explore step.
Returns:
Tuple[List[Experience], Dict]: A tuple containing the processed list of experiences and a dictionary of metrics.
"""
以下是一个简单数据处理算子的实现示例,该算子过滤掉奖励低于某一阈值的 experience:
from trinity.buffer.operators import ExperienceOperatorV1
from trinity.common.experience import Experience
class RewardFilter(ExperienceOperatorV1):
def __init__(self, threshold: float = 0.0) -> None:
self.threshold = threshold
async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
filtered_exps = [exp for exp in exps if exp.reward >= self.threshold]
metrics = {"filtered_count": len(exps) - len(filtered_exps)}
return filtered_exps, metrics
实现后,你需要在 trinity/buffer/operators/__init__.py 中的 default_mapping 中注册此模块。注册后,该模块可在配置文件中使用注册名称进行配置。
EXPERIENCE_OPERATORS = Registry(
"experience_operators",
default_mapping={
"reward_filter": "trinity.buffer.operators.filters.reward_filter.RewardFilter",
},
)
步骤 2:使用此算子#
完成上述步骤后,你可以通过 YAML 配置文件使用新注册的算子。
# some other configs
data_processor:
experience_pipeline:
operators:
- name: "reward_filter"
args:
threshold: 0.1
synchronizer:
sync_method: nccl
sync_style: explorer_driven
sync_interval: 2
# some other configs
小技巧
RewardFilter 会减少 experience 数量,可能导致 Trainer 无法获得足够的 experience 来启动训练流程。为避免此问题,你可以使用 Trinity-RFT 提供的 动态同步 功能 (explorer_driven)。
上述设置意味着 Explorer 每运行 2 步就会与 Trainer 同步一次,且无论 Trainer 当前完成了多少步都会继续运行。这确保了只要 Explorer 在运行,Trainer 就总能获得足够的 experience 来启动训练步骤。
进阶特性#
在 Operator 中使用辅助模型#
如 工作流开发指南 所介绍,Trinity-RFT 支持部署辅助模型并通过 OpenAI API 调用它们。该特性同样可以在 Operator 中使用,使你能够利用强大的模型对 experience 进行判断和处理。这对于实现需要复杂推理或自然语言理解的数据处理算子尤其有用。
假设你在 YAML 配置文件中有如下辅助模型配置:
explorer:
auxiliary_models:
- model_path: Qwen/Qwen2.5-32B-Instruct
name: qwen2.5-32B
engine_num: 1
tensor_parallel_size: 2
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
- model_path: Qwen/Qwen3-8B
name: qwen3-8B
engine_num: 2
tensor_parallel_size: 1
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
Trinity-RFT 会自动将已部署的辅助模型以 self.auxiliary_models 的形式注入到 Operator 中。该属性是一个字典,键为配置文件中模型的 name,值为模型实例列表(Dict[str, List[openai.AsyncOpenAI]]),每个模型实例的数量由 engine_num 决定。
你可以在算子的 process 方法中调用模型的推理 API,根据 experience 数据获得模型的响应。下面是一个在 Operator 中使用辅助模型的示例:
from trinity.buffer.operators import ExperienceOperatorV1
from trinity.common.experience import Experience
class OperatorWithModel(ExperienceOperatorV1):
async def judge_experience(self, exp: Experience) -> bool:
# 从 experience 中提取必要信息并准备模型输入
# messages = ...
# 调用模型推理 API 获取响应
response = await self.auxiliary_models["qwen2.5-32B"][0].chat.completions.create(
model=self.auxiliary_models["qwen2.5-32B"][0].model_path, # Trinity-RFT 会自动设置 model_path,便于调用
messages=messages,
)
# 处理模型响应并根据需要更新 experience
# ...
return exp
async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
await asyncio.gather(*(self.judge_experience(exp) for exp in exps))
return exps, {}