trinity.trainer.verl.losses module

trinity.trainer.verl.losses module#

Trinity policy loss function for veRL’s engine-based training API. This module provides a loss function compatible with veRL’s BaseEngine.forward_backward_batch() interface, replacing the old DataParallelPPOActor.update_policy() approach. The loss function signature expected by veRL’s engine:

def loss_fn(model_output, data: TensorDict, dp_group=None) -> (loss, metrics)

class trinity.trainer.verl.losses.TrinityPolicyLoss(algo_config: AlgorithmConfig)[source]#

Bases: object

Picklable policy loss callable for veRL’s engine API. Wraps Trinity’s POLICY_LOSS_FN, KL_FN, and ENTROPY_LOSS_FN registries into a single callable that can be serialized by Ray and sent to remote workers via set_loss_fn().

__init__(algo_config: AlgorithmConfig)[source]#
trinity.trainer.verl.losses.build_trinity_loss(algo_config: AlgorithmConfig) TrinityPolicyLoss[source]#

Build a TrinityPolicyLoss instance for veRL’s engine API.