trinity.buffer.operators.filters.reward_filter module#

class trinity.buffer.operators.filters.reward_filter.RewardFilter(threshold: float = 0.0)[源代码]#

基类:ExperienceOperator

Filter experiences based on the reward value.

Note: This filter assumes that the reward is already calculated and stored in the Experience object.

__init__(threshold: float = 0.0)[源代码]#
process(exps: List[Experience]) Tuple[List[Experience], dict][源代码]#

Filter experiences based on reward value.

class trinity.buffer.operators.filters.reward_filter.RewardSTDFilter(threshold: float = 0.0)[源代码]#

基类:ExperienceOperator

Filter experiences based on the standard deviation of rewards within each group.

Note: This filter assumes that the reward is already calculated and stored in the Experience object.

__init__(threshold: float = 0.0)[源代码]#
process(exps: List[Experience]) Tuple[List[Experience], dict][源代码]#

Filter experiences based on reward std.

class trinity.buffer.operators.filters.reward_filter.DAPODynamicSamplingFilter(metric_key: str = 'accuracy', correct_threshold: float = 0.0)[源代码]#

基类:ExperienceOperator

DAPO dynamic sampling (arXiv:2503.14476 Sec. 3.2).

Keeps a task group only when some but not all rollouts are correct: 0 < |{correct}| < G. Uses outcome accuracy from experience metrics, not length-shaped total reward.

__init__(metric_key: str = 'accuracy', correct_threshold: float = 0.0) None[源代码]#

Initialize the dynamic sampling filter.

参数:
  • metric_key -- Metric name used to determine rollout correctness.

  • correct_threshold -- Minimum score treated as correct.

process(exps: List[Experience]) Tuple[List[Experience], dict][源代码]#

Keep only mixed-correctness groups for DAPO training.

参数:

exps -- Experiences grouped by task id during filtering.

返回:

Filtered experiences and filtering metrics.

返回类型:

Tuple[List[Experience], dict]

class trinity.buffer.operators.filters.reward_filter.MaskResponseTruncatedOperator[源代码]#

基类:ExperienceOperator

DAPO overlong filtering stage 1 (Sec. 3.4): exclude truncated responses from loss.

Zeros action_mask so truncated rollouts do not contribute to the policy gradient.

process(exps: List[Experience]) Tuple[List[Experience], dict][源代码]#

Mask action positions for truncated responses.

参数:

exps -- Experiences to process.

返回:

Original experiences and masking metrics.

返回类型:

Tuple[List[Experience], dict]

class trinity.buffer.operators.filters.reward_filter.InvalidRewardFilter[源代码]#

基类:ExperienceOperator

Filters out experiences with invalid reward values.

Note: This operator assumes that rewards are already computed and stored in the Experience object.Any experience with a missing (None) or invalid (NaN) reward is removed to prevent low-quality data from entering the training pipeline.

process(exps: List[Experience]) Tuple[List[Experience], dict][源代码]#

Process a list of experiences and return a transformed list.

参数:

exps (List[Experience]) -- List of experiences to process, which contains all experiences generated by the Explorer in one explore step.

返回:

A tuple containing the processed list of experiences and a dictionary of metrics.

返回类型:

Tuple[List[Experience], Dict]