Source code for trinity.common.rewards.dapo_reward

# -*- coding: utf-8 -*-
"""Reward Function with Overlong Reward Shaping described in DAPO (https://arxiv.org/pdf/2503.14476)"""
from typing import Optional

import torch

from trinity.common.rewards.naive_dapo_score import compute_score
from trinity.common.rewards.reward_fn import RewardFn


[docs] class MathDAPORewardFn(RewardFn): """A reward function that follows the definition in DAPO for math task."""
[docs] def __init__( self, enable_overlong_penalty: Optional[bool] = None, penalty_factor: Optional[float] = None, max_response_length: Optional[int] = None, cache_length: Optional[int] = None, ) -> None: """Initialize DAPO math reward settings. Args: enable_overlong_penalty: Whether to apply overlong response shaping. penalty_factor: Magnitude for overlong penalties. max_response_length: Maximum allowed response length in tokens. cache_length: Soft-penalty transition window in tokens. """ self.enable_overlong_penalty = enable_overlong_penalty self.penalty_factor = penalty_factor self.max_response_length = max_response_length self.cache_length = cache_length
def __call__( # type: ignore self, response: str, response_token: torch.Tensor, truth: str, **kwargs, ) -> dict[str, float]: """Compute DAPO reward components for one response. Args: response: Model-generated response text. response_token: Response token ids. truth: Ground-truth answer string. **kwargs: Extra arguments for compatibility with reward API. Returns: dict[str, float]: Reward components containing accuracy and format_score. """ score, extracted_answer = compute_score(response, truth) # DAPO paper (Sec. 2.4): +1 / -1 rule-based outcome reward accuracy_score = 1.0 if score >= 0.5 else -1.0 format_score = 0.0 if self.enable_overlong_penalty: format_score = self.compute_overlong_penalty(response_token) info = kwargs.get("info") if info is not None: info["ground_truth"] = truth info["extracted_answer"] = extracted_answer return { "accuracy": accuracy_score, "format_score": format_score, }
[docs] def compute_overlong_penalty(self, response_token): """Compute soft/hard penalty for long responses. Args: response_token: Response token ids. Returns: float: Length-based shaping value, where negative values penalize overlong outputs. """ assert ( self.max_response_length is not None and self.cache_length is not None and self.penalty_factor is not None ), "When enable_overlong_penalty = true, max_response_length, penalty_factor, cache_length must be set" assert ( self.max_response_length > self.cache_length ), "max_response_length must be greater than cache_length" response_len = len(response_token) expected_len = self.max_response_length - self.cache_length if response_len < expected_len: return 0.0 elif response_len > self.max_response_length: return -self.penalty_factor else: return (expected_len - response_len) / self.cache_length * self.penalty_factor