trinity.algorithm.advantage_fn.multi_step_grpo_advantage module#

GRPO advantage computation for multi-step scenarios

class trinity.algorithm.advantage_fn.multi_step_grpo_advantage.StepWiseGRPOAdvantageFn(epsilon: float = 1e-06, enable_step_norm: bool = False, std_cal_level: str = 'group', std_threshold: float | None = None, **kwargs)[source]#

Bases: AdvantageFn, ExperienceOperator

An advantage function that broadcasts advantages from the last step to previous steps. Inspired by rLLM (rllm-org/rllm).

__init__(epsilon: float = 1e-06, enable_step_norm: bool = False, std_cal_level: str = 'group', std_threshold: float | None = None, **kwargs) None[source]#

Initialize the Step-wise GRPO advantage function.

Parameters:
  • epsilon (float) – A small value to avoid division by zero.

  • enable_step_norm (bool) – If True, normalize advantages by trajectory length.

  • std_cal_level (str) – The scope for calculating reward standard deviation. ‘group’ (default): Std is calculated per task group. ‘batch’: Std is calculated across all last-step rewards in the entire batch. The mean is always calculated per task group.

  • std_threshold (Optional[float]) – If provided, task groups with a reward standard deviation equal or below this threshold will be skipped.

calculate_last_step_advantage(exps: Dict[str, Experience], precomputed_std: Tensor | None = None) Tuple[Dict[str, float], Dict[str, float], bool][source]#

Calculate group advantage for a given group of experiences.

Parameters:
  • exps (Dict[str, Experience]) – One experience per run, keyed by run ID.

  • precomputed_std (Optional[torch.Tensor]) – Precomputed standard deviation for batch-level calculation.

Returns:

Scores for each run. Dict[str, float]: Metrics for logging. bool: Whether this group should be skipped.

Return type:

Dict[str, float]

broadcast_advantages(run_exps: Dict[str, List[Experience]], scores: Dict[str, float]) Dict[str, List[Experience]][source]#

Broadcast the calculated advantages to all previous steps in each run.

Parameters:
  • run_exps (Dict[str, List[Experience]]) – Experiences grouped by run ID.

  • scores (Dict[str, float]) – Calculated scores for each run.

Returns:

Updated experiences with advantages broadcasted.

Return type:

Dict[str, List[Experience]]

process(exps: List[Experience]) Tuple[List[Experience], Dict][source]#

Process a list of experiences and return a transformed list.

Parameters:

exps (List[Experience]) – List of experiences to process, which contains all experiences generated by the Explorer in one explore step.

Returns:

A tuple containing the processed list of experiences and a dictionary of metrics.

Return type:

Tuple[List[Experience], Dict]

classmethod compute_in_trainer() bool[source]#

Whether the advantage should be computed in the trainer loop.

classmethod default_args() Dict[source]#

Return the default configuration for this strategy.