mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 00:24:55 +08:00
Update deprecated type hinting in vllm/device_allocator and vllm/distributed (#18126)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
9b5b39b650
commit
dc372b9c8a
@ -74,8 +74,6 @@ exclude = [
|
||||
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
|
||||
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/distributed/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
|
||||
|
||||
@ -11,7 +11,7 @@ import dataclasses
|
||||
import gc
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -63,7 +63,7 @@ except ModuleNotFoundError:
|
||||
libcudart = None
|
||||
|
||||
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
|
||||
HandleType = Tuple[int, int, int, int]
|
||||
HandleType = tuple[int, int, int, int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -148,9 +148,9 @@ class CuMemAllocator:
|
||||
"Please track https://github.com/pytorch/pytorch/issues/147851 "
|
||||
"for the latest updates.")
|
||||
|
||||
self.pointer_to_data: Dict[int, AllocationData] = {}
|
||||
self.pointer_to_data: dict[int, AllocationData] = {}
|
||||
self.current_tag: str = CuMemAllocator.default_tag
|
||||
self.allocator_and_pools: Dict[str, Any] = {}
|
||||
self.allocator_and_pools: dict[str, Any] = {}
|
||||
|
||||
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
|
||||
"""
|
||||
@ -172,7 +172,7 @@ class CuMemAllocator:
|
||||
|
||||
def sleep(
|
||||
self,
|
||||
offload_tags: Optional[Union[Tuple[str, ...],
|
||||
offload_tags: Optional[Union[tuple[str, ...],
|
||||
str]] = None) -> None:
|
||||
"""
|
||||
Put the allocator in sleep mode.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -32,7 +32,7 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
|
||||
return get_tp_group().gather(input_, dst, dim)
|
||||
|
||||
|
||||
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
|
||||
def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor,
|
||||
Any]]] = None,
|
||||
src: int = 0):
|
||||
if not torch.distributed.is_initialized():
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -160,7 +160,7 @@ class DeviceCommunicatorBase:
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
@ -126,7 +126,7 @@ class _CPUSHMDistributed:
|
||||
|
||||
def gather(self,
|
||||
input: torch.Tensor,
|
||||
gather_list: Optional[List[torch.Tensor]],
|
||||
gather_list: Optional[list[torch.Tensor]],
|
||||
dst: int = -1,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
# Note: different from the torch gather, here we use local dst rank.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
@ -154,7 +154,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_impl is not None
|
||||
hidden_states, router_logits = self.all2all_impl.dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
@ -6,7 +6,7 @@ convenient for use when we just need to call a few functions.
|
||||
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||
import torch # noqa
|
||||
@ -32,7 +32,7 @@ class cudaIpcMemHandle_t(ctypes.Structure):
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
@ -97,11 +97,11 @@ class CudaRTLibrary:
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
if so_file is None:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -276,7 +276,7 @@ class CustomAllreduce:
|
||||
@staticmethod
|
||||
def create_shared_buffer(size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
uncached: Optional[bool] = False) -> List[int]:
|
||||
uncached: Optional[bool] = False) -> list[int]:
|
||||
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
|
||||
|
||||
world_size = dist.get_world_size(group=group)
|
||||
@ -284,7 +284,7 @@ class CustomAllreduce:
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: List[int] = []
|
||||
pointers: list[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer) # type: ignore
|
||||
@ -293,7 +293,7 @@ class CustomAllreduce:
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(pointers: List[int],
|
||||
def free_shared_buffer(pointers: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
rank: Optional[int] = 0) -> None:
|
||||
if rank is None:
|
||||
|
||||
@ -7,8 +7,9 @@ import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from collections.abc import Sequence
|
||||
from itertools import product
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
from typing import Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
@ -149,7 +150,7 @@ def can_actually_p2p(
|
||||
p_src.join()
|
||||
p_tgt.join()
|
||||
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
||||
result: List[bool] = []
|
||||
result: list[bool] = []
|
||||
for src, tgt in zip(batch_src, batch_tgt):
|
||||
a = result_queue.get()
|
||||
b = result_queue.get()
|
||||
@ -175,7 +176,7 @@ def can_actually_p2p(
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
||||
_gpu_p2p_access_cache: Optional[dict[str, bool]] = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
@ -204,7 +205,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
logger.info("generating GPU P2P access cache in %s", path)
|
||||
cache: Dict[str, bool] = {}
|
||||
cache: dict[str, bool] = {}
|
||||
ids = list(range(num_dev))
|
||||
# batch of all pairs of GPUs
|
||||
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||
|
||||
@ -24,7 +24,7 @@
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
@ -121,7 +121,7 @@ class ncclRedOpTypeEnum:
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
class NCCLLibrary:
|
||||
@ -210,11 +210,11 @@ class NCCLLibrary:
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
@ -238,7 +238,7 @@ class NCCLLibrary:
|
||||
raise e
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
_funcs: dict[str, Any] = {}
|
||||
for func in NCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
|
||||
@ -8,7 +8,7 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from threading import Event
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -173,9 +173,9 @@ class ShmRingBuffer:
|
||||
|
||||
@dataclass
|
||||
class Handle:
|
||||
local_reader_ranks: List[int] = field(default_factory=list)
|
||||
local_reader_ranks: list[int] = field(default_factory=list)
|
||||
|
||||
buffer_handle: Optional[Tuple[int, int, int, str]] = None
|
||||
buffer_handle: Optional[tuple[int, int, int, str]] = None
|
||||
local_subscribe_addr: Optional[str] = None
|
||||
remote_subscribe_addr: Optional[str] = None
|
||||
remote_addr_ipv6: bool = False
|
||||
@ -187,7 +187,7 @@ class MessageQueue:
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: Optional[List[int]] = None,
|
||||
local_reader_ranks: Optional[list[int]] = None,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||
max_chunks: int = 10,
|
||||
connect_ip: Optional[str] = None,
|
||||
|
||||
@ -8,7 +8,7 @@ The class provides two primary abstract methods:
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -55,7 +55,7 @@ class KVConnectorBase(ABC):
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
@ -71,7 +71,7 @@ class KVConnectorBase(ABC):
|
||||
start and end layer information.
|
||||
model_input (ModelInputForGPUWithSamplingMetadata): The input
|
||||
metadata from vLLM.
|
||||
kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
|
||||
kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
|
||||
for each layer.
|
||||
hidden_or_intermediate_states (Union[torch.Tensor,
|
||||
IntermediateTensors]):
|
||||
@ -88,8 +88,8 @@ class KVConnectorBase(ABC):
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
"""
|
||||
Receive KV caches and hidden states from the connector.
|
||||
@ -104,7 +104,7 @@ class KVConnectorBase(ABC):
|
||||
The model executable from vLLM modelrunner.
|
||||
model_input (ModelInputForGPUWithSamplingMetadata):
|
||||
The model input from vLLM modelrunner.
|
||||
kv_caches (List[torch.Tensor]):
|
||||
kv_caches (list[torch.Tensor]):
|
||||
List of KV caches for each layer.
|
||||
|
||||
Returns:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Type
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
@ -18,7 +18,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVConnectorFactory:
|
||||
_registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {}
|
||||
_registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str,
|
||||
@ -27,7 +27,7 @@ class KVConnectorFactory:
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> Type[KVConnectorBaseType]:
|
||||
def loader() -> type[KVConnectorBaseType]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
|
||||
(2) offload and share KV caches.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -63,8 +63,8 @@ class LMCacheConnector(KVConnectorBase):
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
retrieve_status = self.lmcache_should_retrieve(model_input)
|
||||
@ -78,7 +78,7 @@ class LMCacheConnector(KVConnectorBase):
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
@ -6,7 +6,7 @@ The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
|
||||
database-style KVStore.
|
||||
"""
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -70,7 +70,7 @@ class MooncakeStoreConnector(KVConnectorBase):
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
@ -113,8 +113,8 @@ class MooncakeStoreConnector(KVConnectorBase):
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
bypass_model_exec = True
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
|
||||
@ -8,7 +8,7 @@ MooncakePipe.
|
||||
|
||||
But the logic can be extended to support other pipe and lookup buffer.
|
||||
"""
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -133,7 +133,7 @@ class SimpleConnector(KVConnectorBase):
|
||||
)
|
||||
|
||||
def select(self, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
||||
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||
|
||||
assert self.consumer_buffer is not None, "Please initialize the "\
|
||||
"consumer buffer before calling select."
|
||||
@ -152,7 +152,7 @@ class SimpleConnector(KVConnectorBase):
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
@ -207,8 +207,8 @@ class SimpleConnector(KVConnectorBase):
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
# When bypass_model_exec is set to False, it means that at least for one
|
||||
|
||||
@ -5,13 +5,13 @@ import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Iterator
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
from typing_extensions import Optional
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
@ -5,7 +5,7 @@ This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
|
||||
1. `send_kv_caches_and_hidden_states`
|
||||
2. `recv_kv_caches_and_hidden_states
|
||||
"""
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
@ -53,7 +53,7 @@ class KVTransferAgent:
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
@ -68,8 +68,8 @@ class KVTransferAgent:
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
return self.connector.recv_kv_caches_and_hidden_states(
|
||||
|
||||
@ -13,7 +13,7 @@ These classes above are abstracted behind class `KVCacheBufferBase`.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -93,7 +93,7 @@ class KVLookupBufferBase(KVCacheBufferBase):
|
||||
@abstractmethod
|
||||
def drop_select(
|
||||
self, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
||||
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||
"""Select and *drop* KV cache entries from the lookup buffer.
|
||||
|
||||
The functionality is similar to the following python statements
|
||||
@ -111,7 +111,7 @@ class KVLookupBufferBase(KVCacheBufferBase):
|
||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||
|
||||
Returns:
|
||||
List[Optional[torch.Tensor]]: A list of tensors. Can be None.
|
||||
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
"""
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Deque, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -38,7 +38,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
data_pipe: on device (e.g. GPU)
|
||||
"""
|
||||
|
||||
self.buffer: Deque[List[torch.Tensor]] = deque()
|
||||
self.buffer: deque[list[torch.Tensor]] = deque()
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = buffer_size_thresh
|
||||
@ -50,8 +50,8 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
self.normal_signal = torch.tensor([0], device="cpu")
|
||||
self.end_signal = None
|
||||
|
||||
def _matches(self, tokens_roi_sender: List[torch.Tensor],
|
||||
tokens_roi_recver: List[torch.Tensor]):
|
||||
def _matches(self, tokens_roi_sender: list[torch.Tensor],
|
||||
tokens_roi_recver: list[torch.Tensor]):
|
||||
|
||||
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
|
||||
# tokens_roi_recver: tokens and roi of the consumer (query)
|
||||
@ -88,7 +88,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
tensor = tensor.float()
|
||||
self.data_pipe.send_tensor(tensor)
|
||||
|
||||
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
|
||||
def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]):
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.element_size() * data.numel()
|
||||
@ -151,7 +151,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
tokens_roi_recver = [input_tokens, roi]
|
||||
|
||||
def is_buffer_available(
|
||||
tokens_roi_recver: List[torch.Tensor], ) -> bool:
|
||||
tokens_roi_recver: list[torch.Tensor], ) -> bool:
|
||||
# perform input tokens and roi matching
|
||||
# FIXME: this matching is O(n), ideally it should be O(1)
|
||||
# but this buffer size won't (and shouldn't) be too large so
|
||||
@ -184,7 +184,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
|
||||
def drop_select(
|
||||
self, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
||||
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||
|
||||
assert self.request_handling_thread is None, \
|
||||
"drop_select should be called by the KV cache consumer "\
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -35,7 +35,7 @@ class BrokenPipeException(Exception):
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
Metadata = Dict[str, Optional[torch.Tensor]]
|
||||
Metadata = dict[str, Optional[torch.Tensor]]
|
||||
|
||||
|
||||
class PyNcclPipe(KVPipeBase):
|
||||
@ -83,7 +83,7 @@ class PyNcclPipe(KVPipeBase):
|
||||
|
||||
def _get_device_send_recv_impl(
|
||||
self, group: StatelessProcessGroup
|
||||
) -> Tuple[Callable[[torch.Tensor, int], None], Callable[
|
||||
) -> tuple[Callable[[torch.Tensor, int], None], Callable[
|
||||
[torch.Tensor, int], None]]:
|
||||
|
||||
send: Callable[[torch.Tensor, int], None]
|
||||
|
||||
@ -29,7 +29,7 @@ from collections import namedtuple
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -54,15 +54,15 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||
|
||||
|
||||
def _split_tensor_dict(
|
||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
||||
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]]
|
||||
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
|
||||
"""Split the tensor dictionary into two parts:
|
||||
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
||||
by its metadata.
|
||||
2. A list of tensors.
|
||||
"""
|
||||
metadata_list: List[Tuple[str, Any]] = []
|
||||
tensor_list: List[torch.Tensor] = []
|
||||
metadata_list: list[tuple[str, Any]] = []
|
||||
tensor_list: list[torch.Tensor] = []
|
||||
for key, value in tensor_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Note: we cannot use `value.device` here,
|
||||
@ -78,7 +78,7 @@ def _split_tensor_dict(
|
||||
return metadata_list, tensor_list
|
||||
|
||||
|
||||
_group_name_counter: Dict[str, int] = {}
|
||||
_group_name_counter: dict[str, int] = {}
|
||||
|
||||
|
||||
def _get_unique_name(name: str) -> str:
|
||||
@ -94,7 +94,7 @@ def _get_unique_name(name: str) -> str:
|
||||
return newname
|
||||
|
||||
|
||||
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
|
||||
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
|
||||
|
||||
|
||||
def _register_group(group: "GroupCoordinator") -> None:
|
||||
@ -182,7 +182,7 @@ class GroupCoordinator:
|
||||
|
||||
# available attributes:
|
||||
rank: int # global rank
|
||||
ranks: List[int] # global ranks in the group
|
||||
ranks: list[int] # global ranks in the group
|
||||
world_size: int # size of the group
|
||||
# difference between `local_rank` and `rank_in_group`:
|
||||
# if we have a group of size 4 across two nodes:
|
||||
@ -201,7 +201,7 @@ class GroupCoordinator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_ranks: List[List[int]],
|
||||
group_ranks: list[list[int]],
|
||||
local_rank: int,
|
||||
torch_distributed_backend: Union[str, Backend],
|
||||
use_device_communicator: bool,
|
||||
@ -435,7 +435,7 @@ class GroupCoordinator:
|
||||
return recv[0]
|
||||
|
||||
def broadcast_object_list(self,
|
||||
obj_list: List[Any],
|
||||
obj_list: list[Any],
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None):
|
||||
"""Broadcast the input object list.
|
||||
@ -518,11 +518,11 @@ class GroupCoordinator:
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
|
||||
tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
metadata_group: Optional[ProcessGroup] = None
|
||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Broadcast the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
"""
|
||||
@ -536,7 +536,7 @@ class GroupCoordinator:
|
||||
|
||||
rank_in_group = self.rank_in_group
|
||||
if rank_in_group == src:
|
||||
metadata_list: List[Tuple[Any, Any]] = []
|
||||
metadata_list: list[tuple[Any, Any]] = []
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
||||
@ -603,10 +603,10 @@ class GroupCoordinator:
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||
dst: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Send the input tensor dictionary.
|
||||
NOTE: `dst` is the local rank of the source rank.
|
||||
"""
|
||||
@ -626,7 +626,7 @@ class GroupCoordinator:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
metadata_list: List[Tuple[Any, Any]] = []
|
||||
metadata_list: list[tuple[Any, Any]] = []
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||
@ -661,7 +661,7 @@ class GroupCoordinator:
|
||||
self,
|
||||
src: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Recv the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
"""
|
||||
@ -682,7 +682,7 @@ class GroupCoordinator:
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict: Dict[str, Any] = {}
|
||||
tensor_dict: dict[str, Any] = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size,
|
||||
@ -764,7 +764,7 @@ class GroupCoordinator:
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch(hidden_states,
|
||||
router_logits)
|
||||
@ -782,7 +782,7 @@ def get_world_group() -> GroupCoordinator:
|
||||
return _WORLD
|
||||
|
||||
|
||||
def init_world_group(ranks: List[int], local_rank: int,
|
||||
def init_world_group(ranks: list[int], local_rank: int,
|
||||
backend: str) -> GroupCoordinator:
|
||||
return GroupCoordinator(
|
||||
group_ranks=[ranks],
|
||||
@ -794,7 +794,7 @@ def init_world_group(ranks: List[int], local_rank: int,
|
||||
|
||||
|
||||
def init_model_parallel_group(
|
||||
group_ranks: List[List[int]],
|
||||
group_ranks: list[list[int]],
|
||||
local_rank: int,
|
||||
backend: str,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
@ -1182,7 +1182,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||
|
||||
|
||||
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
|
||||
source_rank: int = 0) -> List[bool]:
|
||||
source_rank: int = 0) -> list[bool]:
|
||||
"""
|
||||
This is a collective operation that returns if each rank is in the same node
|
||||
as the source rank. It tests if processes are attached to the same
|
||||
|
||||
@ -10,7 +10,8 @@ import pickle
|
||||
import socket
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, TCPStore
|
||||
@ -69,7 +70,7 @@ def split_tensor_along_last_dim(
|
||||
|
||||
|
||||
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
||||
pp_size: int) -> Tuple[int, int]:
|
||||
pp_size: int) -> tuple[int, int]:
|
||||
"""Try to evenly distribute layers across partitions.
|
||||
|
||||
If the number of layers is not divisible by the number of partitions,
|
||||
@ -132,15 +133,15 @@ class StatelessProcessGroup:
|
||||
data_expiration_seconds: int = 3600 # 1 hour
|
||||
|
||||
# dst rank -> counter
|
||||
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
# src rank -> counter
|
||||
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
broadcast_send_counter: int = 0
|
||||
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(
|
||||
broadcast_recv_src_counter: dict[int, int] = dataclasses.field(
|
||||
default_factory=dict)
|
||||
|
||||
# A deque to store the data entries, with key and timestamp.
|
||||
entries: Deque[Tuple[str,
|
||||
entries: deque[tuple[str,
|
||||
float]] = dataclasses.field(default_factory=deque)
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user