[CI] Fix mypy for vllm/distributed (#26593)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Wentao Ye 2025-10-13 16:02:24 -04:00 committed by GitHub
parent d2a7938582
commit 314285d4f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 122 additions and 65 deletions

View File

@ -26,6 +26,7 @@ import regex as re
FILES = [ FILES = [
"vllm/*.py", "vllm/*.py",
"vllm/assets", "vllm/assets",
"vllm/distributed",
"vllm/entrypoints", "vllm/entrypoints",
"vllm/inputs", "vllm/inputs",
"vllm/logging_utils", "vllm/logging_utils",
@ -42,7 +43,6 @@ SEPARATE_GROUPS = [
"tests", "tests",
"vllm/attention", "vllm/attention",
"vllm/compilation", "vllm/compilation",
"vllm/distributed",
"vllm/engine", "vllm/engine",
"vllm/executor", "vllm/executor",
"vllm/inputs", "vllm/inputs",

View File

@ -27,7 +27,7 @@ class KVTransferConfig:
engine_id: str | None = None engine_id: str | None = None
"""The engine id for KV transfers.""" """The engine id for KV transfers."""
kv_buffer_device: str | None = "cuda" kv_buffer_device: str = "cuda"
"""The device used by kv connector to buffer the KV cache. Choices are """The device used by kv connector to buffer the KV cache. Choices are
'cuda' and 'cpu'.""" 'cuda' and 'cpu'."""

View File

@ -15,9 +15,11 @@ from vllm.utils.flashinfer import has_flashinfer_all2all
from .base_device_communicator import All2AllManagerBase, Cache from .base_device_communicator import All2AllManagerBase, Cache
if has_flashinfer_all2all(): if has_flashinfer_all2all():
from flashinfer.comm import Mapping from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_alltoall import MnnvlMoe from flashinfer.comm.trtllm_alltoall import (
MnnvlMoe, # type: ignore[import-not-found]
)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -65,6 +67,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1 sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast( hidden_states = self.naive_multicast(
@ -81,6 +84,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
ep_rank = self.rank if is_sequence_parallel else self.dp_rank ep_rank = self.rank if is_sequence_parallel else self.dp_rank
dp_metadata = get_forward_context().dp_metadata dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sp_size = self.tp_group.world_size if is_sequence_parallel else 1 sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
@ -113,7 +117,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
""" """
Gather hidden_states and router_logits from all dp ranks. Gather hidden_states and router_logits from all dp ranks.
""" """
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
@ -130,7 +137,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
""" """
Reduce-scatter hidden_states across all dp ranks. Reduce-scatter hidden_states across all dp ranks.
""" """
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes) hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
@ -155,7 +165,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
if self.internode: if self.internode:
# inter-node communication needs nvshmem, # inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly # intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import ( from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
nvshmem_alloc_empty_unique_id, nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id, nvshmem_get_unique_id,
nvshmem_init, nvshmem_init,
@ -182,7 +192,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
self.handle_cache = Cache() self.handle_cache = Cache()
def get_handle(self, kwargs): def get_handle(self, kwargs):
import pplx_kernels as pplx import pplx_kernels as pplx # type: ignore[import-not-found]
return self.handle_cache.get_or_create( return self.handle_cache.get_or_create(
kwargs, kwargs,
@ -208,7 +218,9 @@ class PPLXAll2AllManager(All2AllManagerBase):
handle.destroy() handle.destroy()
if self.internode: if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)
logger.debug("PPLX NVSHMEM finalize") logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize() nvshmem_finalize()
@ -288,7 +300,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
"args are computed in the Manager itself." "args are computed in the Manager itself."
) )
import deep_ep import deep_ep # type: ignore[import-not-found]
buffer_kwargs = self._make_all2all_kwargs() buffer_kwargs = self._make_all2all_kwargs()
logger.debug("DeepEP all2all args %s", buffer_kwargs) logger.debug("DeepEP all2all args %s", buffer_kwargs)
@ -298,7 +310,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
return handle return handle
def set_num_sms(self, num_sms: int): def set_num_sms(self, num_sms: int):
import deep_ep import deep_ep # type: ignore[import-not-found]
# Right now the buffers are sized for only what the kernels were # Right now the buffers are sized for only what the kernels were
# created with. So we can only reduce the number of SMS used # created with. So we can only reduce the number of SMS used
@ -332,7 +344,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_global_experts: Number of experts in the model. num_global_experts: Number of experts in the model.
num_local_experts: Number of experts in an EP rank. num_local_experts: Number of experts in an EP rank.
""" """
import deep_ep import deep_ep # type: ignore[import-not-found]
# Defaults for internode and intranode are taken from DeepEP tests. # Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
@ -358,7 +370,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
The kwargs for DeepEPLLAll2AllManager is dictated by The kwargs for DeepEPLLAll2AllManager is dictated by
_make_all2all_kwargs. _make_all2all_kwargs.
""" """
import deep_ep import deep_ep # type: ignore[import-not-found]
buffer_kwargs = self._make_all2all_kwargs(**kwargs) buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs) logger.debug("DeepEP all2all args %s", buffer_kwargs)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager from contextlib import contextmanager
from typing import cast
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -118,14 +119,17 @@ class CustomAllreduce:
# now `device` is a `torch.device` object # now `device` is a `torch.device` object
assert isinstance(device, torch.device) assert isinstance(device, torch.device)
self.device = device self.device = device
device_capability = current_platform.get_device_capability().as_version_str() device_capability = current_platform.get_device_capability()
if ( if (
current_platform.is_cuda() current_platform.is_cuda()
and symm_mem_enabled and symm_mem_enabled
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES and device_capability is not None
): ):
device_capability_str = device_capability.as_version_str()
if device_capability_str in CUSTOM_ALL_REDUCE_MAX_SIZES:
max_size = min( max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability_str][world_size],
max_size,
) )
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices: if cuda_visible_devices:
@ -213,6 +217,7 @@ class CustomAllreduce:
# We cannot directly use `dist.all_gather_object` here # We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode. # because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details. # see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data: list[list[list[int] | None]]
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset] all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group)) ranks = sorted(dist.get_process_group_ranks(group=self.group))
@ -221,8 +226,8 @@ class CustomAllreduce:
all_data[i], src=rank, group=self.group, device="cpu" all_data[i], src=rank, group=self.group, device="cpu"
) )
# Unpack list of tuples to tuple of lists. # Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore handles = cast(list[list[int]], [d[0] for d in all_data])
offsets = [d[1] for d in all_data] # type: ignore offsets = cast(list[list[int]], [d[1] for d in all_data])
ops.register_graph_buffers(self._ptr, handles, offsets) ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor): def should_custom_ar(self, inp: torch.Tensor):

