[Config][Disaggregated] Add timeout configuration for the torch.store and add KVTransferConfig.kv_connector_extra_config (#14367)

Signed-off-by: Mathis Felardos <mathis@mistral.ai>
This commit is contained in:
Mathis Felardos 2025-03-13 04:15:20 +01:00 committed by GitHub
parent 128bf75283
commit 1bd32bc8dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 11 deletions

View File

@ -2837,6 +2837,9 @@ class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection # The KV connector port, used to build distributed connection
kv_port: int = 14579 kv_port: int = 14579
# any extra config that the connector may need
kv_connector_extra_config: dict[str, Any] = {}
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
@ -2896,6 +2899,9 @@ class KVTransferConfig(BaseModel):
return self.kv_connector is not None and \ return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"] self.kv_role in ["kv_consumer", "kv_both"]
def get_from_extra_config(self, key, default) -> Any:
return self.kv_connector_extra_config.get(key, default)
class CompilationLevel: class CompilationLevel:
# constants for the levels of the compilation process # constants for the levels of the compilation process

View File

@ -59,11 +59,13 @@ class PyNcclPipe(KVPipeBase):
self.device = self._select_device(device) self.device = self._select_device(device)
# build distributed connection and send/recv implementation # build distributed connection and send/recv implementation
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
self.group = StatelessProcessGroup.create( self.group = StatelessProcessGroup.create(
host=self.config.kv_ip, host=self.config.kv_ip,
port=self.config.kv_port + port_offset, port=self.config.kv_port + port_offset,
rank=self.kv_rank, rank=self.kv_rank,
world_size=self.kv_parallel_size, world_size=self.kv_parallel_size,
store_timeout=store_timeout,
) )
# add a barrier to make sure the connection is initiated properly # add a barrier to make sure the connection is initiated properly
self.group.barrier() self.group.barrier()

View File

@ -5,6 +5,7 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses import dataclasses
import datetime
import pickle import pickle
import time import time
from collections import deque from collections import deque
@ -217,6 +218,7 @@ class StatelessProcessGroup:
rank: int, rank: int,
world_size: int, world_size: int,
data_expiration_seconds: int = 3600, data_expiration_seconds: int = 3600,
store_timeout: int = 300,
) -> "StatelessProcessGroup": ) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not """A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. pollute the global state.
@ -238,6 +240,7 @@ class StatelessProcessGroup:
port=port, port=port,
world_size=world_size, world_size=world_size,
is_master=(rank == 0), is_master=(rank == 0),
timeout=datetime.timedelta(seconds=store_timeout),
) )
return StatelessProcessGroup( return StatelessProcessGroup(