trinity.algorithm.advantage_fn.jsd_advantage module

trinity.algorithm.advantage_fn.jsd_advantage module#

Jensen-Shannon Divergence (JSD) advantage computation.

JSD(P||Q) = beta * KL(teacher||M) + (1-beta) * KL(student||M), where M = beta*teacher + (1-beta)*student. When beta=0.5, this gives the standard symmetric JSD. All computations in log-space (no exp). Aligned with SWIFT: beta=0/1 yield pure KL; temperature and optional chunking supported.

class trinity.algorithm.advantage_fn.jsd_advantage.JSDAdvantage(lambda_coef: float = 0.5, kl_coef: float = 1.0, temperature: float = 1.0, chunk_size: int | None = None)[source]#

Bases: AdvantageFn

Advantage function using Jensen-Shannon Divergence (log-space, SWIFT-aligned).

Computes JSD in log-space only: - beta=0: JSD = KL(student || teacher) [pure KL] - beta=1: JSD = KL(teacher || student) [pure KL] - else: JSD = beta*KL(teacher||M) + (1-beta)*KL(student||M), M = mixture in log-space.

The teacher_logprobs should be stored in Experience.teacher_logprobs by the workflow during exploration.

__init__(lambda_coef: float = 0.5, kl_coef: float = 1.0, temperature: float = 1.0, chunk_size: int | None = None) None[source]#

Initialize JSD advantage function.

Parameters:
  • lambda_coef – Weight beta for mixture. JSD = beta*KL(teacher||M) + (1-beta)*KL(student||M). beta=0 => KL(student||teacher), beta=1 => KL(teacher||student). Range: [0, 1].

  • kl_coef – Overall scaling coefficient for advantages.

  • temperature – Temperature scaling for log-probs (log_probs / temperature). 1.0 = no scaling.

  • chunk_size – If set, process flattened valid tokens in chunks to reduce peak memory; None = no chunking.

classmethod default_args() Dict[source]#
Returns:

The default init arguments for the advantage function.

Return type:

Dict