View File

@ -52,9 +52,14 @@ class SymmMemCommunicator:
self.device = device self.device = device
self.group = group self.group = group
self.world_size = dist.get_world_size(self.group) self.world_size = dist.get_world_size(self.group)
self.device_capability = ( capability = current_platform.get_device_capability()
current_platform.get_device_capability().as_version_str() if capability is None:
logger.warning(
"SymmMemCommunicator: device capability is unknown, "
"communicator is not available."
) )
return
self.device_capability = capability.as_version_str()
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning( logger.warning(
"SymmMemCommunicator: Device capability %s not supported, " "SymmMemCommunicator: Device capability %s not supported, "

View File

@ -3,7 +3,7 @@
import importlib import importlib
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, cast
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import ( from vllm.distributed.kv_transfer.kv_connector.base import (
@ -48,6 +48,8 @@ class KVConnectorFactory:
) )
kv_transfer_config = config.kv_transfer_config kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(kv_transfer_config) connector_cls = cls.get_connector_class(kv_transfer_config)
logger.info( logger.info(
"Creating v1 connector with name: %s and engine_id: %s", "Creating v1 connector with name: %s and engine_id: %s",
@ -70,6 +72,8 @@ class KVConnectorFactory:
) -> type[KVConnectorBaseType]: ) -> type[KVConnectorBaseType]:
"""Get the connector class by name.""" """Get the connector class by name."""
connector_name = kv_transfer_config.kv_connector connector_name = kv_transfer_config.kv_connector
if connector_name is None:
raise ValueError("Connector name is not set in KVTransferConfig")
if connector_name in cls._registry: if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]() connector_cls = cls._registry[connector_name]()
else: else:
@ -77,7 +81,13 @@ class KVConnectorFactory:
if connector_module_path is None: if connector_module_path is None:
raise ValueError(f"Unsupported connector type: {connector_name}") raise ValueError(f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path) connector_module = importlib.import_module(connector_module_path)
try:
connector_cls = getattr(connector_module, connector_name) connector_cls = getattr(connector_module, connector_name)
except AttributeError as e:
raise AttributeError(
f"Class {connector_name} not found in {connector_module_path}"
) from e
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
return connector_cls return connector_cls

View File

@ -151,21 +151,21 @@ class KVOutputAggregator:
aggregated_kv_connector_stats = None aggregated_kv_connector_stats = None
invalid_block_ids = set[int]() invalid_block_ids = set[int]()
for model_runner_output in outputs: for model_runner_output in outputs:
output = model_runner_output.kv_connector_output kv_output = model_runner_output.kv_connector_output
if not output: if not kv_output:
continue continue
update_finished_set( update_finished_set(
output.finished_sending, self._send_remaining_count, finished_sending kv_output.finished_sending, self._send_remaining_count, finished_sending
) )
update_finished_set( update_finished_set(
output.finished_recving, self._recv_remaining_count, finished_recving kv_output.finished_recving, self._recv_remaining_count, finished_recving
) )
# Aggregate kv_connector_stats from all workers. # Aggregate kv_connector_stats from all workers.
if aggregated_kv_connector_stats is None: if aggregated_kv_connector_stats is None:
# Use the first worker's kv_connector_stats as accumulator. # Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = output.kv_connector_stats aggregated_kv_connector_stats = kv_output.kv_connector_stats
elif kv_connector_stats := output.kv_connector_stats: elif kv_connector_stats := kv_output.kv_connector_stats:
if aggregated_kv_connector_stats is None: if aggregated_kv_connector_stats is None:
aggregated_kv_connector_stats = kv_connector_stats aggregated_kv_connector_stats = kv_connector_stats
else: else:
@ -176,7 +176,7 @@ class KVOutputAggregator:
aggregated_kv_connector_stats.aggregate(kv_connector_stats) aggregated_kv_connector_stats.aggregate(kv_connector_stats)
) )
invalid_block_ids |= output.invalid_block_ids invalid_block_ids |= kv_output.invalid_block_ids
# select output of the worker specified by output_rank # select output of the worker specified by output_rank
output = outputs[output_rank] output = outputs[output_rank]

