Update deprecated type hinting in vllm/device_allocator and vllm/distributed (#18126)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-14 12:07:57 +01:00 committed by GitHub
parent 9b5b39b650
commit dc372b9c8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 105 additions and 105 deletions

View File

@ -74,8 +74,6 @@ exclude = [
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
"vllm/attention/**/*.py" = ["UP006", "UP035"]
"vllm/core/**/*.py" = ["UP006", "UP035"]
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
"vllm/distributed/**/*.py" = ["UP006", "UP035"]
"vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]

View File

@ -11,7 +11,7 @@ import dataclasses
import gc
import os
from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
@ -63,7 +63,7 @@ except ModuleNotFoundError:
libcudart = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = Tuple[int, int, int, int]
HandleType = tuple[int, int, int, int]
@dataclasses.dataclass
@ -148,9 +148,9 @@ class CuMemAllocator:
"Please track https://github.com/pytorch/pytorch/issues/147851 "
"for the latest updates.")
self.pointer_to_data: Dict[int, AllocationData] = {}
self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {}
self.allocator_and_pools: dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
@ -172,7 +172,7 @@ class CuMemAllocator:
def sleep(
self,
offload_tags: Optional[Union[Tuple[str, ...],
offload_tags: Optional[Union[tuple[str, ...],
str]] = None) -> None:
"""
Put the allocator in sleep mode.

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
import torch
import torch.distributed
@ -32,7 +32,7 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
return get_tp_group().gather(input_, dst, dim)
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor,
Any]]] = None,
src: int = 0):
if not torch.distributed.is_initialized():

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.distributed as dist
@ -160,7 +160,7 @@ class DeviceCommunicatorBase:
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import List, Optional
from typing import Optional
import torch
from torch.distributed import ProcessGroup
@ -126,7 +126,7 @@ class _CPUSHMDistributed:
def gather(self,
input: torch.Tensor,
gather_list: Optional[List[torch.Tensor]],
gather_list: Optional[list[torch.Tensor]],
dst: int = -1,
group: Optional[ProcessGroup] = None) -> None:
# Note: different from the torch gather, here we use local dst rank.

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
from torch.distributed import ProcessGroup
@ -154,7 +154,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_impl is not None
hidden_states, router_logits = self.all2all_impl.dispatch(
hidden_states, router_logits)

View File

@ -6,7 +6,7 @@ convenient for use when we just need to call a few functions.
import ctypes
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa
@ -32,7 +32,7 @@ class cudaIpcMemHandle_t(ctypes.Structure):
class Function:
name: str
restype: Any
argtypes: List[Any]
argtypes: list[Any]
def find_loaded_library(lib_name) -> Optional[str]:
@ -97,11 +97,11 @@ class CudaRTLibrary:
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
if so_file is None:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager
from typing import List, Optional, Union
from typing import Optional, Union
import torch
import torch.distributed as dist
@ -276,7 +276,7 @@ class CustomAllreduce:
@staticmethod
def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> List[int]:
uncached: Optional[bool] = False) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group)
@ -284,7 +284,7 @@ class CustomAllreduce:
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = []
pointers: list[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer) # type: ignore
@ -293,7 +293,7 @@ class CustomAllreduce:
return pointers
@staticmethod
def free_shared_buffer(pointers: List[int],
def free_shared_buffer(pointers: list[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = 0) -> None:
if rank is None:

View File

@ -7,8 +7,9 @@ import pickle
import subprocess
import sys
import tempfile
from collections.abc import Sequence
from itertools import product
from typing import Dict, List, Optional, Sequence
from typing import Optional
import torch.distributed as dist
import torch.multiprocessing as mp
@ -149,7 +150,7 @@ def can_actually_p2p(
p_src.join()
p_tgt.join()
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
result: List[bool] = []
result: list[bool] = []
for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get()
b = result_queue.get()
@ -175,7 +176,7 @@ def can_actually_p2p(
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
_gpu_p2p_access_cache: Optional[dict[str, bool]] = None
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
@ -204,7 +205,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger.info("generating GPU P2P access cache in %s", path)
cache: Dict[str, bool] = {}
cache: dict[str, bool] = {}
ids = list(range(num_dev))
# batch of all pairs of GPUs
batch_src, batch_tgt = zip(*list(product(ids, ids)))

View File

@ -24,7 +24,7 @@
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.distributed import ReduceOp
@ -121,7 +121,7 @@ class ncclRedOpTypeEnum:
class Function:
name: str
restype: Any
argtypes: List[Any]
argtypes: list[Any]
class NCCLLibrary:
@ -210,11 +210,11 @@ class NCCLLibrary:
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
@ -238,7 +238,7 @@ class NCCLLibrary:
raise e
if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
_funcs: dict[str, Any] = {}
for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype

View File

@ -8,7 +8,7 @@ from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from threading import Event
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Union
from unittest.mock import patch
import torch
@ -173,9 +173,9 @@ class ShmRingBuffer:
@dataclass
class Handle:
local_reader_ranks: List[int] = field(default_factory=list)
local_reader_ranks: list[int] = field(default_factory=list)
buffer_handle: Optional[Tuple[int, int, int, str]] = None
buffer_handle: Optional[tuple[int, int, int, str]] = None
local_subscribe_addr: Optional[str] = None
remote_subscribe_addr: Optional[str] = None
remote_addr_ipv6: bool = False
@ -187,7 +187,7 @@ class MessageQueue:
self,
n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory
local_reader_ranks: Optional[List[int]] = None,
local_reader_ranks: Optional[list[int]] = None,
max_chunk_bytes: int = 1024 * 1024 * 10,
max_chunks: int = 10,
connect_ip: Optional[str] = None,

View File

@ -8,7 +8,7 @@ The class provides two primary abstract methods:
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Tuple, Union
from typing import TYPE_CHECKING, Union
import torch
@ -55,7 +55,7 @@ class KVConnectorBase(ABC):
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
@ -71,7 +71,7 @@ class KVConnectorBase(ABC):
start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM.
kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
for each layer.
hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]):
@ -88,8 +88,8 @@ class KVConnectorBase(ABC):
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
"""
Receive KV caches and hidden states from the connector.
@ -104,7 +104,7 @@ class KVConnectorBase(ABC):
The model executable from vLLM modelrunner.
model_input (ModelInputForGPUWithSamplingMetadata):
The model input from vLLM modelrunner.
kv_caches (List[torch.Tensor]):
kv_caches (list[torch.Tensor]):
List of KV caches for each layer.
Returns:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import importlib
from typing import TYPE_CHECKING, Callable, Dict, Type
from typing import TYPE_CHECKING, Callable
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
@ -18,7 +18,7 @@ logger = init_logger(__name__)
class KVConnectorFactory:
_registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {}
_registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {}
@classmethod
def register_connector(cls, name: str, module_path: str,
@ -27,7 +27,7 @@ class KVConnectorFactory:
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> Type[KVConnectorBaseType]:
def loader() -> type[KVConnectorBaseType]:
module = importlib.import_module(module_path)
return getattr(module, class_name)

View File

@ -7,7 +7,7 @@ The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
(2) offload and share KV caches.
"""
from typing import TYPE_CHECKING, List, Tuple, Union
from typing import TYPE_CHECKING, Union
import torch
@ -63,8 +63,8 @@ class LMCacheConnector(KVConnectorBase):
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
retrieve_status = self.lmcache_should_retrieve(model_input)
@ -78,7 +78,7 @@ class LMCacheConnector(KVConnectorBase):
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:

View File

@ -6,7 +6,7 @@ The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
database-style KVStore.
"""
import hashlib
from typing import TYPE_CHECKING, List, Tuple, Union
from typing import TYPE_CHECKING, Union
import torch
@ -70,7 +70,7 @@ class MooncakeStoreConnector(KVConnectorBase):
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
@ -113,8 +113,8 @@ class MooncakeStoreConnector(KVConnectorBase):
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
bypass_model_exec = True
input_tokens_tensor = model_input.input_tokens

View File

@ -8,7 +8,7 @@ MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Union
import torch
@ -133,7 +133,7 @@ class SimpleConnector(KVConnectorBase):
)
def select(self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select."
@ -152,7 +152,7 @@ class SimpleConnector(KVConnectorBase):
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
@ -207,8 +207,8 @@ class SimpleConnector(KVConnectorBase):
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
# When bypass_model_exec is set to False, it means that at least for one

View File

@ -5,13 +5,13 @@ import threading
import time
import uuid
from collections import defaultdict
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator
from typing import TYPE_CHECKING, Any, Optional
import msgspec
import torch
import zmq
from typing_extensions import Optional
from vllm import envs
from vllm.config import VllmConfig

View File

@ -5,7 +5,7 @@ This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
1. `send_kv_caches_and_hidden_states`
2. `recv_kv_caches_and_hidden_states
"""
from typing import TYPE_CHECKING, List, Tuple, Union
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
@ -53,7 +53,7 @@ class KVTransferAgent:
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
@ -68,8 +68,8 @@ class KVTransferAgent:
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
return self.connector.recv_kv_caches_and_hidden_states(

View File

@ -13,7 +13,7 @@ These classes above are abstracted behind class `KVCacheBufferBase`.
"""
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import Optional
import torch
@ -93,7 +93,7 @@ class KVLookupBufferBase(KVCacheBufferBase):
@abstractmethod
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
@ -111,7 +111,7 @@ class KVLookupBufferBase(KVCacheBufferBase):
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
List[Optional[torch.Tensor]]: A list of tensors. Can be None.
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.

View File

@ -11,7 +11,7 @@
"""
import threading
from collections import deque
from typing import Deque, List, Optional, Union
from typing import Optional, Union
import torch
@ -38,7 +38,7 @@ class SimpleBuffer(KVLookupBufferBase):
data_pipe: on device (e.g. GPU)
"""
self.buffer: Deque[List[torch.Tensor]] = deque()
self.buffer: deque[list[torch.Tensor]] = deque()
self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
@ -50,8 +50,8 @@ class SimpleBuffer(KVLookupBufferBase):
self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None
def _matches(self, tokens_roi_sender: List[torch.Tensor],
tokens_roi_recver: List[torch.Tensor]):
def _matches(self, tokens_roi_sender: list[torch.Tensor],
tokens_roi_recver: list[torch.Tensor]):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
@ -88,7 +88,7 @@ class SimpleBuffer(KVLookupBufferBase):
tensor = tensor.float()
self.data_pipe.send_tensor(tensor)
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]):
if isinstance(data, torch.Tensor):
return data.element_size() * data.numel()
@ -151,7 +151,7 @@ class SimpleBuffer(KVLookupBufferBase):
tokens_roi_recver = [input_tokens, roi]
def is_buffer_available(
tokens_roi_recver: List[torch.Tensor], ) -> bool:
tokens_roi_recver: list[torch.Tensor], ) -> bool:
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
@ -184,7 +184,7 @@ class SimpleBuffer(KVLookupBufferBase):
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
assert self.request_handling_thread is None, \
"drop_select should be called by the KV cache consumer "\

View File

@ -15,7 +15,7 @@
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Optional
import torch
@ -35,7 +35,7 @@ class BrokenPipeException(Exception):
super().__init__(self.message)
Metadata = Dict[str, Optional[torch.Tensor]]
Metadata = dict[str, Optional[torch.Tensor]]
class PyNcclPipe(KVPipeBase):
@ -83,7 +83,7 @@ class PyNcclPipe(KVPipeBase):
def _get_device_send_recv_impl(
self, group: StatelessProcessGroup
) -> Tuple[Callable[[torch.Tensor, int], None], Callable[
) -> tuple[Callable[[torch.Tensor, int], None], Callable[
[torch.Tensor, int], None]]:
send: Callable[[torch.Tensor, int], None]

View File

@ -29,7 +29,7 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
import torch
@ -54,15 +54,15 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
tensor_dict: dict[str, Union[torch.Tensor, Any]]
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list: List[Tuple[str, Any]] = []
tensor_list: List[torch.Tensor] = []
metadata_list: list[tuple[str, Any]] = []
tensor_list: list[torch.Tensor] = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
@ -78,7 +78,7 @@ def _split_tensor_dict(
return metadata_list, tensor_list
_group_name_counter: Dict[str, int] = {}
_group_name_counter: dict[str, int] = {}
def _get_unique_name(name: str) -> str:
@ -94,7 +94,7 @@ def _get_unique_name(name: str) -> str:
return newname
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
def _register_group(group: "GroupCoordinator") -> None:
@ -182,7 +182,7 @@ class GroupCoordinator:
# available attributes:
rank: int # global rank
ranks: List[int] # global ranks in the group
ranks: list[int] # global ranks in the group
world_size: int # size of the group
# difference between `local_rank` and `rank_in_group`:
# if we have a group of size 4 across two nodes:
@ -201,7 +201,7 @@ class GroupCoordinator:
def __init__(
self,
group_ranks: List[List[int]],
group_ranks: list[list[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_device_communicator: bool,
@ -435,7 +435,7 @@ class GroupCoordinator:
return recv[0]
def broadcast_object_list(self,
obj_list: List[Any],
obj_list: list[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input object list.
@ -518,11 +518,11 @@ class GroupCoordinator:
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
@ -536,7 +536,7 @@ class GroupCoordinator:
rank_in_group = self.rank_in_group
if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = []
metadata_list: list[tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
@ -603,10 +603,10 @@ class GroupCoordinator:
def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
@ -626,7 +626,7 @@ class GroupCoordinator:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
metadata_list: list[tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
@ -661,7 +661,7 @@ class GroupCoordinator:
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
@ -682,7 +682,7 @@ class GroupCoordinator:
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: Dict[str, Any] = {}
tensor_dict: dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
@ -764,7 +764,7 @@ class GroupCoordinator:
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states,
router_logits)
@ -782,7 +782,7 @@ def get_world_group() -> GroupCoordinator:
return _WORLD
def init_world_group(ranks: List[int], local_rank: int,
def init_world_group(ranks: list[int], local_rank: int,
backend: str) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=[ranks],
@ -794,7 +794,7 @@ def init_world_group(ranks: List[int], local_rank: int,
def init_model_parallel_group(
group_ranks: List[List[int]],
group_ranks: list[list[int]],
local_rank: int,
backend: str,
use_message_queue_broadcaster: bool = False,
@ -1182,7 +1182,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
source_rank: int = 0) -> List[bool]:
source_rank: int = 0) -> list[bool]:
"""
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same

View File

@ -10,7 +10,8 @@ import pickle
import socket
import time
from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import Any, Optional
import torch
from torch.distributed import ProcessGroup, TCPStore
@ -69,7 +70,7 @@ def split_tensor_along_last_dim(
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]:
pp_size: int) -> tuple[int, int]:
"""Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions,
@ -132,15 +133,15 @@ class StatelessProcessGroup:
data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
# src rank -> counter
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
broadcast_send_counter: int = 0
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(
broadcast_recv_src_counter: dict[int, int] = dataclasses.field(
default_factory=dict)
# A deque to store the data entries, with key and timestamp.
entries: Deque[Tuple[str,
entries: deque[tuple[str,
float]] = dataclasses.field(default_factory=deque)
def __post_init__(self):