mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:45:50 +08:00
Update deprecated type hinting in vllm/device_allocator and vllm/distributed (#18126)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
9b5b39b650
commit
dc372b9c8a
@ -74,8 +74,6 @@ exclude = [
|
|||||||
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
|
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
|
||||||
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
|
|
||||||
"vllm/distributed/**/*.py" = ["UP006", "UP035"]
|
|
||||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
|
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import dataclasses
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ except ModuleNotFoundError:
|
|||||||
libcudart = None
|
libcudart = None
|
||||||
|
|
||||||
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
|
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
|
||||||
HandleType = Tuple[int, int, int, int]
|
HandleType = tuple[int, int, int, int]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -148,9 +148,9 @@ class CuMemAllocator:
|
|||||||
"Please track https://github.com/pytorch/pytorch/issues/147851 "
|
"Please track https://github.com/pytorch/pytorch/issues/147851 "
|
||||||
"for the latest updates.")
|
"for the latest updates.")
|
||||||
|
|
||||||
self.pointer_to_data: Dict[int, AllocationData] = {}
|
self.pointer_to_data: dict[int, AllocationData] = {}
|
||||||
self.current_tag: str = CuMemAllocator.default_tag
|
self.current_tag: str = CuMemAllocator.default_tag
|
||||||
self.allocator_and_pools: Dict[str, Any] = {}
|
self.allocator_and_pools: dict[str, Any] = {}
|
||||||
|
|
||||||
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
|
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
|
||||||
"""
|
"""
|
||||||
@ -172,7 +172,7 @@ class CuMemAllocator:
|
|||||||
|
|
||||||
def sleep(
|
def sleep(
|
||||||
self,
|
self,
|
||||||
offload_tags: Optional[Union[Tuple[str, ...],
|
offload_tags: Optional[Union[tuple[str, ...],
|
||||||
str]] = None) -> None:
|
str]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Put the allocator in sleep mode.
|
Put the allocator in sleep mode.
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -32,7 +32,7 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
|
|||||||
return get_tp_group().gather(input_, dst, dim)
|
return get_tp_group().gather(input_, dst, dim)
|
||||||
|
|
||||||
|
|
||||||
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
|
def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor,
|
||||||
Any]]] = None,
|
Any]]] = None,
|
||||||
src: int = 0):
|
src: int = 0):
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -160,7 +160,7 @@ class DeviceCommunicatorBase:
|
|||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, hidden_states: torch.Tensor,
|
self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Dispatch the hidden states and router logits to the appropriate device.
|
Dispatch the hidden states and router logits to the appropriate device.
|
||||||
This is a no-op in the base class.
|
This is a no-op in the base class.
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@ -126,7 +126,7 @@ class _CPUSHMDistributed:
|
|||||||
|
|
||||||
def gather(self,
|
def gather(self,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
gather_list: Optional[List[torch.Tensor]],
|
gather_list: Optional[list[torch.Tensor]],
|
||||||
dst: int = -1,
|
dst: int = -1,
|
||||||
group: Optional[ProcessGroup] = None) -> None:
|
group: Optional[ProcessGroup] = None) -> None:
|
||||||
# Note: different from the torch gather, here we use local dst rank.
|
# Note: different from the torch gather, here we use local dst rank.
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@ -154,7 +154,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, hidden_states: torch.Tensor,
|
self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert self.all2all_impl is not None
|
assert self.all2all_impl is not None
|
||||||
hidden_states, router_logits = self.all2all_impl.dispatch(
|
hidden_states, router_logits = self.all2all_impl.dispatch(
|
||||||
hidden_states, router_logits)
|
hidden_states, router_logits)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ convenient for use when we just need to call a few functions.
|
|||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||||
import torch # noqa
|
import torch # noqa
|
||||||
@ -32,7 +32,7 @@ class cudaIpcMemHandle_t(ctypes.Structure):
|
|||||||
class Function:
|
class Function:
|
||||||
name: str
|
name: str
|
||||||
restype: Any
|
restype: Any
|
||||||
argtypes: List[Any]
|
argtypes: list[Any]
|
||||||
|
|
||||||
|
|
||||||
def find_loaded_library(lib_name) -> Optional[str]:
|
def find_loaded_library(lib_name) -> Optional[str]:
|
||||||
@ -97,11 +97,11 @@ class CudaRTLibrary:
|
|||||||
|
|
||||||
# class attribute to store the mapping from the path to the library
|
# class attribute to store the mapping from the path to the library
|
||||||
# to avoid loading the same library multiple times
|
# to avoid loading the same library multiple times
|
||||||
path_to_library_cache: Dict[str, Any] = {}
|
path_to_library_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
# class attribute to store the mapping from library path
|
# class attribute to store the mapping from library path
|
||||||
# to the corresponding dictionary
|
# to the corresponding dictionary
|
||||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
def __init__(self, so_file: Optional[str] = None):
|
def __init__(self, so_file: Optional[str] = None):
|
||||||
if so_file is None:
|
if so_file is None:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import List, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -276,7 +276,7 @@ class CustomAllreduce:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def create_shared_buffer(size_in_bytes: int,
|
def create_shared_buffer(size_in_bytes: int,
|
||||||
group: Optional[ProcessGroup] = None,
|
group: Optional[ProcessGroup] = None,
|
||||||
uncached: Optional[bool] = False) -> List[int]:
|
uncached: Optional[bool] = False) -> list[int]:
|
||||||
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
|
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
|
||||||
|
|
||||||
world_size = dist.get_world_size(group=group)
|
world_size = dist.get_world_size(group=group)
|
||||||
@ -284,7 +284,7 @@ class CustomAllreduce:
|
|||||||
handles = [None] * world_size
|
handles = [None] * world_size
|
||||||
dist.all_gather_object(handles, handle, group=group)
|
dist.all_gather_object(handles, handle, group=group)
|
||||||
|
|
||||||
pointers: List[int] = []
|
pointers: list[int] = []
|
||||||
for i, h in enumerate(handles):
|
for i, h in enumerate(handles):
|
||||||
if i == rank:
|
if i == rank:
|
||||||
pointers.append(pointer) # type: ignore
|
pointers.append(pointer) # type: ignore
|
||||||
@ -293,7 +293,7 @@ class CustomAllreduce:
|
|||||||
return pointers
|
return pointers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def free_shared_buffer(pointers: List[int],
|
def free_shared_buffer(pointers: list[int],
|
||||||
group: Optional[ProcessGroup] = None,
|
group: Optional[ProcessGroup] = None,
|
||||||
rank: Optional[int] = 0) -> None:
|
rank: Optional[int] = 0) -> None:
|
||||||
if rank is None:
|
if rank is None:
|
||||||
|
|||||||
@ -7,8 +7,9 @@ import pickle
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from collections.abc import Sequence
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Dict, List, Optional, Sequence
|
from typing import Optional
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
@ -149,7 +150,7 @@ def can_actually_p2p(
|
|||||||
p_src.join()
|
p_src.join()
|
||||||
p_tgt.join()
|
p_tgt.join()
|
||||||
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
||||||
result: List[bool] = []
|
result: list[bool] = []
|
||||||
for src, tgt in zip(batch_src, batch_tgt):
|
for src, tgt in zip(batch_src, batch_tgt):
|
||||||
a = result_queue.get()
|
a = result_queue.get()
|
||||||
b = result_queue.get()
|
b = result_queue.get()
|
||||||
@ -175,7 +176,7 @@ def can_actually_p2p(
|
|||||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||||
# of visible devices in the vllm engine.
|
# of visible devices in the vllm engine.
|
||||||
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
_gpu_p2p_access_cache: Optional[dict[str, bool]] = None
|
||||||
|
|
||||||
|
|
||||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||||
@ -204,7 +205,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
|||||||
# only the local master process (with local_rank == 0) can
|
# only the local master process (with local_rank == 0) can
|
||||||
# enter this block to calculate the cache
|
# enter this block to calculate the cache
|
||||||
logger.info("generating GPU P2P access cache in %s", path)
|
logger.info("generating GPU P2P access cache in %s", path)
|
||||||
cache: Dict[str, bool] = {}
|
cache: dict[str, bool] = {}
|
||||||
ids = list(range(num_dev))
|
ids = list(range(num_dev))
|
||||||
# batch of all pairs of GPUs
|
# batch of all pairs of GPUs
|
||||||
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||||
|
|||||||
@ -24,7 +24,7 @@
|
|||||||
import ctypes
|
import ctypes
|
||||||
import platform
|
import platform
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ReduceOp
|
||||||
@ -121,7 +121,7 @@ class ncclRedOpTypeEnum:
|
|||||||
class Function:
|
class Function:
|
||||||
name: str
|
name: str
|
||||||
restype: Any
|
restype: Any
|
||||||
argtypes: List[Any]
|
argtypes: list[Any]
|
||||||
|
|
||||||
|
|
||||||
class NCCLLibrary:
|
class NCCLLibrary:
|
||||||
@ -210,11 +210,11 @@ class NCCLLibrary:
|
|||||||
|
|
||||||
# class attribute to store the mapping from the path to the library
|
# class attribute to store the mapping from the path to the library
|
||||||
# to avoid loading the same library multiple times
|
# to avoid loading the same library multiple times
|
||||||
path_to_library_cache: Dict[str, Any] = {}
|
path_to_library_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
# class attribute to store the mapping from library path
|
# class attribute to store the mapping from library path
|
||||||
# to the corresponding dictionary
|
# to the corresponding dictionary
|
||||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
def __init__(self, so_file: Optional[str] = None):
|
def __init__(self, so_file: Optional[str] = None):
|
||||||
|
|
||||||
@ -238,7 +238,7 @@ class NCCLLibrary:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||||
_funcs: Dict[str, Any] = {}
|
_funcs: dict[str, Any] = {}
|
||||||
for func in NCCLLibrary.exported_functions:
|
for func in NCCLLibrary.exported_functions:
|
||||||
f = getattr(self.lib, func.name)
|
f = getattr(self.lib, func.name)
|
||||||
f.restype = func.restype
|
f.restype = func.restype
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from contextlib import contextmanager
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, Optional, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -173,9 +173,9 @@ class ShmRingBuffer:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Handle:
|
class Handle:
|
||||||
local_reader_ranks: List[int] = field(default_factory=list)
|
local_reader_ranks: list[int] = field(default_factory=list)
|
||||||
|
|
||||||
buffer_handle: Optional[Tuple[int, int, int, str]] = None
|
buffer_handle: Optional[tuple[int, int, int, str]] = None
|
||||||
local_subscribe_addr: Optional[str] = None
|
local_subscribe_addr: Optional[str] = None
|
||||||
remote_subscribe_addr: Optional[str] = None
|
remote_subscribe_addr: Optional[str] = None
|
||||||
remote_addr_ipv6: bool = False
|
remote_addr_ipv6: bool = False
|
||||||
@ -187,7 +187,7 @@ class MessageQueue:
|
|||||||
self,
|
self,
|
||||||
n_reader, # number of all readers
|
n_reader, # number of all readers
|
||||||
n_local_reader, # number of local readers through shared memory
|
n_local_reader, # number of local readers through shared memory
|
||||||
local_reader_ranks: Optional[List[int]] = None,
|
local_reader_ranks: Optional[list[int]] = None,
|
||||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||||
max_chunks: int = 10,
|
max_chunks: int = 10,
|
||||||
connect_ip: Optional[str] = None,
|
connect_ip: Optional[str] = None,
|
||||||
|
|||||||
@ -8,7 +8,7 @@ The class provides two primary abstract methods:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ class KVConnectorBase(ABC):
|
|||||||
self,
|
self,
|
||||||
model_executable: torch.nn.Module,
|
model_executable: torch.nn.Module,
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: list[torch.Tensor],
|
||||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||||
IntermediateTensors],
|
IntermediateTensors],
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -71,7 +71,7 @@ class KVConnectorBase(ABC):
|
|||||||
start and end layer information.
|
start and end layer information.
|
||||||
model_input (ModelInputForGPUWithSamplingMetadata): The input
|
model_input (ModelInputForGPUWithSamplingMetadata): The input
|
||||||
metadata from vLLM.
|
metadata from vLLM.
|
||||||
kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
|
kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
|
||||||
for each layer.
|
for each layer.
|
||||||
hidden_or_intermediate_states (Union[torch.Tensor,
|
hidden_or_intermediate_states (Union[torch.Tensor,
|
||||||
IntermediateTensors]):
|
IntermediateTensors]):
|
||||||
@ -88,8 +88,8 @@ class KVConnectorBase(ABC):
|
|||||||
def recv_kv_caches_and_hidden_states(
|
def recv_kv_caches_and_hidden_states(
|
||||||
self, model_executable: torch.nn.Module,
|
self, model_executable: torch.nn.Module,
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
kv_caches: List[torch.Tensor]
|
kv_caches: list[torch.Tensor]
|
||||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||||
"ModelInputForGPUWithSamplingMetadata"]:
|
"ModelInputForGPUWithSamplingMetadata"]:
|
||||||
"""
|
"""
|
||||||
Receive KV caches and hidden states from the connector.
|
Receive KV caches and hidden states from the connector.
|
||||||
@ -104,7 +104,7 @@ class KVConnectorBase(ABC):
|
|||||||
The model executable from vLLM modelrunner.
|
The model executable from vLLM modelrunner.
|
||||||
model_input (ModelInputForGPUWithSamplingMetadata):
|
model_input (ModelInputForGPUWithSamplingMetadata):
|
||||||
The model input from vLLM modelrunner.
|
The model input from vLLM modelrunner.
|
||||||
kv_caches (List[torch.Tensor]):
|
kv_caches (list[torch.Tensor]):
|
||||||
List of KV caches for each layer.
|
List of KV caches for each layer.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, Type
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||||
@ -18,7 +18,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class KVConnectorFactory:
|
class KVConnectorFactory:
|
||||||
_registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {}
|
_registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_connector(cls, name: str, module_path: str,
|
def register_connector(cls, name: str, module_path: str,
|
||||||
@ -27,7 +27,7 @@ class KVConnectorFactory:
|
|||||||
if name in cls._registry:
|
if name in cls._registry:
|
||||||
raise ValueError(f"Connector '{name}' is already registered.")
|
raise ValueError(f"Connector '{name}' is already registered.")
|
||||||
|
|
||||||
def loader() -> Type[KVConnectorBaseType]:
|
def loader() -> type[KVConnectorBaseType]:
|
||||||
module = importlib.import_module(module_path)
|
module = importlib.import_module(module_path)
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
|
|||||||
(2) offload and share KV caches.
|
(2) offload and share KV caches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -63,8 +63,8 @@ class LMCacheConnector(KVConnectorBase):
|
|||||||
def recv_kv_caches_and_hidden_states(
|
def recv_kv_caches_and_hidden_states(
|
||||||
self, model_executable: torch.nn.Module,
|
self, model_executable: torch.nn.Module,
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
kv_caches: List[torch.Tensor]
|
kv_caches: list[torch.Tensor]
|
||||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||||
"ModelInputForGPUWithSamplingMetadata"]:
|
"ModelInputForGPUWithSamplingMetadata"]:
|
||||||
|
|
||||||
retrieve_status = self.lmcache_should_retrieve(model_input)
|
retrieve_status = self.lmcache_should_retrieve(model_input)
|
||||||
@ -78,7 +78,7 @@ class LMCacheConnector(KVConnectorBase):
|
|||||||
self,
|
self,
|
||||||
model_executable: torch.nn.Module,
|
model_executable: torch.nn.Module,
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: list[torch.Tensor],
|
||||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||||
IntermediateTensors],
|
IntermediateTensors],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -6,7 +6,7 @@ The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
|
|||||||
database-style KVStore.
|
database-style KVStore.
|
||||||
"""
|
"""
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ class MooncakeStoreConnector(KVConnectorBase):
|
|||||||
self,
|
self,
|
||||||
model_executable: torch.nn.Module,
|
model_executable: torch.nn.Module,
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: list[torch.Tensor],
|
||||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||||
IntermediateTensors],
|
IntermediateTensors],
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -113,8 +113,8 @@ class MooncakeStoreConnector(KVConnectorBase):
|
|||||||
def recv_kv_caches_and_hidden_states(
|
def recv_kv_caches_and_hidden_states(
|
||||||
self, model_executable: torch.nn.Module,
|
self, model_executable: torch.nn.Module,
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
kv_caches: List[torch.Tensor]
|
kv_caches: list[torch.Tensor]
|
||||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||||
"ModelInputForGPUWithSamplingMetadata"]:
|
"ModelInputForGPUWithSamplingMetadata"]:
|
||||||
bypass_model_exec = True
|
bypass_model_exec = True
|
||||||
input_tokens_tensor = model_input.input_tokens
|
input_tokens_tensor = model_input.input_tokens
|
||||||
|
|||||||
@ -8,7 +8,7 @@ MooncakePipe.
|
|||||||
|
|
||||||
But the logic can be extended to support other pipe and lookup buffer.
|
But the logic can be extended to support other pipe and lookup buffer.
|
||||||
"""
|
"""
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def select(self, input_tokens: Optional[torch.Tensor],
|
def select(self, input_tokens: Optional[torch.Tensor],
|
||||||
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||||
|
|
||||||
assert self.consumer_buffer is not None, "Please initialize the "\
|
assert self.consumer_buffer is not None, "Please initialize the "\
|
||||||
"consumer buffer before calling select."
|
"consumer buffer before calling select."
|
||||||
@ -152,7 +152,7 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
self,
|
self,
|
||||||
model_executable: torch.nn.Module,
|
model_executable: torch.nn.Module,
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: list[torch.Tensor],
|
||||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||||
IntermediateTensors],
|
IntermediateTensors],
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -207,8 +207,8 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
def recv_kv_caches_and_hidden_states(
|
def recv_kv_caches_and_hidden_states(
|
||||||
self, model_executable: torch.nn.Module,
|
self, model_executable: torch.nn.Module,
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
kv_caches: List[torch.Tensor]
|
kv_caches: list[torch.Tensor]
|
||||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||||
"ModelInputForGPUWithSamplingMetadata"]:
|
"ModelInputForGPUWithSamplingMetadata"]:
|
||||||
|
|
||||||
# When bypass_model_exec is set to False, it means that at least for one
|
# When bypass_model_exec is set to False, it means that at least for one
|
||||||
|
|||||||
@ -5,13 +5,13 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Iterator
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
from typing_extensions import Optional
|
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
|||||||
@ -5,7 +5,7 @@ This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
|
|||||||
1. `send_kv_caches_and_hidden_states`
|
1. `send_kv_caches_and_hidden_states`
|
||||||
2. `recv_kv_caches_and_hidden_states
|
2. `recv_kv_caches_and_hidden_states
|
||||||
"""
|
"""
|
||||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||||
@ -53,7 +53,7 @@ class KVTransferAgent:
|
|||||||
self,
|
self,
|
||||||
model_executable: torch.nn.Module,
|
model_executable: torch.nn.Module,
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: list[torch.Tensor],
|
||||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||||
IntermediateTensors],
|
IntermediateTensors],
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -68,8 +68,8 @@ class KVTransferAgent:
|
|||||||
def recv_kv_caches_and_hidden_states(
|
def recv_kv_caches_and_hidden_states(
|
||||||
self, model_executable: torch.nn.Module,
|
self, model_executable: torch.nn.Module,
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
kv_caches: List[torch.Tensor]
|
kv_caches: list[torch.Tensor]
|
||||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||||
"ModelInputForGPUWithSamplingMetadata"]:
|
"ModelInputForGPUWithSamplingMetadata"]:
|
||||||
|
|
||||||
return self.connector.recv_kv_caches_and_hidden_states(
|
return self.connector.recv_kv_caches_and_hidden_states(
|
||||||
|
|||||||
@ -13,7 +13,7 @@ These classes above are abstracted behind class `KVCacheBufferBase`.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -93,7 +93,7 @@ class KVLookupBufferBase(KVCacheBufferBase):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def drop_select(
|
def drop_select(
|
||||||
self, input_tokens: Optional[torch.Tensor],
|
self, input_tokens: Optional[torch.Tensor],
|
||||||
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||||
"""Select and *drop* KV cache entries from the lookup buffer.
|
"""Select and *drop* KV cache entries from the lookup buffer.
|
||||||
|
|
||||||
The functionality is similar to the following python statements
|
The functionality is similar to the following python statements
|
||||||
@ -111,7 +111,7 @@ class KVLookupBufferBase(KVCacheBufferBase):
|
|||||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Optional[torch.Tensor]]: A list of tensors. Can be None.
|
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: This method must be implemented in subclasses.
|
NotImplementedError: This method must be implemented in subclasses.
|
||||||
|
|||||||
@ -11,7 +11,7 @@
|
|||||||
"""
|
"""
|
||||||
import threading
|
import threading
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Deque, List, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
|||||||
data_pipe: on device (e.g. GPU)
|
data_pipe: on device (e.g. GPU)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.buffer: Deque[List[torch.Tensor]] = deque()
|
self.buffer: deque[list[torch.Tensor]] = deque()
|
||||||
|
|
||||||
self.buffer_size = 0
|
self.buffer_size = 0
|
||||||
self.buffer_size_threshold = buffer_size_thresh
|
self.buffer_size_threshold = buffer_size_thresh
|
||||||
@ -50,8 +50,8 @@ class SimpleBuffer(KVLookupBufferBase):
|
|||||||
self.normal_signal = torch.tensor([0], device="cpu")
|
self.normal_signal = torch.tensor([0], device="cpu")
|
||||||
self.end_signal = None
|
self.end_signal = None
|
||||||
|
|
||||||
def _matches(self, tokens_roi_sender: List[torch.Tensor],
|
def _matches(self, tokens_roi_sender: list[torch.Tensor],
|
||||||
tokens_roi_recver: List[torch.Tensor]):
|
tokens_roi_recver: list[torch.Tensor]):
|
||||||
|
|
||||||
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
|
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
|
||||||
# tokens_roi_recver: tokens and roi of the consumer (query)
|
# tokens_roi_recver: tokens and roi of the consumer (query)
|
||||||
@ -88,7 +88,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
|||||||
tensor = tensor.float()
|
tensor = tensor.float()
|
||||||
self.data_pipe.send_tensor(tensor)
|
self.data_pipe.send_tensor(tensor)
|
||||||
|
|
||||||
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
|
def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]):
|
||||||
|
|
||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
return data.element_size() * data.numel()
|
return data.element_size() * data.numel()
|
||||||
@ -151,7 +151,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
|||||||
tokens_roi_recver = [input_tokens, roi]
|
tokens_roi_recver = [input_tokens, roi]
|
||||||
|
|
||||||
def is_buffer_available(
|
def is_buffer_available(
|
||||||
tokens_roi_recver: List[torch.Tensor], ) -> bool:
|
tokens_roi_recver: list[torch.Tensor], ) -> bool:
|
||||||
# perform input tokens and roi matching
|
# perform input tokens and roi matching
|
||||||
# FIXME: this matching is O(n), ideally it should be O(1)
|
# FIXME: this matching is O(n), ideally it should be O(1)
|
||||||
# but this buffer size won't (and shouldn't) be too large so
|
# but this buffer size won't (and shouldn't) be too large so
|
||||||
@ -184,7 +184,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
|||||||
|
|
||||||
def drop_select(
|
def drop_select(
|
||||||
self, input_tokens: Optional[torch.Tensor],
|
self, input_tokens: Optional[torch.Tensor],
|
||||||
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||||
|
|
||||||
assert self.request_handling_thread is None, \
|
assert self.request_handling_thread is None, \
|
||||||
"drop_select should be called by the KV cache consumer "\
|
"drop_select should be called by the KV cache consumer "\
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Callable, Dict, Optional, Tuple
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ class BrokenPipeException(Exception):
|
|||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
Metadata = Dict[str, Optional[torch.Tensor]]
|
Metadata = dict[str, Optional[torch.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
class PyNcclPipe(KVPipeBase):
|
class PyNcclPipe(KVPipeBase):
|
||||||
@ -83,7 +83,7 @@ class PyNcclPipe(KVPipeBase):
|
|||||||
|
|
||||||
def _get_device_send_recv_impl(
|
def _get_device_send_recv_impl(
|
||||||
self, group: StatelessProcessGroup
|
self, group: StatelessProcessGroup
|
||||||
) -> Tuple[Callable[[torch.Tensor, int], None], Callable[
|
) -> tuple[Callable[[torch.Tensor, int], None], Callable[
|
||||||
[torch.Tensor, int], None]]:
|
[torch.Tensor, int], None]]:
|
||||||
|
|
||||||
send: Callable[[torch.Tensor, int], None]
|
send: Callable[[torch.Tensor, int], None]
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from collections import namedtuple
|
|||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -54,15 +54,15 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
|||||||
|
|
||||||
|
|
||||||
def _split_tensor_dict(
|
def _split_tensor_dict(
|
||||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
tensor_dict: dict[str, Union[torch.Tensor, Any]]
|
||||||
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
|
||||||
"""Split the tensor dictionary into two parts:
|
"""Split the tensor dictionary into two parts:
|
||||||
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
||||||
by its metadata.
|
by its metadata.
|
||||||
2. A list of tensors.
|
2. A list of tensors.
|
||||||
"""
|
"""
|
||||||
metadata_list: List[Tuple[str, Any]] = []
|
metadata_list: list[tuple[str, Any]] = []
|
||||||
tensor_list: List[torch.Tensor] = []
|
tensor_list: list[torch.Tensor] = []
|
||||||
for key, value in tensor_dict.items():
|
for key, value in tensor_dict.items():
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
# Note: we cannot use `value.device` here,
|
# Note: we cannot use `value.device` here,
|
||||||
@ -78,7 +78,7 @@ def _split_tensor_dict(
|
|||||||
return metadata_list, tensor_list
|
return metadata_list, tensor_list
|
||||||
|
|
||||||
|
|
||||||
_group_name_counter: Dict[str, int] = {}
|
_group_name_counter: dict[str, int] = {}
|
||||||
|
|
||||||
|
|
||||||
def _get_unique_name(name: str) -> str:
|
def _get_unique_name(name: str) -> str:
|
||||||
@ -94,7 +94,7 @@ def _get_unique_name(name: str) -> str:
|
|||||||
return newname
|
return newname
|
||||||
|
|
||||||
|
|
||||||
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
|
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
|
||||||
|
|
||||||
|
|
||||||
def _register_group(group: "GroupCoordinator") -> None:
|
def _register_group(group: "GroupCoordinator") -> None:
|
||||||
@ -182,7 +182,7 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
# available attributes:
|
# available attributes:
|
||||||
rank: int # global rank
|
rank: int # global rank
|
||||||
ranks: List[int] # global ranks in the group
|
ranks: list[int] # global ranks in the group
|
||||||
world_size: int # size of the group
|
world_size: int # size of the group
|
||||||
# difference between `local_rank` and `rank_in_group`:
|
# difference between `local_rank` and `rank_in_group`:
|
||||||
# if we have a group of size 4 across two nodes:
|
# if we have a group of size 4 across two nodes:
|
||||||
@ -201,7 +201,7 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group_ranks: List[List[int]],
|
group_ranks: list[list[int]],
|
||||||
local_rank: int,
|
local_rank: int,
|
||||||
torch_distributed_backend: Union[str, Backend],
|
torch_distributed_backend: Union[str, Backend],
|
||||||
use_device_communicator: bool,
|
use_device_communicator: bool,
|
||||||
@ -435,7 +435,7 @@ class GroupCoordinator:
|
|||||||
return recv[0]
|
return recv[0]
|
||||||
|
|
||||||
def broadcast_object_list(self,
|
def broadcast_object_list(self,
|
||||||
obj_list: List[Any],
|
obj_list: list[Any],
|
||||||
src: int = 0,
|
src: int = 0,
|
||||||
group: Optional[ProcessGroup] = None):
|
group: Optional[ProcessGroup] = None):
|
||||||
"""Broadcast the input object list.
|
"""Broadcast the input object list.
|
||||||
@ -518,11 +518,11 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
def broadcast_tensor_dict(
|
def broadcast_tensor_dict(
|
||||||
self,
|
self,
|
||||||
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
|
tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
|
||||||
src: int = 0,
|
src: int = 0,
|
||||||
group: Optional[ProcessGroup] = None,
|
group: Optional[ProcessGroup] = None,
|
||||||
metadata_group: Optional[ProcessGroup] = None
|
metadata_group: Optional[ProcessGroup] = None
|
||||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||||
"""Broadcast the input tensor dictionary.
|
"""Broadcast the input tensor dictionary.
|
||||||
NOTE: `src` is the local rank of the source rank.
|
NOTE: `src` is the local rank of the source rank.
|
||||||
"""
|
"""
|
||||||
@ -536,7 +536,7 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
rank_in_group = self.rank_in_group
|
rank_in_group = self.rank_in_group
|
||||||
if rank_in_group == src:
|
if rank_in_group == src:
|
||||||
metadata_list: List[Tuple[Any, Any]] = []
|
metadata_list: list[tuple[Any, Any]] = []
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
tensor_dict,
|
tensor_dict,
|
||||||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
||||||
@ -603,10 +603,10 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
def send_tensor_dict(
|
def send_tensor_dict(
|
||||||
self,
|
self,
|
||||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||||
dst: Optional[int] = None,
|
dst: Optional[int] = None,
|
||||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||||
"""Send the input tensor dictionary.
|
"""Send the input tensor dictionary.
|
||||||
NOTE: `dst` is the local rank of the source rank.
|
NOTE: `dst` is the local rank of the source rank.
|
||||||
"""
|
"""
|
||||||
@ -626,7 +626,7 @@ class GroupCoordinator:
|
|||||||
dst = (self.rank_in_group + 1) % self.world_size
|
dst = (self.rank_in_group + 1) % self.world_size
|
||||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||||
|
|
||||||
metadata_list: List[Tuple[Any, Any]] = []
|
metadata_list: list[tuple[Any, Any]] = []
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
tensor_dict,
|
tensor_dict,
|
||||||
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
|
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||||
@ -661,7 +661,7 @@ class GroupCoordinator:
|
|||||||
self,
|
self,
|
||||||
src: Optional[int] = None,
|
src: Optional[int] = None,
|
||||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||||
"""Recv the input tensor dictionary.
|
"""Recv the input tensor dictionary.
|
||||||
NOTE: `src` is the local rank of the source rank.
|
NOTE: `src` is the local rank of the source rank.
|
||||||
"""
|
"""
|
||||||
@ -682,7 +682,7 @@ class GroupCoordinator:
|
|||||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
recv_metadata_list = self.recv_object(src=src)
|
recv_metadata_list = self.recv_object(src=src)
|
||||||
tensor_dict: Dict[str, Any] = {}
|
tensor_dict: dict[str, Any] = {}
|
||||||
for key, value in recv_metadata_list:
|
for key, value in recv_metadata_list:
|
||||||
if isinstance(value, TensorMetadata):
|
if isinstance(value, TensorMetadata):
|
||||||
tensor = torch.empty(value.size,
|
tensor = torch.empty(value.size,
|
||||||
@ -764,7 +764,7 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, hidden_states: torch.Tensor,
|
self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
if self.device_communicator is not None:
|
if self.device_communicator is not None:
|
||||||
return self.device_communicator.dispatch(hidden_states,
|
return self.device_communicator.dispatch(hidden_states,
|
||||||
router_logits)
|
router_logits)
|
||||||
@ -782,7 +782,7 @@ def get_world_group() -> GroupCoordinator:
|
|||||||
return _WORLD
|
return _WORLD
|
||||||
|
|
||||||
|
|
||||||
def init_world_group(ranks: List[int], local_rank: int,
|
def init_world_group(ranks: list[int], local_rank: int,
|
||||||
backend: str) -> GroupCoordinator:
|
backend: str) -> GroupCoordinator:
|
||||||
return GroupCoordinator(
|
return GroupCoordinator(
|
||||||
group_ranks=[ranks],
|
group_ranks=[ranks],
|
||||||
@ -794,7 +794,7 @@ def init_world_group(ranks: List[int], local_rank: int,
|
|||||||
|
|
||||||
|
|
||||||
def init_model_parallel_group(
|
def init_model_parallel_group(
|
||||||
group_ranks: List[List[int]],
|
group_ranks: list[list[int]],
|
||||||
local_rank: int,
|
local_rank: int,
|
||||||
backend: str,
|
backend: str,
|
||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
@ -1182,7 +1182,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
|||||||
|
|
||||||
|
|
||||||
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
|
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
|
||||||
source_rank: int = 0) -> List[bool]:
|
source_rank: int = 0) -> list[bool]:
|
||||||
"""
|
"""
|
||||||
This is a collective operation that returns if each rank is in the same node
|
This is a collective operation that returns if each rank is in the same node
|
||||||
as the source rank. It tests if processes are attached to the same
|
as the source rank. It tests if processes are attached to the same
|
||||||
|
|||||||
@ -10,7 +10,8 @@ import pickle
|
|||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup, TCPStore
|
from torch.distributed import ProcessGroup, TCPStore
|
||||||
@ -69,7 +70,7 @@ def split_tensor_along_last_dim(
|
|||||||
|
|
||||||
|
|
||||||
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
||||||
pp_size: int) -> Tuple[int, int]:
|
pp_size: int) -> tuple[int, int]:
|
||||||
"""Try to evenly distribute layers across partitions.
|
"""Try to evenly distribute layers across partitions.
|
||||||
|
|
||||||
If the number of layers is not divisible by the number of partitions,
|
If the number of layers is not divisible by the number of partitions,
|
||||||
@ -132,15 +133,15 @@ class StatelessProcessGroup:
|
|||||||
data_expiration_seconds: int = 3600 # 1 hour
|
data_expiration_seconds: int = 3600 # 1 hour
|
||||||
|
|
||||||
# dst rank -> counter
|
# dst rank -> counter
|
||||||
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||||
# src rank -> counter
|
# src rank -> counter
|
||||||
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
|
||||||
broadcast_send_counter: int = 0
|
broadcast_send_counter: int = 0
|
||||||
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(
|
broadcast_recv_src_counter: dict[int, int] = dataclasses.field(
|
||||||
default_factory=dict)
|
default_factory=dict)
|
||||||
|
|
||||||
# A deque to store the data entries, with key and timestamp.
|
# A deque to store the data entries, with key and timestamp.
|
||||||
entries: Deque[Tuple[str,
|
entries: deque[tuple[str,
|
||||||
float]] = dataclasses.field(default_factory=deque)
|
float]] = dataclasses.field(default_factory=deque)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user