mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:55:02 +08:00
[CI] Fix mypy for vllm/distributed (#26593)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
d2a7938582
commit
314285d4f2
@ -26,6 +26,7 @@ import regex as re
|
||||
FILES = [
|
||||
"vllm/*.py",
|
||||
"vllm/assets",
|
||||
"vllm/distributed",
|
||||
"vllm/entrypoints",
|
||||
"vllm/inputs",
|
||||
"vllm/logging_utils",
|
||||
@ -42,7 +43,6 @@ SEPARATE_GROUPS = [
|
||||
"tests",
|
||||
"vllm/attention",
|
||||
"vllm/compilation",
|
||||
"vllm/distributed",
|
||||
"vllm/engine",
|
||||
"vllm/executor",
|
||||
"vllm/inputs",
|
||||
|
||||
@ -27,7 +27,7 @@ class KVTransferConfig:
|
||||
engine_id: str | None = None
|
||||
"""The engine id for KV transfers."""
|
||||
|
||||
kv_buffer_device: str | None = "cuda"
|
||||
kv_buffer_device: str = "cuda"
|
||||
"""The device used by kv connector to buffer the KV cache. Choices are
|
||||
'cuda' and 'cpu'."""
|
||||
|
||||
|
||||
@ -15,9 +15,11 @@ from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
if has_flashinfer_all2all():
|
||||
from flashinfer.comm import Mapping
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
from flashinfer.comm import Mapping # type: ignore[import-not-found]
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
|
||||
from flashinfer.comm.trtllm_alltoall import (
|
||||
MnnvlMoe, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -65,6 +67,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(
|
||||
@ -81,6 +84,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
@ -113,7 +117,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
@ -130,7 +137,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
Reduce-scatter hidden_states across all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
|
||||
@ -155,7 +165,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import (
|
||||
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
@ -182,7 +192,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
self.handle_cache = Cache()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx
|
||||
import pplx_kernels as pplx # type: ignore[import-not-found]
|
||||
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs,
|
||||
@ -208,7 +218,9 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_finalize, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
@ -288,7 +300,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"args are computed in the Manager itself."
|
||||
)
|
||||
|
||||
import deep_ep
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
buffer_kwargs = self._make_all2all_kwargs()
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
@ -298,7 +310,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
return handle
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
import deep_ep
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
# Right now the buffers are sized for only what the kernels were
|
||||
# created with. So we can only reduce the number of SMS used
|
||||
@ -332,7 +344,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_global_experts: Number of experts in the model.
|
||||
num_local_experts: Number of experts in an EP rank.
|
||||
"""
|
||||
import deep_ep
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
@ -358,7 +370,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
The kwargs for DeepEPLLAll2AllManager is dictated by
|
||||
_make_all2all_kwargs.
|
||||
"""
|
||||
import deep_ep
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -118,15 +119,18 @@ class CustomAllreduce:
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
device_capability = current_platform.get_device_capability().as_version_str()
|
||||
device_capability = current_platform.get_device_capability()
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and symm_mem_enabled
|
||||
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES
|
||||
and device_capability is not None
|
||||
):
|
||||
max_size = min(
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size
|
||||
)
|
||||
device_capability_str = device_capability.as_version_str()
|
||||
if device_capability_str in CUSTOM_ALL_REDUCE_MAX_SIZES:
|
||||
max_size = min(
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability_str][world_size],
|
||||
max_size,
|
||||
)
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
@ -213,6 +217,7 @@ class CustomAllreduce:
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data: list[list[list[int] | None]]
|
||||
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
@ -221,8 +226,8 @@ class CustomAllreduce:
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
handles = cast(list[list[int]], [d[0] for d in all_data])
|
||||
offsets = cast(list[list[int]], [d[1] for d in all_data])
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
|
||||
@ -52,9 +52,14 @@ class SymmMemCommunicator:
|
||||
self.device = device
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
self.device_capability = (
|
||||
current_platform.get_device_capability().as_version_str()
|
||||
)
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability is None:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: device capability is unknown, "
|
||||
"communicator is not available."
|
||||
)
|
||||
return
|
||||
self.device_capability = capability.as_version_str()
|
||||
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: Device capability %s not supported, "
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import (
|
||||
@ -48,6 +48,8 @@ class KVConnectorFactory:
|
||||
)
|
||||
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
if kv_transfer_config is None:
|
||||
raise ValueError("kv_transfer_config must be set to create a connector")
|
||||
connector_cls = cls.get_connector_class(kv_transfer_config)
|
||||
logger.info(
|
||||
"Creating v1 connector with name: %s and engine_id: %s",
|
||||
@ -70,6 +72,8 @@ class KVConnectorFactory:
|
||||
) -> type[KVConnectorBaseType]:
|
||||
"""Get the connector class by name."""
|
||||
connector_name = kv_transfer_config.kv_connector
|
||||
if connector_name is None:
|
||||
raise ValueError("Connector name is not set in KVTransferConfig")
|
||||
if connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
@ -77,7 +81,13 @@ class KVConnectorFactory:
|
||||
if connector_module_path is None:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
try:
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"Class {connector_name} not found in {connector_module_path}"
|
||||
) from e
|
||||
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
|
||||
return connector_cls
|
||||
|
||||
|
||||
|
||||
@ -151,21 +151,21 @@ class KVOutputAggregator:
|
||||
aggregated_kv_connector_stats = None
|
||||
invalid_block_ids = set[int]()
|
||||
for model_runner_output in outputs:
|
||||
output = model_runner_output.kv_connector_output
|
||||
if not output:
|
||||
kv_output = model_runner_output.kv_connector_output
|
||||
if not kv_output:
|
||||
continue
|
||||
update_finished_set(
|
||||
output.finished_sending, self._send_remaining_count, finished_sending
|
||||
kv_output.finished_sending, self._send_remaining_count, finished_sending
|
||||
)
|
||||
update_finished_set(
|
||||
output.finished_recving, self._recv_remaining_count, finished_recving
|
||||
kv_output.finished_recving, self._recv_remaining_count, finished_recving
|
||||
)
|
||||
|
||||
# Aggregate kv_connector_stats from all workers.
|
||||
if aggregated_kv_connector_stats is None:
|
||||
# Use the first worker's kv_connector_stats as accumulator.
|
||||
aggregated_kv_connector_stats = output.kv_connector_stats
|
||||
elif kv_connector_stats := output.kv_connector_stats:
|
||||
aggregated_kv_connector_stats = kv_output.kv_connector_stats
|
||||
elif kv_connector_stats := kv_output.kv_connector_stats:
|
||||
if aggregated_kv_connector_stats is None:
|
||||
aggregated_kv_connector_stats = kv_connector_stats
|
||||
else:
|
||||
@ -176,7 +176,7 @@ class KVOutputAggregator:
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
)
|
||||
|
||||
invalid_block_ids |= output.invalid_block_ids
|
||||
invalid_block_ids |= kv_output.invalid_block_ids
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
|
||||
@ -95,6 +95,10 @@ class KVConnectorBase_V1(ABC):
|
||||
)
|
||||
self._connector_metadata: KVConnectorMetadata | None = None
|
||||
self._vllm_config = vllm_config
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||
else:
|
||||
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
|
||||
self._role = role
|
||||
|
||||
@property
|
||||
|
||||
@ -86,13 +86,11 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._connectors: list[KVConnectorBase_V1] = []
|
||||
self._ktc_kv_transfer_config = []
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors"
|
||||
)
|
||||
ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors")
|
||||
assert ktcs is not None
|
||||
for ktc in ktcs:
|
||||
temp_config = copy.copy(vllm_config)
|
||||
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
|
||||
engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id)
|
||||
temp_config.kv_transfer_config = KVTransferConfig(
|
||||
**ktc, engine_id=engine_id
|
||||
)
|
||||
@ -296,6 +294,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors"
|
||||
)
|
||||
|
||||
@ -297,6 +297,7 @@ class NixlConnectorScheduler:
|
||||
+ vllm_config.parallel_config.data_parallel_rank
|
||||
* vllm_config.parallel_config.tensor_parallel_size
|
||||
)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
|
||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||
|
||||
@ -340,7 +341,8 @@ class NixlConnectorScheduler:
|
||||
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
count = len(request.prompt_token_ids) - num_computed_tokens
|
||||
token_ids = request.prompt_token_ids or []
|
||||
count = len(token_ids) - num_computed_tokens
|
||||
if count > 0:
|
||||
return count, True
|
||||
|
||||
@ -521,6 +523,9 @@ class NixlConnectorWorker:
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
if vllm_config.kv_transfer_config is None:
|
||||
raise ValueError("kv_transfer_config must be set for NixlConnector")
|
||||
|
||||
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"backends", ["UCX"]
|
||||
)
|
||||
@ -577,17 +582,18 @@ class NixlConnectorWorker:
|
||||
self.use_host_buffer = self.kv_buffer_device == "cpu"
|
||||
# support for oot platform which can't register nixl memory
|
||||
# type based on kv_buffer_device
|
||||
self.nixl_memory_type = current_platform.get_nixl_memory_type()
|
||||
if self.nixl_memory_type is None:
|
||||
nixl_memory_type = current_platform.get_nixl_memory_type()
|
||||
if nixl_memory_type is None:
|
||||
if self.kv_buffer_device == "cuda":
|
||||
self.nixl_memory_type = "VRAM"
|
||||
nixl_memory_type = "VRAM"
|
||||
elif self.kv_buffer_device == "cpu":
|
||||
self.nixl_memory_type = "DRAM"
|
||||
if self.nixl_memory_type is None:
|
||||
nixl_memory_type = "DRAM"
|
||||
if nixl_memory_type is None:
|
||||
raise RuntimeError(
|
||||
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
|
||||
"is not supported."
|
||||
)
|
||||
self.nixl_memory_type = nixl_memory_type
|
||||
|
||||
# Note: host xfer buffer ops when use_host_buffer is True
|
||||
self.copy_blocks: CopyBlocksOp | None = None
|
||||
|
||||
@ -75,9 +75,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Any] = {}
|
||||
self.config = vllm_config.kv_transfer_config
|
||||
self.is_producer = self.config.is_kv_producer
|
||||
self.chunked_prefill: dict[str, Any] = {}
|
||||
self.is_producer = self._kv_transfer_config.is_kv_producer
|
||||
self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {}
|
||||
|
||||
self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0
|
||||
self._local_rank = (
|
||||
@ -87,7 +86,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
self.p2p_nccl_engine = (
|
||||
P2pNcclEngine(
|
||||
local_rank=self._local_rank,
|
||||
config=self.config,
|
||||
config=self._kv_transfer_config,
|
||||
hostname="",
|
||||
port_offset=self._rank,
|
||||
)
|
||||
@ -346,7 +345,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
if self.is_producer:
|
||||
return 0, False
|
||||
|
||||
num_external_tokens = len(request.prompt_token_ids) - 1 - num_computed_tokens
|
||||
prompt_token_ids = request.prompt_token_ids or []
|
||||
num_external_tokens = len(prompt_token_ids) - 1 - num_computed_tokens
|
||||
|
||||
if num_external_tokens < 0:
|
||||
num_external_tokens = 0
|
||||
@ -387,7 +387,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
]
|
||||
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
|
||||
# the request's prompt is chunked prefill
|
||||
if num_tokens < len(new_req.prompt_token_ids):
|
||||
if num_tokens < len(new_req.prompt_token_ids or []):
|
||||
# 'CachedRequestData' has no attribute 'prompt_token_ids'
|
||||
self.chunked_prefill[new_req.req_id] = (
|
||||
new_req.block_ids[0],
|
||||
@ -397,7 +397,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
# the request's prompt is not chunked prefill
|
||||
meta.add_request(
|
||||
request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
token_ids=new_req.prompt_token_ids or [],
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
@ -405,7 +405,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(
|
||||
request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
token_ids=new_req.prompt_token_ids or [],
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
@ -421,10 +421,12 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id]
|
||||
num_tokens = num_scheduled_tokens + num_computed_tokens
|
||||
assert req_id in self.chunked_prefill
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
if not resumed_from_preemption:
|
||||
block_ids = self.chunked_prefill[req_id][0] + block_ids
|
||||
prompt_token_ids = self.chunked_prefill[req_id][1]
|
||||
assert prompt_token_ids is not None
|
||||
# the request's prompt is chunked prefill again
|
||||
if num_tokens < len(prompt_token_ids):
|
||||
self.chunked_prefill[req_id] = (block_ids, prompt_token_ids)
|
||||
@ -450,6 +452,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(
|
||||
|
||||
@ -90,11 +90,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Request] = {}
|
||||
transfer_config = vllm_config.kv_transfer_config
|
||||
self._storage_path = transfer_config.get_from_extra_config(
|
||||
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp"
|
||||
)
|
||||
logger.info(vllm_config.kv_transfer_config)
|
||||
logger.info(self._kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
@ -277,9 +276,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
|
||||
# Now, first num_tokens_to_check tokens are hit, we need to prepare
|
||||
# the metadata for the worker connector to correctly load the KV
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(request.prompt_token_ids) - 1, self._block_size
|
||||
)
|
||||
token_ids = request.prompt_token_ids or []
|
||||
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
|
||||
|
||||
return num_tokens_to_check - num_computed_tokens, False
|
||||
|
||||
@ -311,13 +309,15 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
|
||||
total_need_load = 0
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
token_ids = new_req.prompt_token_ids or []
|
||||
mm_hashes = [f.identifier for f in new_req.mm_features]
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[f.identifier for f in new_req.mm_features],
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
total_need_load += 1
|
||||
else:
|
||||
@ -325,13 +325,13 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
# but a single request can have both store and load.
|
||||
# NOTE(rob): for this debug implementation, we only cache
|
||||
# the original prompt tokens.
|
||||
if not self._found_match_for_request(new_req):
|
||||
if not self._found_match_for_prompt(token_ids, mm_hashes):
|
||||
meta.add_request(
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=True,
|
||||
mm_hashes=[f.identifier for f in new_req.mm_features],
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
@ -355,6 +355,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(
|
||||
@ -379,12 +380,22 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
request: "Request",
|
||||
) -> bool:
|
||||
"""Check if the cache is hit for the request."""
|
||||
return self._found_match_for_prompt(
|
||||
list(request.prompt_token_ids or []),
|
||||
[f.identifier for f in request.mm_features],
|
||||
)
|
||||
|
||||
def _found_match_for_prompt(
|
||||
self,
|
||||
prompt_token_ids: list[int],
|
||||
mm_hashes: list[str],
|
||||
) -> bool:
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(request.prompt_token_ids) - 1, self._block_size
|
||||
len(prompt_token_ids) - 1, self._block_size
|
||||
)
|
||||
foldername = self._generate_foldername_debug(
|
||||
torch.tensor(request.prompt_token_ids)[:num_tokens_to_check],
|
||||
[f.identifier for f in request.mm_features],
|
||||
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
|
||||
mm_hashes,
|
||||
create_folder=False,
|
||||
)
|
||||
return os.path.exists(foldername)
|
||||
|
||||
@ -236,6 +236,7 @@ class MooncakePipe(KVPipeBase):
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
assert self.kv_rank is not None
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
else:
|
||||
|
||||
@ -53,6 +53,7 @@ class PyNcclPipe(KVPipeBase):
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
assert self.kv_rank is not None
|
||||
self.kv_parallel_size = self.config.kv_parallel_size
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user