[CI] Fix mypy for vllm/distributed (#26593)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Wentao Ye 2025-10-13 16:02:24 -04:00 committed by GitHub
parent d2a7938582
commit 314285d4f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 122 additions and 65 deletions

View File

@ -26,6 +26,7 @@ import regex as re
FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/distributed",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging_utils",
@ -42,7 +43,6 @@ SEPARATE_GROUPS = [
"tests",
"vllm/attention",
"vllm/compilation",
"vllm/distributed",
"vllm/engine",
"vllm/executor",
"vllm/inputs",

View File

@ -27,7 +27,7 @@ class KVTransferConfig:
engine_id: str | None = None
"""The engine id for KV transfers."""
kv_buffer_device: str | None = "cuda"
kv_buffer_device: str = "cuda"
"""The device used by kv connector to buffer the KV cache. Choices are
'cuda' and 'cpu'."""

View File

@ -15,9 +15,11 @@ from vllm.utils.flashinfer import has_flashinfer_all2all
from .base_device_communicator import All2AllManagerBase, Cache
if has_flashinfer_all2all():
from flashinfer.comm import Mapping
from flashinfer.comm.mnnvl import MnnvlConfig
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_alltoall import (
MnnvlMoe, # type: ignore[import-not-found]
)
logger = init_logger(__name__)
@ -65,6 +67,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
@ -81,6 +84,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
@ -113,7 +117,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
@ -130,7 +137,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
@ -155,7 +165,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
@ -182,7 +192,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx
import pplx_kernels as pplx # type: ignore[import-not-found]
return self.handle_cache.get_or_create(
kwargs,
@ -208,7 +218,9 @@ class PPLXAll2AllManager(All2AllManagerBase):
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
@ -288,7 +300,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
"args are computed in the Manager itself."
)
import deep_ep
import deep_ep # type: ignore[import-not-found]
buffer_kwargs = self._make_all2all_kwargs()
logger.debug("DeepEP all2all args %s", buffer_kwargs)
@ -298,7 +310,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
return handle
def set_num_sms(self, num_sms: int):
import deep_ep
import deep_ep # type: ignore[import-not-found]
# Right now the buffers are sized for only what the kernels were
# created with. So we can only reduce the number of SMS used
@ -332,7 +344,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_global_experts: Number of experts in the model.
num_local_experts: Number of experts in an EP rank.
"""
import deep_ep
import deep_ep # type: ignore[import-not-found]
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
@ -358,7 +370,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
The kwargs for DeepEPLLAll2AllManager is dictated by
_make_all2all_kwargs.
"""
import deep_ep
import deep_ep # type: ignore[import-not-found]
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import cast
import torch
import torch.distributed as dist
@ -118,15 +119,18 @@ class CustomAllreduce:
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
device_capability = current_platform.get_device_capability().as_version_str()
device_capability = current_platform.get_device_capability()
if (
current_platform.is_cuda()
and symm_mem_enabled
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES
and device_capability is not None
):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size
)
device_capability_str = device_capability.as_version_str()
if device_capability_str in CUSTOM_ALL_REDUCE_MAX_SIZES:
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability_str][world_size],
max_size,
)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
@ -213,6 +217,7 @@ class CustomAllreduce:
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data: list[list[list[int] | None]]
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
@ -221,8 +226,8 @@ class CustomAllreduce:
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
handles = cast(list[list[int]], [d[0] for d in all_data])
offsets = cast(list[list[int]], [d[1] for d in all_data])
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):

View File

@ -52,9 +52,14 @@ class SymmMemCommunicator:
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = (
current_platform.get_device_capability().as_version_str()
)
capability = current_platform.get_device_capability()
if capability is None:
logger.warning(
"SymmMemCommunicator: device capability is unknown, "
"communicator is not available."
)
return
self.device_capability = capability.as_version_str()
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "

View File

