trinity.algorithm.advantage_fn.multi_step_grpo_advantage module#
GRPO advantage computation for multi-step scenarios
- class trinity.algorithm.advantage_fn.multi_step_grpo_advantage.StepWiseGRPOAdvantageFn(epsilon: float = 1e-06, enable_step_norm: bool = False, std_cal_level: str = 'group', std_threshold: float | None = None, **kwargs)[source]#
Bases:
AdvantageFn,ExperienceOperatorAn advantage function that broadcasts advantages from the last step to previous steps. Inspired by rLLM (rllm-org/rllm).
- __init__(epsilon: float = 1e-06, enable_step_norm: bool = False, std_cal_level: str = 'group', std_threshold: float | None = None, **kwargs) None[source]#
Initialize the Step-wise GRPO advantage function.
- Parameters:
epsilon (float) – A small value to avoid division by zero.
enable_step_norm (bool) – If True, normalize advantages by trajectory length.
std_cal_level (str) – The scope for calculating reward standard deviation. ‘group’ (default): Std is calculated per task group. ‘batch’: Std is calculated across all last-step rewards in the entire batch. The mean is always calculated per task group.
std_threshold (Optional[float]) – If provided, task groups with a reward standard deviation equal or below this threshold will be skipped.
- calculate_last_step_advantage(exps: Dict[str, Experience], precomputed_std: Tensor | None = None) Tuple[Dict[str, float], Dict[str, float], bool][source]#
Calculate group advantage for a given group of experiences.
- Parameters:
exps (Dict[str, Experience]) – One experience per run, keyed by run ID.
precomputed_std (Optional[torch.Tensor]) – Precomputed standard deviation for batch-level calculation.
- Returns:
Scores for each run. Dict[str, float]: Metrics for logging. bool: Whether this group should be skipped.
- Return type:
Dict[str, float]
- broadcast_advantages(run_exps: Dict[str, List[Experience]], scores: Dict[str, float]) Dict[str, List[Experience]][source]#
Broadcast the calculated advantages to all previous steps in each run.
- Parameters:
run_exps (Dict[str, List[Experience]]) – Experiences grouped by run ID.
scores (Dict[str, float]) – Calculated scores for each run.
- Returns:
Updated experiences with advantages broadcasted.
- Return type:
Dict[str, List[Experience]]
- process(exps: List[Experience]) Tuple[List[Experience], Dict][source]#
Process a list of experiences and return a transformed list.
- Parameters:
exps (List[Experience]) – List of experiences to process, which contains all experiences generated by the Explorer in one explore step.
- Returns:
A tuple containing the processed list of experiences and a dictionary of metrics.
- Return type:
Tuple[List[Experience], Dict]