View File

@ -95,6 +95,10 @@ class KVConnectorBase_V1(ABC):
) )
self._connector_metadata: KVConnectorMetadata | None = None self._connector_metadata: KVConnectorMetadata | None = None
self._vllm_config = vllm_config self._vllm_config = vllm_config
if vllm_config.kv_transfer_config is not None:
self._kv_transfer_config = vllm_config.kv_transfer_config
else:
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
self._role = role self._role = role
@property @property

View File

@ -86,13 +86,11 @@ class MultiConnector(KVConnectorBase_V1):
super().__init__(vllm_config=vllm_config, role=role) super().__init__(vllm_config=vllm_config, role=role)
self._connectors: list[KVConnectorBase_V1] = [] self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = [] self._ktc_kv_transfer_config = []
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors")
"connectors"
)
assert ktcs is not None assert ktcs is not None
for ktc in ktcs: for ktc in ktcs:
temp_config = copy.copy(vllm_config) temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id) engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig( temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id **ktc, engine_id=engine_id
) )
@ -296,6 +294,7 @@ class MultiConnector(KVConnectorBase_V1):
str: the required KV cache layout. e.g. HND, or NHD. str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout. None if the connector does not require a specific layout.
""" """
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors" "connectors"
) )

View File

@ -297,6 +297,7 @@ class NixlConnectorScheduler:
+ vllm_config.parallel_config.data_parallel_rank + vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size * vllm_config.parallel_config.tensor_parallel_size
) )
assert vllm_config.kv_transfer_config is not None
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
logger.info("Initializing NIXL Scheduler %s", engine_id) logger.info("Initializing NIXL Scheduler %s", engine_id)
@ -340,7 +341,8 @@ class NixlConnectorScheduler:
if params is not None and params.get("do_remote_prefill"): if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote. # Remote prefill: get all prompt blocks from remote.
count = len(request.prompt_token_ids) - num_computed_tokens token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens
if count > 0: if count > 0:
return count, True return count, True
@ -521,6 +523,9 @@ class NixlConnectorWorker:
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
if vllm_config.kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set for NixlConnector")
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config( self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"] "backends", ["UCX"]
) )
@ -577,17 +582,18 @@ class NixlConnectorWorker:
self.use_host_buffer = self.kv_buffer_device == "cpu" self.use_host_buffer = self.kv_buffer_device == "cpu"
# support for oot platform which can't register nixl memory # support for oot platform which can't register nixl memory
# type based on kv_buffer_device # type based on kv_buffer_device
self.nixl_memory_type = current_platform.get_nixl_memory_type() nixl_memory_type = current_platform.get_nixl_memory_type()
if self.nixl_memory_type is None: if nixl_memory_type is None:
if self.kv_buffer_device == "cuda": if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM" nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu": elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM" nixl_memory_type = "DRAM"
if self.nixl_memory_type is None: if nixl_memory_type is None:
raise RuntimeError( raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer " f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported." "is not supported."
) )
self.nixl_memory_type = nixl_memory_type
# Note: host xfer buffer ops when use_host_buffer is True # Note: host xfer buffer ops when use_host_buffer is True
self.copy_blocks: CopyBlocksOp | None = None self.copy_blocks: CopyBlocksOp | None = None

