Operator 开发指南#

步骤 0:Operator 模块基本概念#

Operator 模块负责处理由 Explorer 所生成的轨迹数据(我们称之为 Experience)。它原生支持来自 Data-Juicer 的数据处理功能,也允许开发者实现自己的算子。 通过自定义数据处理算子,开发者可以实现各种数据处理功能,如数据增强、过滤和转换。你甚至可以将优势值/回报值计算实现为 Operator,如 算法 部分所示。

  • DataJuicerOperator (trinity.buffer.operators.DataJuicerOperator):封装后的 Data-Juicer 算子,使用时只需在配置文件中标明想要使用的 Data-Juicer 算子列表即可。完整的 Data-Juicer 算子列表请见 此处

  • ExperienceOperator (trinity.buffer.operators.ExperienceOperator):用于 experience 数据处理的所有数据处理算子的基类。定义了所有数据处理算子应具备的接口和通用功能。每个算子处理一批 experience 数据,并返回处理后的数据及用于日志记录的指标。

  • ExperiencePipeline (trinity.buffer.pipelines.ExperiencePipeline):管理一系列数据处理算子的 experience 数据处理流水线。它从 Explorer 获取原始 experience,通过流水线中的每个算子处理,最后将最终处理过的 experience 写入 Trainer 的输入缓冲区。

备注

除了 ExperiencePipeline,Trinity-RFT 还提供 TaskPipeline 用于任务数据处理。 当前版本中,TaskPipeline 仅支持使用 Data-Juicer 算子。详情请参见 数据处理 部分。


开发者可通过以下步骤实现并使用自己的算子。

步骤 1:实现数据处理算子#

ExperienceOperatorV1 接口仅包含一个 process 方法。ExperiencePipeline 将调用此方法,传入 Explorer 在一次探索步骤中生成的一组 Experienceprocess 方法应返回一个元组,包含处理后的 Experience 列表和用于日志记录的指标字典。

class ExperienceOperatorV1(ABC):

    @abstractmethod
    async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
        """Process a list of experiences and return a transformed list.

        Args:
            exps (List[Experience]): List of experiences to process, which contains
                all experiences generated by the Explorer in one explore step.
        Returns:
            Tuple[List[Experience], Dict]: A tuple containing the processed list of experiences and a dictionary of metrics.
        """

以下是一个简单数据处理算子的实现示例,该算子过滤掉奖励低于某一阈值的 experience:

from trinity.buffer.operators import ExperienceOperatorV1
from trinity.common.experience import Experience


class RewardFilter(ExperienceOperatorV1):

    def __init__(self, threshold: float = 0.0) -> None:
        self.threshold = threshold

    async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
        filtered_exps = [exp for exp in exps if exp.reward >= self.threshold]
        metrics = {"filtered_count": len(exps) - len(filtered_exps)}
        return filtered_exps, metrics

实现后,你需要在 trinity/buffer/operators/__init__.py 中的 default_mapping 中注册此模块。注册后,该模块可在配置文件中使用注册名称进行配置。

EXPERIENCE_OPERATORS = Registry(
    "experience_operators",
    default_mapping={
        "reward_filter": "trinity.buffer.operators.filters.reward_filter.RewardFilter",
    },
)

步骤 2:使用此算子#

完成上述步骤后,你可以通过 YAML 配置文件使用新注册的算子。

# some other configs
data_processor:
  experience_pipeline:
    operators:
      - name: "reward_filter"
        args:
          threshold: 0.1
synchronizer:
  sync_method: nccl
  sync_style: explorer_driven
  sync_interval: 2
# some other configs

小技巧

RewardFilter 会减少 experience 数量,可能导致 Trainer 无法获得足够的 experience 来启动训练流程。为避免此问题,你可以使用 Trinity-RFT 提供的 动态同步 功能 (explorer_driven)。 上述设置意味着 Explorer 每运行 2 步就会与 Trainer 同步一次,且无论 Trainer 当前完成了多少步都会继续运行。这确保了只要 Explorer 在运行,Trainer 就总能获得足够的 experience 来启动训练步骤。

进阶特性#

在 Operator 中使用辅助模型#

工作流开发指南 所介绍,Trinity-RFT 支持部署辅助模型并通过 OpenAI API 调用它们。该特性同样可以在 Operator 中使用,使你能够利用强大的模型对 experience 进行判断和处理。这对于实现需要复杂推理或自然语言理解的数据处理算子尤其有用。

假设你在 YAML 配置文件中有如下辅助模型配置:

explorer:
  auxiliary_models:
    - model_path: Qwen/Qwen2.5-32B-Instruct
      name: qwen2.5-32B
      engine_num: 1
      tensor_parallel_size: 2
      enable_thinking: false
      max_prompt_tokens: 12288
      max_response_tokens: 12288
      max_model_len: 16384
    - model_path: Qwen/Qwen3-8B
      name: qwen3-8B
      engine_num: 2
      tensor_parallel_size: 1
      enable_thinking: false
      max_prompt_tokens: 12288
      max_response_tokens: 12288
      max_model_len: 16384

Trinity-RFT 会自动将已部署的辅助模型以 self.auxiliary_models 的形式注入到 Operator 中。该属性是一个字典,键为配置文件中模型的 name,值为模型实例列表(Dict[str, List[openai.AsyncOpenAI]]),每个模型实例的数量由 engine_num 决定。

你可以在算子的 process 方法中调用模型的推理 API,根据 experience 数据获得模型的响应。下面是一个在 Operator 中使用辅助模型的示例:

from trinity.buffer.operators import ExperienceOperatorV1
from trinity.common.experience import Experience


class OperatorWithModel(ExperienceOperatorV1):

    async def judge_experience(self, exp: Experience) -> bool:
        # 从 experience 中提取必要信息并准备模型输入
        # messages = ...
        # 调用模型推理 API 获取响应
        response = await self.auxiliary_models["qwen2.5-32B"][0].chat.completions.create(
            model=self.auxiliary_models["qwen2.5-32B"][0].model_path,  # Trinity-RFT 会自动设置 model_path,便于调用
            messages=messages,
        )
        # 处理模型响应并根据需要更新 experience
        # ...
        return exp

    async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
        await asyncio.gather(*(self.judge_experience(exp) for exp in exps))
        return exps, {}