@ -3,7 +3,7 @@
import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import (
@ -48,6 +48,8 @@ class KVConnectorFactory:
)
kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(kv_transfer_config)
logger.info(
"Creating v1 connector with name: %s and engine_id: %s",
@ -70,6 +72,8 @@ class KVConnectorFactory:
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_name = kv_transfer_config.kv_connector
if connector_name is None:
raise ValueError("Connector name is not set in KVTransferConfig")
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
@ -77,7 +81,13 @@ class KVConnectorFactory:
if connector_module_path is None:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
try:
connector_cls = getattr(connector_module, connector_name)
except AttributeError as e:
raise AttributeError(
f"Class {connector_name} not found in {connector_module_path}"
) from e
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
return connector_cls

View File

@ -151,21 +151,21 @@ class KVOutputAggregator:
aggregated_kv_connector_stats = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
output = model_runner_output.kv_connector_output
if not output:
kv_output = model_runner_output.kv_connector_output
if not kv_output:
continue
update_finished_set(
output.finished_sending, self._send_remaining_count, finished_sending
kv_output.finished_sending, self._send_remaining_count, finished_sending
)
update_finished_set(
output.finished_recving, self._recv_remaining_count, finished_recving
kv_output.finished_recving, self._recv_remaining_count, finished_recving
)
# Aggregate kv_connector_stats from all workers.
if aggregated_kv_connector_stats is None:
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = output.kv_connector_stats
elif kv_connector_stats := output.kv_connector_stats:
aggregated_kv_connector_stats = kv_output.kv_connector_stats
elif kv_connector_stats := kv_output.kv_connector_stats:
if aggregated_kv_connector_stats is None:
aggregated_kv_connector_stats = kv_connector_stats
else:
@ -176,7 +176,7 @@ class KVOutputAggregator:
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
)
invalid_block_ids |= output.invalid_block_ids
invalid_block_ids |= kv_output.invalid_block_ids
# select output of the worker specified by output_rank
output = outputs[output_rank]

View File

@ -95,6 +95,10 @@ class KVConnectorBase_V1(ABC):
)
self._connector_metadata: KVConnectorMetadata | None = None
self._vllm_config = vllm_config
if vllm_config.kv_transfer_config is not None:
self._kv_transfer_config = vllm_config.kv_transfer_config
else:
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
self._role = role
@property

View File

@ -86,13 +86,11 @@ class MultiConnector(KVConnectorBase_V1):
super().__init__(vllm_config=vllm_config, role=role)
self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors")
assert ktcs is not None
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
@ -296,6 +294,7 @@ class MultiConnector(KVConnectorBase_V1):
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)

View File

@ -297,6 +297,7 @@ class NixlConnectorScheduler:
+ vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
)
assert vllm_config.kv_transfer_config is not None
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
logger.info("Initializing NIXL Scheduler %s", engine_id)
@ -340,7 +341,8 @@ class NixlConnectorScheduler:
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
count = len(request.prompt_token_ids) - num_computed_tokens
token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens
if count > 0:
return count, True
@ -521,6 +523,9 @@ class NixlConnectorWorker:
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
if vllm_config.kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set for NixlConnector")
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"]
)
@ -577,17 +582,18 @@ class NixlConnectorWorker:
self.use_host_buffer = self.kv_buffer_device == "cpu"
# support for oot platform which can't register nixl memory
# type based on kv_buffer_device
self.nixl_memory_type = current_platform.get_nixl_memory_type()
if self.nixl_memory_type is None:
nixl_memory_type = current_platform.get_nixl_memory_type()
if nixl_memory_type is None:
if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM"
nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM"
if self.nixl_memory_type is None:
nixl_memory_type = "DRAM"
if nixl_memory_type is None:
raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported."
)
self.nixl_memory_type = nixl_memory_type
# Note: host xfer buffer ops when use_host_buffer is True
self.copy_blocks: CopyBlocksOp | None = None

View File

