trinity.algorithm.advantage_fn.on_policy_distill_advantage module

trinity.algorithm.advantage_fn.on_policy_distill_advantage module#

On-Policy Distillation advantage computation.

Reference: Tinker library’s on-policy distillation.

advantages = -(student_logprobs - teacher_logprobs)

= teacher_logprobs - student_logprobs

class trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage(kl_coef: float = 1.0)[source]#

Bases: AdvantageFn

Advantage function for on-policy distillation.

Computes: advantages = kl_coef * (teacher_logprobs - student_logprobs)

The teacher_logprobs should be stored in Experience.teacher_logprobs by the workflow during exploration.

__init__(kl_coef: float = 1.0) None[source]#
classmethod default_args() Dict[source]#
Returns:

The default init arguments for the advantage function.

Return type:

Dict