[Config][Disaggregated] Add timeout configuration for the torch.store and add KVTransferConfig.kv_connector_extra_config (#14367)

Signed-off-by: Mathis Felardos <mathis@mistral.ai>
This commit is contained in:
Mathis Felardos 2025-03-13 04:15:20 +01:00 committed by GitHub
parent 128bf75283
commit 1bd32bc8dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 11 deletions

View File

@ -2837,6 +2837,9 @@ class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection # The KV connector port, used to build distributed connection
kv_port: int = 14579 kv_port: int = 14579
# any extra config that the connector may need
kv_connector_extra_config: dict[str, Any] = {}
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
@ -2896,6 +2899,9 @@ class KVTransferConfig(BaseModel):
return self.kv_connector is not None and \ return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"] self.kv_role in ["kv_consumer", "kv_both"]
def get_from_extra_config(self, key, default) -> Any:
return self.kv_connector_extra_config.get(key, default)
class CompilationLevel: class CompilationLevel:
# constants for the levels of the compilation process # constants for the levels of the compilation process

View File

@ -6,7 +6,7 @@
- Distributed KV cache transmission using PyNccl pipes. - Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`. - Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition - Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to - Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow. stop the prefill instance when the decode instance is slow.
""" """
import threading import threading

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
This module implements a PyNccl pipe for sending and receiving This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced Optional[torch.Tensor] between distributed ranks with advanced
communication features. communication features.
Key Features: Key Features:
@ -59,11 +59,13 @@ class PyNcclPipe(KVPipeBase):
self.device = self._select_device(device) self.device = self._select_device(device)
# build distributed connection and send/recv implementation # build distributed connection and send/recv implementation
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
self.group = StatelessProcessGroup.create( self.group = StatelessProcessGroup.create(
host=self.config.kv_ip, host=self.config.kv_ip,
port=self.config.kv_port + port_offset, port=self.config.kv_port + port_offset,
rank=self.kv_rank, rank=self.kv_rank,
world_size=self.kv_parallel_size, world_size=self.kv_parallel_size,
store_timeout=store_timeout,
) )
# add a barrier to make sure the connection is initiated properly # add a barrier to make sure the connection is initiated properly
self.group.barrier() self.group.barrier()
@ -134,11 +136,11 @@ class PyNcclPipe(KVPipeBase):
Create a buffer to receive the tensor based on the provided metadata. Create a buffer to receive the tensor based on the provided metadata.
Parameters: Parameters:
- metadata: A dictionary with keys "dtype" and "shape", describing - metadata: A dictionary with keys "dtype" and "shape", describing
the tensor's data type and shape. the tensor's data type and shape.
Returns: Returns:
- buffer: A tensor of the specified type and shape, allocated on - buffer: A tensor of the specified type and shape, allocated on
self.device. self.device.
""" """
return torch.empty(metadata["shape"], return torch.empty(metadata["shape"],
@ -159,18 +161,18 @@ class PyNcclPipe(KVPipeBase):
Receive the metadata dictionary from the target rank. Receive the metadata dictionary from the target rank.
Returns: Returns:
- metadata: A dictionary with keys "dtype" and "shape" describing - metadata: A dictionary with keys "dtype" and "shape" describing
the tensor. the tensor.
""" """
return self.group.recv_obj(self.target_rank_for_recv) return self.group.recv_obj(self.target_rank_for_recv)
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
""" """
The actual implementation of sending the tensor and its metadata to the The actual implementation of sending the tensor and its metadata to the
target rank. target rank.
Parameters: Parameters:
- tensor: The input tensor to be sent, or None if no tensor is - tensor: The input tensor to be sent, or None if no tensor is
being sent. being sent.
""" """
metadata = self._make_metadata(tensor) metadata = self._make_metadata(tensor)
@ -181,7 +183,7 @@ class PyNcclPipe(KVPipeBase):
def _recv_impl(self) -> Optional[torch.Tensor]: def _recv_impl(self) -> Optional[torch.Tensor]:
""" """
The actual implementation of receiving a tensor and its metadata from The actual implementation of receiving a tensor and its metadata from
the target rank. the target rank.
Returns: Returns:
@ -213,7 +215,7 @@ class PyNcclPipe(KVPipeBase):
def block_if_full(self): def block_if_full(self):
""" """
Block the current thread if the buffer size is larger than the Block the current thread if the buffer size is larger than the
threshold. threshold.
""" """
while self.buffer_size > self.buffer_size_thresh: while self.buffer_size > self.buffer_size_thresh:
@ -222,7 +224,7 @@ class PyNcclPipe(KVPipeBase):
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
""" """
Sends a tensor and its metadata to the destination rank in a Sends a tensor and its metadata to the destination rank in a
non-blocking way. non-blocking way.
Parameters: Parameters:

View File

@ -5,6 +5,7 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses import dataclasses
import datetime
import pickle import pickle
import time import time
from collections import deque from collections import deque
@ -217,6 +218,7 @@ class StatelessProcessGroup:
rank: int, rank: int,
world_size: int, world_size: int,
data_expiration_seconds: int = 3600, data_expiration_seconds: int = 3600,
store_timeout: int = 300,
) -> "StatelessProcessGroup": ) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not """A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. pollute the global state.
@ -238,6 +240,7 @@ class StatelessProcessGroup:
port=port, port=port,
world_size=world_size, world_size=world_size,
is_master=(rank == 0), is_master=(rank == 0),
timeout=datetime.timedelta(seconds=store_timeout),
) )
return StatelessProcessGroup( return StatelessProcessGroup(