View File

@ -75,9 +75,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
super().__init__(vllm_config=vllm_config, role=role) super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {} self._requests_need_load: dict[str, Any] = {}
self.config = vllm_config.kv_transfer_config self.is_producer = self._kv_transfer_config.is_kv_producer
self.is_producer = self.config.is_kv_producer self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {}
self.chunked_prefill: dict[str, Any] = {}
self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0 self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0
self._local_rank = ( self._local_rank = (
@ -87,7 +86,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
self.p2p_nccl_engine = ( self.p2p_nccl_engine = (
P2pNcclEngine( P2pNcclEngine(
local_rank=self._local_rank, local_rank=self._local_rank,
config=self.config, config=self._kv_transfer_config,
hostname="", hostname="",
port_offset=self._rank, port_offset=self._rank,
) )
@ -346,7 +345,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
if self.is_producer: if self.is_producer:
return 0, False return 0, False
num_external_tokens = len(request.prompt_token_ids) - 1 - num_computed_tokens prompt_token_ids = request.prompt_token_ids or []
num_external_tokens = len(prompt_token_ids) - 1 - num_computed_tokens
if num_external_tokens < 0: if num_external_tokens < 0:
num_external_tokens = 0 num_external_tokens = 0
@ -387,7 +387,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
] ]
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
# the request's prompt is chunked prefill # the request's prompt is chunked prefill
if num_tokens < len(new_req.prompt_token_ids): if num_tokens < len(new_req.prompt_token_ids or []):
# 'CachedRequestData' has no attribute 'prompt_token_ids' # 'CachedRequestData' has no attribute 'prompt_token_ids'
self.chunked_prefill[new_req.req_id] = ( self.chunked_prefill[new_req.req_id] = (
new_req.block_ids[0], new_req.block_ids[0],
@ -397,7 +397,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
# the request's prompt is not chunked prefill # the request's prompt is not chunked prefill
meta.add_request( meta.add_request(
request_id=new_req.req_id, request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids, token_ids=new_req.prompt_token_ids or [],
block_ids=new_req.block_ids[0], block_ids=new_req.block_ids[0],
block_size=self._block_size, block_size=self._block_size,
) )
@ -405,7 +405,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
if new_req.req_id in self._requests_need_load: if new_req.req_id in self._requests_need_load:
meta.add_request( meta.add_request(
request_id=new_req.req_id, request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids, token_ids=new_req.prompt_token_ids or [],
block_ids=new_req.block_ids[0], block_ids=new_req.block_ids[0],
block_size=self._block_size, block_size=self._block_size,
) )
@ -421,10 +421,12 @@ class P2pNcclConnector(KVConnectorBase_V1):
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id] num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = num_scheduled_tokens + num_computed_tokens num_tokens = num_scheduled_tokens + num_computed_tokens
assert req_id in self.chunked_prefill assert req_id in self.chunked_prefill
assert new_block_ids is not None
block_ids = new_block_ids[0] block_ids = new_block_ids[0]
if not resumed_from_preemption: if not resumed_from_preemption:
block_ids = self.chunked_prefill[req_id][0] + block_ids block_ids = self.chunked_prefill[req_id][0] + block_ids
prompt_token_ids = self.chunked_prefill[req_id][1] prompt_token_ids = self.chunked_prefill[req_id][1]
assert prompt_token_ids is not None
# the request's prompt is chunked prefill again # the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids): if num_tokens < len(prompt_token_ids):
self.chunked_prefill[req_id] = (block_ids, prompt_token_ids) self.chunked_prefill[req_id] = (block_ids, prompt_token_ids)
@ -450,6 +452,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
# NOTE(rob): For resumed req, new_block_ids is all # NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request. # of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0] block_ids = new_block_ids[0]
meta.add_request( meta.add_request(

View File

@ -90,11 +90,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
super().__init__(vllm_config=vllm_config, role=role) super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {} self._requests_need_load: dict[str, Request] = {}
transfer_config = vllm_config.kv_transfer_config self._storage_path = self._kv_transfer_config.get_from_extra_config(
self._storage_path = transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp" "shared_storage_path", "/tmp"
) )
logger.info(vllm_config.kv_transfer_config) logger.info(self._kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path) logger.info("Shared storage path is %s", self._storage_path)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
@ -277,9 +276,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
# Now, first num_tokens_to_check tokens are hit, we need to prepare # Now, first num_tokens_to_check tokens are hit, we need to prepare
# the metadata for the worker connector to correctly load the KV # the metadata for the worker connector to correctly load the KV
num_tokens_to_check = align_to_block_size( token_ids = request.prompt_token_ids or []
len(request.prompt_token_ids) - 1, self._block_size num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
)
return num_tokens_to_check - num_computed_tokens, False return num_tokens_to_check - num_computed_tokens, False
@ -311,13 +309,15 @@ class SharedStorageConnector(KVConnectorBase_V1):
total_need_load = 0 total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs: for new_req in scheduler_output.scheduled_new_reqs:
token_ids = new_req.prompt_token_ids or []
mm_hashes = [f.identifier for f in new_req.mm_features]
if new_req.req_id in self._requests_need_load: if new_req.req_id in self._requests_need_load:
meta.add_request( meta.add_request(
token_ids=new_req.prompt_token_ids, token_ids=token_ids,
block_ids=new_req.block_ids[0], block_ids=new_req.block_ids[0],
block_size=self._block_size, block_size=self._block_size,
is_store=False, is_store=False,
mm_hashes=[f.identifier for f in new_req.mm_features], mm_hashes=mm_hashes,
) )
total_need_load += 1 total_need_load += 1
else: else:
@ -325,13 +325,13 @@ class SharedStorageConnector(KVConnectorBase_V1):
# but a single request can have both store and load. # but a single request can have both store and load.
# NOTE(rob): for this debug implementation, we only cache # NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens. # the original prompt tokens.
if not self._found_match_for_request(new_req): if not self._found_match_for_prompt(token_ids, mm_hashes):
meta.add_request( meta.add_request(
token_ids=new_req.prompt_token_ids, token_ids=token_ids,
block_ids=new_req.block_ids[0], block_ids=new_req.block_ids[0],
block_size=self._block_size, block_size=self._block_size,
is_store=True, is_store=True,
mm_hashes=[f.identifier for f in new_req.mm_features], mm_hashes=mm_hashes,
) )
cached_reqs = scheduler_output.scheduled_cached_reqs cached_reqs = scheduler_output.scheduled_cached_reqs
@ -355,6 +355,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# NOTE(rob): For resumed req, new_block_ids is all # NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request. # of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0] block_ids = new_block_ids[0]
meta.add_request( meta.add_request(
@ -379,12 +380,22 @@ class SharedStorageConnector(KVConnectorBase_V1):
request: "Request", request: "Request",
) -> bool: ) -> bool:
"""Check if the cache is hit for the request.""" """Check if the cache is hit for the request."""
return self._found_match_for_prompt(
list(request.prompt_token_ids or []),
[f.identifier for f in request.mm_features],
)
def _found_match_for_prompt(
self,
prompt_token_ids: list[int],
mm_hashes: list[str],
) -> bool:
num_tokens_to_check = align_to_block_size( num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size len(prompt_token_ids) - 1, self._block_size
) )
foldername = self._generate_foldername_debug( foldername = self._generate_foldername_debug(
torch.tensor(request.prompt_token_ids)[:num_tokens_to_check], torch.tensor(prompt_token_ids)[:num_tokens_to_check],
[f.identifier for f in request.mm_features], mm_hashes,
create_folder=False, create_folder=False,
) )
return os.path.exists(foldername) return os.path.exists(foldername)

View File

@ -236,6 +236,7 @@ class MooncakePipe(KVPipeBase):
self.config = config self.config = config
self.local_rank = local_rank self.local_rank = local_rank
self.kv_rank = self.config.kv_rank self.kv_rank = self.config.kv_rank
assert self.kv_rank is not None
if device is None: if device is None:
self.device = self._select_device(self.config.kv_buffer_device) self.device = self._select_device(self.config.kv_buffer_device)
else: else:

View File

@ -53,6 +53,7 @@ class PyNcclPipe(KVPipeBase):
self.config = config self.config = config
self.local_rank = local_rank self.local_rank = local_rank
self.kv_rank = self.config.kv_rank self.kv_rank = self.config.kv_rank
assert self.kv_rank is not None
self.kv_parallel_size = self.config.kv_parallel_size self.kv_parallel_size = self.config.kv_parallel_size
if device is None: if device is None:
self.device = self._select_device(self.config.kv_buffer_device) self.device = self._select_device(self.config.kv_buffer_device)