trinity.utils.distributed module
For distributed training with multiple process groups.
-
trinity.utils.distributed.is_ipv6_address(ip_str: str) → bool[源代码]
-
trinity.utils.distributed.get_available_port() → int[源代码]
-
trinity.utils.distributed.get_endpoint(host: str, port: int) → str[源代码]
-
trinity.utils.distributed.is_port_available(port: int, host='127.0.0.1') → bool[源代码]
-
trinity.utils.distributed.init_process_group(host: str, port: int, group_name: str, backend: str | Backend = 'nccl', timeout: float | None = None, world_size: int = -1, rank: int = -1, pg_options: Any | None = None, device_id: device | None = None)[源代码]
This function is used to initialize the process group. It requires torch >= 2.6.0
-
class trinity.utils.distributed.WeightTransferEngine[源代码]
基类:object
-
abstractmethod sync_weight(iterator)[源代码]
Perform the weight sync.
-
abstractmethod teardown()[源代码]
Tear down the weight sync group.
-
static create(engine_type: str, master_address: str, master_port: int, world_size: int, group_name: str)[源代码]
Factory method to create the appropriate weight transfer engine based on the rollout engine type.
-
class trinity.utils.distributed.VLLMWeightTransferEngine(master_address: str, master_port: int, world_size: int, group_name: str)[源代码]
基类:WeightTransferEngine
A helper class to manage NCCL weight synchronization using vLLM's API.
-
__init__(master_address: str, master_port: int, world_size: int, group_name: str)[源代码]
Initialize the NCCL process group for weight sync with vLLM's API.
-
sync_weight(iterator)[源代码]
Perform the NCCL weight sync using vLLM's API.
-
teardown()[源代码]
Tear down the weight sync group.
-
class trinity.utils.distributed.SGLangWeightTransferEngine(master_address: str, master_port: int, world_size: int, group_name: str)[源代码]
基类:WeightTransferEngine
A helper class to manage NCCL weight synchronization using SGLang's API.
-
__init__(master_address: str, master_port: int, world_size: int, group_name: str)[源代码]
Initialize the NCCL process group for weight sync with SGLang's API.
-
sync_weight(iterator)[源代码]
Perform the NCCL weight sync using SGLang's API.
-
teardown()[源代码]
Tear down the weight sync group.