Update deprecated type hinting in vllm/device_allocator and vllm/distributed (#18126)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-14 12:07:57 +01:00 committed by GitHub
parent 9b5b39b650
commit dc372b9c8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 105 additions and 105 deletions

View File

@ -74,8 +74,6 @@ exclude = [
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0 # Python 3.8 typing. TODO: Remove these excludes after v1.0.0
"vllm/attention/**/*.py" = ["UP006", "UP035"] "vllm/attention/**/*.py" = ["UP006", "UP035"]
"vllm/core/**/*.py" = ["UP006", "UP035"] "vllm/core/**/*.py" = ["UP006", "UP035"]
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
"vllm/distributed/**/*.py" = ["UP006", "UP035"]
"vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]

View File

@ -11,7 +11,7 @@ import dataclasses
import gc import gc
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Optional, Union
import torch import torch
@ -63,7 +63,7 @@ except ModuleNotFoundError:
libcudart = None libcudart = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle # py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = Tuple[int, int, int, int] HandleType = tuple[int, int, int, int]
@dataclasses.dataclass @dataclasses.dataclass
@ -148,9 +148,9 @@ class CuMemAllocator:
"Please track https://github.com/pytorch/pytorch/issues/147851 " "Please track https://github.com/pytorch/pytorch/issues/147851 "
"for the latest updates.") "for the latest updates.")
self.pointer_to_data: Dict[int, AllocationData] = {} self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {} self.allocator_and_pools: dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None: def python_malloc_callback(self, allocation_handle: HandleType) -> None:
""" """
@ -172,7 +172,7 @@ class CuMemAllocator:
def sleep( def sleep(
self, self,
offload_tags: Optional[Union[Tuple[str, ...], offload_tags: Optional[Union[tuple[str, ...],
str]] = None) -> None: str]] = None) -> None:
""" """
Put the allocator in sleep mode. Put the allocator in sleep mode.

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Optional, Union from typing import Any, Optional, Union
import torch import torch
import torch.distributed import torch.distributed
@ -32,7 +32,7 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
return get_tp_group().gather(input_, dst, dim) return get_tp_group().gather(input_, dst, dim)
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor,
Any]]] = None, Any]]] = None,
src: int = 0): src: int = 0):
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -160,7 +160,7 @@ class DeviceCommunicatorBase:
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Dispatch the hidden states and router logits to the appropriate device. Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from typing import List, Optional from typing import Optional
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -126,7 +126,7 @@ class _CPUSHMDistributed:
def gather(self, def gather(self,
input: torch.Tensor, input: torch.Tensor,
gather_list: Optional[List[torch.Tensor]], gather_list: Optional[list[torch.Tensor]],
dst: int = -1, dst: int = -1,
group: Optional[ProcessGroup] = None) -> None: group: Optional[ProcessGroup] = None) -> None:
# Note: different from the torch gather, here we use local dst rank. # Note: different from the torch gather, here we use local dst rank.

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -154,7 +154,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_impl is not None assert self.all2all_impl is not None
hidden_states, router_logits = self.all2all_impl.dispatch( hidden_states, router_logits = self.all2all_impl.dispatch(
hidden_states, router_logits) hidden_states, router_logits)

View File

