trinity.utils.distributed module#

For distributed training with multiple process groups.

trinity.utils.distributed.is_ipv6_address(ip_str: str) bool[源代码]#
trinity.utils.distributed.get_available_port() int[源代码]#
trinity.utils.distributed.get_endpoint(host: str, port: int) str[源代码]#
trinity.utils.distributed.is_port_available(port: int, host='127.0.0.1') bool[源代码]#
trinity.utils.distributed.init_process_group(host: str, port: int, group_name: str, backend: str | Backend = 'nccl', timeout: float | None = None, world_size: int = -1, rank: int = -1, pg_options: Any | None = None, device_id: device | None = None)[源代码]#

This function is used to initialize the process group. It requires torch >= 2.6.0

class trinity.utils.distributed.WeightTransferEngine[源代码]#

基类:object

abstractmethod sync_weight(iterator)[源代码]#

Perform the weight sync.

abstractmethod teardown()[源代码]#

Tear down the weight sync group.

static create(engine_type: str, master_address: str, master_port: int, world_size: int, group_name: str)[源代码]#

Factory method to create the appropriate weight transfer engine based on the rollout engine type.

class trinity.utils.distributed.VLLMWeightTransferEngine(master_address: str, master_port: int, world_size: int, group_name: str)[源代码]#

基类:WeightTransferEngine

A helper class to manage NCCL weight synchronization using vLLM's API.

__init__(master_address: str, master_port: int, world_size: int, group_name: str)[源代码]#

Initialize the NCCL process group for weight sync with vLLM's API.

sync_weight(iterator)[源代码]#

Perform the NCCL weight sync using vLLM's API.

teardown()[源代码]#

Tear down the weight sync group.

class trinity.utils.distributed.SGLangWeightTransferEngine(master_address: str, master_port: int, world_size: int, group_name: str)[源代码]#

基类:WeightTransferEngine

A helper class to manage NCCL weight synchronization using SGLang's API.

__init__(master_address: str, master_port: int, world_size: int, group_name: str)[源代码]#

Initialize the NCCL process group for weight sync with SGLang's API.

sync_weight(iterator)[源代码]#

Perform the NCCL weight sync using SGLang's API.

teardown()[源代码]#

Tear down the weight sync group.