trinity.algorithm.advantage_fn.jsd_advantage module#
Jensen-Shannon Divergence (JSD) advantage computation.
JSD(P||Q) = beta * KL(teacher||M) + (1-beta) * KL(student||M), where M = beta*teacher + (1-beta)*student. When beta=0.5, this gives the standard symmetric JSD. All computations in log-space (no exp). Aligned with SWIFT: beta=0/1 yield pure KL; temperature and optional chunking supported.
- class trinity.algorithm.advantage_fn.jsd_advantage.JSDAdvantage(lambda_coef: float = 0.5, kl_coef: float = 1.0, temperature: float = 1.0, chunk_size: int | None = None)[source]#
Bases:
AdvantageFnAdvantage function using Jensen-Shannon Divergence (log-space, SWIFT-aligned).
Computes JSD in log-space only: - beta=0: JSD = KL(student || teacher) [pure KL] - beta=1: JSD = KL(teacher || student) [pure KL] - else: JSD = beta*KL(teacher||M) + (1-beta)*KL(student||M), M = mixture in log-space.
The teacher_logprobs should be stored in Experience.teacher_logprobs by the workflow during exploration.
- __init__(lambda_coef: float = 0.5, kl_coef: float = 1.0, temperature: float = 1.0, chunk_size: int | None = None) None[source]#
Initialize JSD advantage function.
- Parameters:
lambda_coef β Weight beta for mixture. JSD = beta*KL(teacher||M) + (1-beta)*KL(student||M). beta=0 => KL(student||teacher), beta=1 => KL(teacher||student). Range: [0, 1].
kl_coef β Overall scaling coefficient for advantages.
temperature β Temperature scaling for log-probs (log_probs / temperature). 1.0 = no scaling.
chunk_size β If set, process flattened valid tokens in chunks to reduce peak memory; None = no chunking.