trinity.trainer.verl.losses module#
Trinity policy loss function for veRL’s engine-based training API. This module provides a loss function compatible with veRL’s BaseEngine.forward_backward_batch() interface, replacing the old DataParallelPPOActor.update_policy() approach. The loss function signature expected by veRL’s engine:
def loss_fn(model_output, data: TensorDict, dp_group=None) -> (loss, metrics)
- class trinity.trainer.verl.losses.TrinityPolicyLoss(algo_config: AlgorithmConfig)[source]#
Bases:
objectPicklable policy loss callable for veRL’s engine API. Wraps Trinity’s POLICY_LOSS_FN, KL_FN, and ENTROPY_LOSS_FN registries into a single callable that can be serialized by Ray and sent to remote workers via set_loss_fn().
- __init__(algo_config: AlgorithmConfig)[source]#
- trinity.trainer.verl.losses.build_trinity_loss(algo_config: AlgorithmConfig) TrinityPolicyLoss[source]#
Build a TrinityPolicyLoss instance for veRL’s engine API.