@ -6,7 +6,7 @@ convenient for use when we just need to call a few functions.
import ctypes import ctypes
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes` # this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa import torch # noqa
@ -32,7 +32,7 @@ class cudaIpcMemHandle_t(ctypes.Structure):
class Function: class Function:
name: str name: str
restype: Any restype: Any
argtypes: List[Any] argtypes: list[Any]
def find_loaded_library(lib_name) -> Optional[str]: def find_loaded_library(lib_name) -> Optional[str]:
@ -97,11 +97,11 @@ class CudaRTLibrary:
# class attribute to store the mapping from the path to the library # class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times # to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {} path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path # class attribute to store the mapping from library path
# to the corresponding dictionary # to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None): def __init__(self, so_file: Optional[str] = None):
if so_file is None: if so_file is None:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional, Union from typing import Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -276,7 +276,7 @@ class CustomAllreduce:
@staticmethod @staticmethod
def create_shared_buffer(size_in_bytes: int, def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> List[int]: uncached: Optional[bool] = False) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes) pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group) world_size = dist.get_world_size(group=group)
@ -284,7 +284,7 @@ class CustomAllreduce:
handles = [None] * world_size handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group) dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = [] pointers: list[int] = []
for i, h in enumerate(handles): for i, h in enumerate(handles):
if i == rank: if i == rank:
pointers.append(pointer) # type: ignore pointers.append(pointer) # type: ignore
@ -293,7 +293,7 @@ class CustomAllreduce:
return pointers return pointers
@staticmethod @staticmethod
def free_shared_buffer(pointers: List[int], def free_shared_buffer(pointers: list[int],
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
rank: Optional[int] = 0) -> None: rank: Optional[int] = 0) -> None:
if rank is None: if rank is None:

View File

@ -7,8 +7,9 @@ import pickle
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
from collections.abc import Sequence
from itertools import product from itertools import product
from typing import Dict, List, Optional, Sequence from typing import Optional
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
@ -149,7 +150,7 @@ def can_actually_p2p(
p_src.join() p_src.join()
p_tgt.join() p_tgt.join()
assert p_src.exitcode == 0 and p_tgt.exitcode == 0 assert p_src.exitcode == 0 and p_tgt.exitcode == 0
result: List[bool] = [] result: list[bool] = []
for src, tgt in zip(batch_src, batch_tgt): for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get() a = result_queue.get()
b = result_queue.get() b = result_queue.get()
@ -175,7 +176,7 @@ def can_actually_p2p(
# e.g. used by different vllm engines. The device id in the cache file is a # e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number # **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine. # of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None _gpu_p2p_access_cache: Optional[dict[str, bool]] = None
def gpu_p2p_access_check(src: int, tgt: int) -> bool: def gpu_p2p_access_check(src: int, tgt: int) -> bool:
@ -204,7 +205,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# only the local master process (with local_rank == 0) can # only the local master process (with local_rank == 0) can
# enter this block to calculate the cache # enter this block to calculate the cache
logger.info("generating GPU P2P access cache in %s", path) logger.info("generating GPU P2P access cache in %s", path)
cache: Dict[str, bool] = {} cache: dict[str, bool] = {}
ids = list(range(num_dev)) ids = list(range(num_dev))
# batch of all pairs of GPUs # batch of all pairs of GPUs
batch_src, batch_tgt = zip(*list(product(ids, ids))) batch_src, batch_tgt = zip(*list(product(ids, ids)))

View File

@ -24,7 +24,7 @@
import ctypes import ctypes
import platform import platform
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
@ -121,7 +121,7 @@ class ncclRedOpTypeEnum:
class Function: class Function:
name: str name: str
restype: Any restype: Any
argtypes: List[Any] argtypes: list[Any]
class NCCLLibrary: class NCCLLibrary:
@ -210,11 +210,11 @@ class NCCLLibrary:
# class attribute to store the mapping from the path to the library # class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times # to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {} path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path # class attribute to store the mapping from library path
# to the corresponding dictionary # to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None): def __init__(self, so_file: Optional[str] = None):
@ -238,7 +238,7 @@ class NCCLLibrary:
raise e raise e
if so_file not in NCCLLibrary.path_to_dict_mapping: if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {} _funcs: dict[str, Any] = {}
for func in NCCLLibrary.exported_functions: for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name) f = getattr(self.lib, func.name)
f.restype = func.restype f.restype = func.restype

View File

@ -8,7 +8,7 @@ from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing import shared_memory from multiprocessing import shared_memory
from threading import Event from threading import Event
from typing import Any, List, Optional, Tuple, Union from typing import Any, Optional, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -173,9 +173,9 @@ class ShmRingBuffer:
@dataclass @dataclass
class Handle: class Handle:
local_reader_ranks: List[int] = field(default_factory=list) local_reader_ranks: list[int] = field(default_factory=list)
buffer_handle: Optional[Tuple[int, int, int, str]] = None buffer_handle: Optional[tuple[int, int, int, str]] = None
local_subscribe_addr: Optional[str] = None local_subscribe_addr: Optional[str] = None
remote_subscribe_addr: Optional[str] = None remote_subscribe_addr: Optional[str] = None
remote_addr_ipv6: bool = False remote_addr_ipv6: bool = False
@ -187,7 +187,7 @@ class MessageQueue:
self, self,
n_reader, # number of all readers n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory n_local_reader, # number of local readers through shared memory
local_reader_ranks: Optional[List[int]] = None, local_reader_ranks: Optional[list[int]] = None,
max_chunk_bytes: int = 1024 * 1024 * 10, max_chunk_bytes: int = 1024 * 1024 * 10,
max_chunks: int = 10, max_chunks: int = 10,
connect_ip: Optional[str] = None, connect_ip: Optional[str] = None,

View File

@ -8,7 +8,7 @@ The class provides two primary abstract methods:
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Tuple, Union from typing import TYPE_CHECKING, Union
import torch import torch
@ -55,7 +55,7 @@ class KVConnectorBase(ABC):
self, self,
model_executable: torch.nn.Module, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors], IntermediateTensors],
) -> None: ) -> None:
@ -71,7 +71,7 @@ class KVConnectorBase(ABC):
start and end layer information. start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM. metadata from vLLM.
kv_caches (List[torch.Tensor]): List of KV caches (keys and values) kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
for each layer. for each layer.
hidden_or_intermediate_states (Union[torch.Tensor, hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]): IntermediateTensors]):
@ -88,8 +88,8 @@ class KVConnectorBase(ABC):
def recv_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module, self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor] kv_caches: list[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]: "ModelInputForGPUWithSamplingMetadata"]:
""" """
Receive KV caches and hidden states from the connector. Receive KV caches and hidden states from the connector.
@ -104,7 +104,7 @@ class KVConnectorBase(ABC):
The model executable from vLLM modelrunner. The model executable from vLLM modelrunner.
model_input (ModelInputForGPUWithSamplingMetadata): model_input (ModelInputForGPUWithSamplingMetadata):
The model input from vLLM modelrunner. The model input from vLLM modelrunner.
kv_caches (List[torch.Tensor]): kv_caches (list[torch.Tensor]):
List of KV caches for each layer. List of KV caches for each layer.
Returns: Returns:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import importlib import importlib
from typing import TYPE_CHECKING, Callable, Dict, Type from typing import TYPE_CHECKING, Callable
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
@ -18,7 +18,7 @@ logger = init_logger(__name__)
class KVConnectorFactory: class KVConnectorFactory:
_registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {} _registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {}
@classmethod @classmethod
def register_connector(cls, name: str, module_path: str, def register_connector(cls, name: str, module_path: str,
@ -27,7 +27,7 @@ class KVConnectorFactory:
if name in cls._registry: if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.") raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> Type[KVConnectorBaseType]: def loader() -> type[KVConnectorBaseType]:
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)

View File

@ -7,7 +7,7 @@ The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
(2) offload and share KV caches. (2) offload and share KV caches.
""" """
from typing import TYPE_CHECKING, List, Tuple, Union from typing import TYPE_CHECKING, Union
import torch import torch
@ -63,8 +63,8 @@ class LMCacheConnector(KVConnectorBase):
def recv_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module, self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor] kv_caches: list[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]: "ModelInputForGPUWithSamplingMetadata"]:
retrieve_status = self.lmcache_should_retrieve(model_input) retrieve_status = self.lmcache_should_retrieve(model_input)
@ -78,7 +78,7 @@ class LMCacheConnector(KVConnectorBase):
self, self,
model_executable: torch.nn.Module, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors], IntermediateTensors],
) -> None: ) -> None:

View File

@ -6,7 +6,7 @@ The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
database-style KVStore. database-style KVStore.
""" """
import hashlib import hashlib
from typing import TYPE_CHECKING, List, Tuple, Union from typing import TYPE_CHECKING, Union
import torch import torch
@ -70,7 +70,7 @@ class MooncakeStoreConnector(KVConnectorBase):
self, self,
model_executable: torch.nn.Module, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors], IntermediateTensors],
) -> None: ) -> None:
@ -113,8 +113,8 @@ class MooncakeStoreConnector(KVConnectorBase):
def recv_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module, self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor] kv_caches: list[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]: "ModelInputForGPUWithSamplingMetadata"]:
bypass_model_exec = True bypass_model_exec = True
input_tokens_tensor = model_input.input_tokens input_tokens_tensor = model_input.input_tokens

View File

@ -8,7 +8,7 @@ MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer. But the logic can be extended to support other pipe and lookup buffer.
""" """
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Optional, Union
import torch import torch
@ -133,7 +133,7 @@ class SimpleConnector(KVConnectorBase):
) )
def select(self, input_tokens: Optional[torch.Tensor], def select(self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
assert self.consumer_buffer is not None, "Please initialize the "\ assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select." "consumer buffer before calling select."
@ -152,7 +152,7 @@ class SimpleConnector(KVConnectorBase):
self, self,
model_executable: torch.nn.Module, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors], IntermediateTensors],
) -> None: ) -> None:
@ -207,8 +207,8 @@ class SimpleConnector(KVConnectorBase):
def recv_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module, self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor] kv_caches: list[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]: "ModelInputForGPUWithSamplingMetadata"]:
# When bypass_model_exec is set to False, it means that at least for one # When bypass_model_exec is set to False, it means that at least for one

View File

@ -5,13 +5,13 @@ import threading
import time import time
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator from typing import TYPE_CHECKING, Any, Optional
import msgspec import msgspec
import torch import torch
import zmq import zmq
from typing_extensions import Optional
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig

View File

@ -5,7 +5,7 @@ This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
1. `send_kv_caches_and_hidden_states` 1. `send_kv_caches_and_hidden_states`
2. `recv_kv_caches_and_hidden_states 2. `recv_kv_caches_and_hidden_states
""" """
from typing import TYPE_CHECKING, List, Tuple, Union from typing import TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
@ -53,7 +53,7 @@ class KVTransferAgent:
self, self,
model_executable: torch.nn.Module, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors], IntermediateTensors],
) -> None: ) -> None:
@ -68,8 +68,8 @@ class KVTransferAgent:
def recv_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module, self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata", model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor] kv_caches: list[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]: "ModelInputForGPUWithSamplingMetadata"]:
return self.connector.recv_kv_caches_and_hidden_states( return self.connector.recv_kv_caches_and_hidden_states(

View File

@ -13,7 +13,7 @@ These classes above are abstracted behind class `KVCacheBufferBase`.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import Optional
import torch import torch
@ -93,7 +93,7 @@ class KVLookupBufferBase(KVCacheBufferBase):
@abstractmethod @abstractmethod
def drop_select( def drop_select(
self, input_tokens: Optional[torch.Tensor], self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
"""Select and *drop* KV cache entries from the lookup buffer. """Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements The functionality is similar to the following python statements
@ -111,7 +111,7 @@ class KVLookupBufferBase(KVCacheBufferBase):
roi (torch.Tensor): A binary mask on top of the input tokens roi (torch.Tensor): A binary mask on top of the input tokens
Returns: Returns:
List[Optional[torch.Tensor]]: A list of tensors. Can be None. list[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises: Raises:
NotImplementedError: This method must be implemented in subclasses. NotImplementedError: This method must be implemented in subclasses.

View File

@ -11,7 +11,7 @@
""" """
import threading import threading
from collections import deque from collections import deque
from typing import Deque, List, Optional, Union from typing import Optional, Union
import torch import torch
@ -38,7 +38,7 @@ class SimpleBuffer(KVLookupBufferBase):
data_pipe: on device (e.g. GPU) data_pipe: on device (e.g. GPU)
""" """
self.buffer: Deque[List[torch.Tensor]] = deque() self.buffer: deque[list[torch.Tensor]] = deque()
self.buffer_size = 0 self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh self.buffer_size_threshold = buffer_size_thresh
@ -50,8 +50,8 @@ class SimpleBuffer(KVLookupBufferBase):
self.normal_signal = torch.tensor([0], device="cpu") self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None self.end_signal = None
def _matches(self, tokens_roi_sender: List[torch.Tensor], def _matches(self, tokens_roi_sender: list[torch.Tensor],
tokens_roi_recver: List[torch.Tensor]): tokens_roi_recver: list[torch.Tensor]):
# tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query) # tokens_roi_recver: tokens and roi of the consumer (query)
@ -88,7 +88,7 @@ class SimpleBuffer(KVLookupBufferBase):
tensor = tensor.float() tensor = tensor.float()
self.data_pipe.send_tensor(tensor) self.data_pipe.send_tensor(tensor)
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data.element_size() * data.numel() return data.element_size() * data.numel()
@ -151,7 +151,7 @@ class SimpleBuffer(KVLookupBufferBase):
tokens_roi_recver = [input_tokens, roi] tokens_roi_recver = [input_tokens, roi]
def is_buffer_available( def is_buffer_available(
tokens_roi_recver: List[torch.Tensor], ) -> bool: tokens_roi_recver: list[torch.Tensor], ) -> bool:
# perform input tokens and roi matching # perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1) # FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so # but this buffer size won't (and shouldn't) be too large so
@ -184,7 +184,7 @@ class SimpleBuffer(KVLookupBufferBase):
def drop_select( def drop_select(
self, input_tokens: Optional[torch.Tensor], self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
assert self.request_handling_thread is None, \ assert self.request_handling_thread is None, \
"drop_select should be called by the KV cache consumer "\ "drop_select should be called by the KV cache consumer "\

View File

@ -15,7 +15,7 @@
import threading import threading
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Optional
import torch import torch
@ -35,7 +35,7 @@ class BrokenPipeException(Exception):
super().__init__(self.message) super().__init__(self.message)
Metadata = Dict[str, Optional[torch.Tensor]] Metadata = dict[str, Optional[torch.Tensor]]
class PyNcclPipe(KVPipeBase): class PyNcclPipe(KVPipeBase):
@ -83,7 +83,7 @@ class PyNcclPipe(KVPipeBase):
def _get_device_send_recv_impl( def _get_device_send_recv_impl(
self, group: StatelessProcessGroup self, group: StatelessProcessGroup
) -> Tuple[Callable[[torch.Tensor, int], None], Callable[ ) -> tuple[Callable[[torch.Tensor, int], None], Callable[
[torch.Tensor, int], None]]: [torch.Tensor, int], None]]:
send: Callable[[torch.Tensor, int], None] send: Callable[[torch.Tensor, int], None]

View File

@ -29,7 +29,7 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Optional, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -54,15 +54,15 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict( def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]] tensor_dict: dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: ) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
"""Split the tensor dictionary into two parts: """Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata. by its metadata.
2. A list of tensors. 2. A list of tensors.
""" """
metadata_list: List[Tuple[str, Any]] = [] metadata_list: list[tuple[str, Any]] = []
tensor_list: List[torch.Tensor] = [] tensor_list: list[torch.Tensor] = []
for key, value in tensor_dict.items(): for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here, # Note: we cannot use `value.device` here,
@ -78,7 +78,7 @@ def _split_tensor_dict(
return metadata_list, tensor_list return metadata_list, tensor_list
_group_name_counter: Dict[str, int] = {} _group_name_counter: dict[str, int] = {}
def _get_unique_name(name: str) -> str: def _get_unique_name(name: str) -> str:
@ -94,7 +94,7 @@ def _get_unique_name(name: str) -> str:
return newname return newname
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} _groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
def _register_group(group: "GroupCoordinator") -> None: def _register_group(group: "GroupCoordinator") -> None:
@ -182,7 +182,7 @@ class GroupCoordinator:
# available attributes: # available attributes:
rank: int # global rank rank: int # global rank
ranks: List[int] # global ranks in the group ranks: list[int] # global ranks in the group
world_size: int # size of the group world_size: int # size of the group
# difference between `local_rank` and `rank_in_group`: # difference between `local_rank` and `rank_in_group`:
# if we have a group of size 4 across two nodes: # if we have a group of size 4 across two nodes:
@ -201,7 +201,7 @@ class GroupCoordinator:
def __init__( def __init__(
self, self,
group_ranks: List[List[int]], group_ranks: list[list[int]],
local_rank: int, local_rank: int,
torch_distributed_backend: Union[str, Backend], torch_distributed_backend: Union[str, Backend],
use_device_communicator: bool, use_device_communicator: bool,
@ -435,7 +435,7 @@ class GroupCoordinator:
return recv[0] return recv[0]
def broadcast_object_list(self, def broadcast_object_list(self,
obj_list: List[Any], obj_list: list[Any],
src: int = 0, src: int = 0,
group: Optional[ProcessGroup] = None): group: Optional[ProcessGroup] = None):
"""Broadcast the input object list. """Broadcast the input object list.
@ -518,11 +518,11 @@ class GroupCoordinator:
def broadcast_tensor_dict( def broadcast_tensor_dict(
self, self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0, src: int = 0,
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary. """Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank. NOTE: `src` is the local rank of the source rank.
""" """
@ -536,7 +536,7 @@ class GroupCoordinator:
rank_in_group = self.rank_in_group rank_in_group = self.rank_in_group
if rank_in_group == src: if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = [] metadata_list: list[tuple[Any, Any]] = []
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}") dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
@ -603,10 +603,10 @@ class GroupCoordinator:
def send_tensor_dict( def send_tensor_dict(
self, self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]], tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None, dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None, all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary. """Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank. NOTE: `dst` is the local rank of the source rank.
""" """
@ -626,7 +626,7 @@ class GroupCoordinator:
dst = (self.rank_in_group + 1) % self.world_size dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})" assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = [] metadata_list: list[tuple[Any, Any]] = []
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}" dict), f"Expecting a dictionary, got {type(tensor_dict)}"
@ -661,7 +661,7 @@ class GroupCoordinator:
self, self,
src: Optional[int] = None, src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None, all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary. """Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank. NOTE: `src` is the local rank of the source rank.
""" """
@ -682,7 +682,7 @@ class GroupCoordinator:
assert src < self.world_size, f"Invalid src rank ({src})" assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src) recv_metadata_list = self.recv_object(src=src)
tensor_dict: Dict[str, Any] = {} tensor_dict: dict[str, Any] = {}
for key, value in recv_metadata_list: for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, tensor = torch.empty(value.size,
@ -764,7 +764,7 @@ class GroupCoordinator:
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states, return self.device_communicator.dispatch(hidden_states,
router_logits) router_logits)
@ -782,7 +782,7 @@ def get_world_group() -> GroupCoordinator:
return _WORLD return _WORLD
def init_world_group(ranks: List[int], local_rank: int, def init_world_group(ranks: list[int], local_rank: int,
backend: str) -> GroupCoordinator: backend: str) -> GroupCoordinator:
return GroupCoordinator( return GroupCoordinator(
group_ranks=[ranks], group_ranks=[ranks],
@ -794,7 +794,7 @@ def init_world_group(ranks: List[int], local_rank: int,
def init_model_parallel_group( def init_model_parallel_group(
group_ranks: List[List[int]], group_ranks: list[list[int]],
local_rank: int, local_rank: int,
backend: str, backend: str,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
@ -1182,7 +1182,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
source_rank: int = 0) -> List[bool]: source_rank: int = 0) -> list[bool]:
""" """
This is a collective operation that returns if each rank is in the same node This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same as the source rank. It tests if processes are attached to the same

View File

@ -10,7 +10,8 @@ import pickle
import socket import socket
import time import time
from collections import deque from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple from collections.abc import Sequence
from typing import Any, Optional
import torch import torch
from torch.distributed import ProcessGroup, TCPStore from torch.distributed import ProcessGroup, TCPStore
@ -69,7 +70,7 @@ def split_tensor_along_last_dim(
def get_pp_indices(num_hidden_layers: int, pp_rank: int, def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]: pp_size: int) -> tuple[int, int]:
"""Try to evenly distribute layers across partitions. """Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions, If the number of layers is not divisible by the number of partitions,
@ -132,15 +133,15 @@ class StatelessProcessGroup:
data_expiration_seconds: int = 3600 # 1 hour data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter # dst rank -> counter
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict) send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
# src rank -> counter # src rank -> counter
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
broadcast_send_counter: int = 0 broadcast_send_counter: int = 0
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field( broadcast_recv_src_counter: dict[int, int] = dataclasses.field(
default_factory=dict) default_factory=dict)
# A deque to store the data entries, with key and timestamp. # A deque to store the data entries, with key and timestamp.
entries: Deque[Tuple[str, entries: deque[tuple[str,
float]] = dataclasses.field(default_factory=deque) float]] = dataclasses.field(default_factory=deque)
def __post_init__(self): def __post_init__(self):