@ -75,9 +75,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {}
self.config = vllm_config.kv_transfer_config
self.is_producer = self.config.is_kv_producer
self.chunked_prefill: dict[str, Any] = {}
self.is_producer = self._kv_transfer_config.is_kv_producer
self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {}
self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0
self._local_rank = (
@ -87,7 +86,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
self.p2p_nccl_engine = (
P2pNcclEngine(
local_rank=self._local_rank,
config=self.config,
config=self._kv_transfer_config,
hostname="",
port_offset=self._rank,
)
@ -346,7 +345,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
if self.is_producer:
return 0, False
num_external_tokens = len(request.prompt_token_ids) - 1 - num_computed_tokens
prompt_token_ids = request.prompt_token_ids or []
num_external_tokens = len(prompt_token_ids) - 1 - num_computed_tokens
if num_external_tokens < 0:
num_external_tokens = 0
@ -387,7 +387,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
]
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
# the request's prompt is chunked prefill
if num_tokens < len(new_req.prompt_token_ids):
if num_tokens < len(new_req.prompt_token_ids or []):
# 'CachedRequestData' has no attribute 'prompt_token_ids'
self.chunked_prefill[new_req.req_id] = (
new_req.block_ids[0],
@ -397,7 +397,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
# the request's prompt is not chunked prefill
meta.add_request(
request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
token_ids=new_req.prompt_token_ids or [],
block_ids=new_req.block_ids[0],
block_size=self._block_size,
)
@ -405,7 +405,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
if new_req.req_id in self._requests_need_load:
meta.add_request(
request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
token_ids=new_req.prompt_token_ids or [],
block_ids=new_req.block_ids[0],
block_size=self._block_size,
)
@ -421,10 +421,12 @@ class P2pNcclConnector(KVConnectorBase_V1):
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = num_scheduled_tokens + num_computed_tokens
assert req_id in self.chunked_prefill
assert new_block_ids is not None
block_ids = new_block_ids[0]
if not resumed_from_preemption:
block_ids = self.chunked_prefill[req_id][0] + block_ids
prompt_token_ids = self.chunked_prefill[req_id][1]
assert prompt_token_ids is not None
# the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids):
self.chunked_prefill[req_id] = (block_ids, prompt_token_ids)
@ -450,6 +452,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0]
meta.add_request(

View File

@ -90,11 +90,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {}
transfer_config = vllm_config.kv_transfer_config
self._storage_path = transfer_config.get_from_extra_config(
self._storage_path = self._kv_transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp"
)
logger.info(vllm_config.kv_transfer_config)
logger.info(self._kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
@ -277,9 +276,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
# Now, first num_tokens_to_check tokens are hit, we need to prepare
# the metadata for the worker connector to correctly load the KV
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size
)
token_ids = request.prompt_token_ids or []
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens, False
@ -311,13 +309,15 @@ class SharedStorageConnector(KVConnectorBase_V1):
total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs:
token_ids = new_req.prompt_token_ids or []
mm_hashes = [f.identifier for f in new_req.mm_features]
if new_req.req_id in self._requests_need_load:
meta.add_request(
token_ids=new_req.prompt_token_ids,
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=False,
mm_hashes=[f.identifier for f in new_req.mm_features],
mm_hashes=mm_hashes,
)
total_need_load += 1
else:
@ -325,13 +325,13 @@ class SharedStorageConnector(KVConnectorBase_V1):
# but a single request can have both store and load.
# NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens.
if not self._found_match_for_request(new_req):
if not self._found_match_for_prompt(token_ids, mm_hashes):
meta.add_request(
token_ids=new_req.prompt_token_ids,
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=True,
mm_hashes=[f.identifier for f in new_req.mm_features],
mm_hashes=mm_hashes,
)
cached_reqs = scheduler_output.scheduled_cached_reqs
@ -355,6 +355,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0]
meta.add_request(
@ -379,12 +380,22 @@ class SharedStorageConnector(KVConnectorBase_V1):
request: "Request",
) -> bool:
"""Check if the cache is hit for the request."""
return self._found_match_for_prompt(
list(request.prompt_token_ids or []),
[f.identifier for f in request.mm_features],
)
def _found_match_for_prompt(
self,
prompt_token_ids: list[int],
mm_hashes: list[str],
) -> bool:
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size
len(prompt_token_ids) - 1, self._block_size
)
foldername = self._generate_foldername_debug(
torch.tensor(request.prompt_token_ids)[:num_tokens_to_check],
[f.identifier for f in request.mm_features],
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
mm_hashes,
create_folder=False,
)
return os.path.exists(foldername)

View File

@ -236,6 +236,7 @@ class MooncakePipe(KVPipeBase):
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
assert self.kv_rank is not None
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:

View File

@ -53,6 +53,7 @@ class PyNcclPipe(KVPipeBase):
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
assert self.kv_rank is not None
self.kv_parallel_size = self.config.kv_parallel_size
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)