mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 06:55:01 +08:00
[CI] Fix mypy for vllm/distributed (#26593)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
d2a7938582
commit
314285d4f2
@ -26,6 +26,7 @@ import regex as re
|
|||||||
FILES = [
|
FILES = [
|
||||||
"vllm/*.py",
|
"vllm/*.py",
|
||||||
"vllm/assets",
|
"vllm/assets",
|
||||||
|
"vllm/distributed",
|
||||||
"vllm/entrypoints",
|
"vllm/entrypoints",
|
||||||
"vllm/inputs",
|
"vllm/inputs",
|
||||||
"vllm/logging_utils",
|
"vllm/logging_utils",
|
||||||
@ -42,7 +43,6 @@ SEPARATE_GROUPS = [
|
|||||||
"tests",
|
"tests",
|
||||||
"vllm/attention",
|
"vllm/attention",
|
||||||
"vllm/compilation",
|
"vllm/compilation",
|
||||||
"vllm/distributed",
|
|
||||||
"vllm/engine",
|
"vllm/engine",
|
||||||
"vllm/executor",
|
"vllm/executor",
|
||||||
"vllm/inputs",
|
"vllm/inputs",
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class KVTransferConfig:
|
|||||||
engine_id: str | None = None
|
engine_id: str | None = None
|
||||||
"""The engine id for KV transfers."""
|
"""The engine id for KV transfers."""
|
||||||
|
|
||||||
kv_buffer_device: str | None = "cuda"
|
kv_buffer_device: str = "cuda"
|
||||||
"""The device used by kv connector to buffer the KV cache. Choices are
|
"""The device used by kv connector to buffer the KV cache. Choices are
|
||||||
'cuda' and 'cpu'."""
|
'cuda' and 'cpu'."""
|
||||||
|
|
||||||
|
|||||||
@ -15,9 +15,11 @@ from vllm.utils.flashinfer import has_flashinfer_all2all
|
|||||||
from .base_device_communicator import All2AllManagerBase, Cache
|
from .base_device_communicator import All2AllManagerBase, Cache
|
||||||
|
|
||||||
if has_flashinfer_all2all():
|
if has_flashinfer_all2all():
|
||||||
from flashinfer.comm import Mapping
|
from flashinfer.comm import Mapping # type: ignore[import-not-found]
|
||||||
from flashinfer.comm.mnnvl import MnnvlConfig
|
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
|
||||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
from flashinfer.comm.trtllm_alltoall import (
|
||||||
|
MnnvlMoe, # type: ignore[import-not-found]
|
||||||
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -65,6 +67,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
|||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||||
dp_metadata = get_forward_context().dp_metadata
|
dp_metadata = get_forward_context().dp_metadata
|
||||||
|
assert dp_metadata is not None
|
||||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||||
|
|
||||||
hidden_states = self.naive_multicast(
|
hidden_states = self.naive_multicast(
|
||||||
@ -81,6 +84,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
|||||||
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||||
|
|
||||||
dp_metadata = get_forward_context().dp_metadata
|
dp_metadata = get_forward_context().dp_metadata
|
||||||
|
assert dp_metadata is not None
|
||||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||||
|
|
||||||
@ -113,7 +117,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
|||||||
"""
|
"""
|
||||||
Gather hidden_states and router_logits from all dp ranks.
|
Gather hidden_states and router_logits from all dp ranks.
|
||||||
"""
|
"""
|
||||||
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
dp_metadata = get_forward_context().dp_metadata
|
||||||
|
assert dp_metadata is not None
|
||||||
|
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||||
|
assert sizes is not None
|
||||||
|
|
||||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||||
@ -130,7 +137,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
|||||||
"""
|
"""
|
||||||
Reduce-scatter hidden_states across all dp ranks.
|
Reduce-scatter hidden_states across all dp ranks.
|
||||||
"""
|
"""
|
||||||
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
dp_metadata = get_forward_context().dp_metadata
|
||||||
|
assert dp_metadata is not None
|
||||||
|
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||||
|
assert sizes is not None
|
||||||
|
|
||||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||||
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
|
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
|
||||||
@ -155,7 +165,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
|||||||
if self.internode:
|
if self.internode:
|
||||||
# inter-node communication needs nvshmem,
|
# inter-node communication needs nvshmem,
|
||||||
# intra-node communication uses p2p mapping directly
|
# intra-node communication uses p2p mapping directly
|
||||||
from pplx_kernels.nvshmem import (
|
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
|
||||||
nvshmem_alloc_empty_unique_id,
|
nvshmem_alloc_empty_unique_id,
|
||||||
nvshmem_get_unique_id,
|
nvshmem_get_unique_id,
|
||||||
nvshmem_init,
|
nvshmem_init,
|
||||||
@ -182,7 +192,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
|||||||
self.handle_cache = Cache()
|
self.handle_cache = Cache()
|
||||||
|
|
||||||
def get_handle(self, kwargs):
|
def get_handle(self, kwargs):
|
||||||
import pplx_kernels as pplx
|
import pplx_kernels as pplx # type: ignore[import-not-found]
|
||||||
|
|
||||||
return self.handle_cache.get_or_create(
|
return self.handle_cache.get_or_create(
|
||||||
kwargs,
|
kwargs,
|
||||||
@ -208,7 +218,9 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
|||||||
handle.destroy()
|
handle.destroy()
|
||||||
|
|
||||||
if self.internode:
|
if self.internode:
|
||||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
from pplx_kernels.nvshmem import (
|
||||||
|
nvshmem_finalize, # type: ignore[import-not-found]
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug("PPLX NVSHMEM finalize")
|
logger.debug("PPLX NVSHMEM finalize")
|
||||||
nvshmem_finalize()
|
nvshmem_finalize()
|
||||||
@ -288,7 +300,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
"args are computed in the Manager itself."
|
"args are computed in the Manager itself."
|
||||||
)
|
)
|
||||||
|
|
||||||
import deep_ep
|
import deep_ep # type: ignore[import-not-found]
|
||||||
|
|
||||||
buffer_kwargs = self._make_all2all_kwargs()
|
buffer_kwargs = self._make_all2all_kwargs()
|
||||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||||
@ -298,7 +310,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
return handle
|
return handle
|
||||||
|
|
||||||
def set_num_sms(self, num_sms: int):
|
def set_num_sms(self, num_sms: int):
|
||||||
import deep_ep
|
import deep_ep # type: ignore[import-not-found]
|
||||||
|
|
||||||
# Right now the buffers are sized for only what the kernels were
|
# Right now the buffers are sized for only what the kernels were
|
||||||
# created with. So we can only reduce the number of SMS used
|
# created with. So we can only reduce the number of SMS used
|
||||||
@ -332,7 +344,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
num_global_experts: Number of experts in the model.
|
num_global_experts: Number of experts in the model.
|
||||||
num_local_experts: Number of experts in an EP rank.
|
num_local_experts: Number of experts in an EP rank.
|
||||||
"""
|
"""
|
||||||
import deep_ep
|
import deep_ep # type: ignore[import-not-found]
|
||||||
|
|
||||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||||
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||||
@ -358,7 +370,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
The kwargs for DeepEPLLAll2AllManager is dictated by
|
The kwargs for DeepEPLLAll2AllManager is dictated by
|
||||||
_make_all2all_kwargs.
|
_make_all2all_kwargs.
|
||||||
"""
|
"""
|
||||||
import deep_ep
|
import deep_ep # type: ignore[import-not-found]
|
||||||
|
|
||||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -118,14 +119,17 @@ class CustomAllreduce:
|
|||||||
# now `device` is a `torch.device` object
|
# now `device` is a `torch.device` object
|
||||||
assert isinstance(device, torch.device)
|
assert isinstance(device, torch.device)
|
||||||
self.device = device
|
self.device = device
|
||||||
device_capability = current_platform.get_device_capability().as_version_str()
|
device_capability = current_platform.get_device_capability()
|
||||||
if (
|
if (
|
||||||
current_platform.is_cuda()
|
current_platform.is_cuda()
|
||||||
and symm_mem_enabled
|
and symm_mem_enabled
|
||||||
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES
|
and device_capability is not None
|
||||||
):
|
):
|
||||||
|
device_capability_str = device_capability.as_version_str()
|
||||||
|
if device_capability_str in CUSTOM_ALL_REDUCE_MAX_SIZES:
|
||||||
max_size = min(
|
max_size = min(
|
||||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size
|
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability_str][world_size],
|
||||||
|
max_size,
|
||||||
)
|
)
|
||||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||||
if cuda_visible_devices:
|
if cuda_visible_devices:
|
||||||
@ -213,6 +217,7 @@ class CustomAllreduce:
|
|||||||
# We cannot directly use `dist.all_gather_object` here
|
# We cannot directly use `dist.all_gather_object` here
|
||||||
# because it is incompatible with `gloo` backend under inference mode.
|
# because it is incompatible with `gloo` backend under inference mode.
|
||||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||||
|
all_data: list[list[list[int] | None]]
|
||||||
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
||||||
all_data[self.rank] = [handle, offset]
|
all_data[self.rank] = [handle, offset]
|
||||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||||
@ -221,8 +226,8 @@ class CustomAllreduce:
|
|||||||
all_data[i], src=rank, group=self.group, device="cpu"
|
all_data[i], src=rank, group=self.group, device="cpu"
|
||||||
)
|
)
|
||||||
# Unpack list of tuples to tuple of lists.
|
# Unpack list of tuples to tuple of lists.
|
||||||
handles = [d[0] for d in all_data] # type: ignore
|
handles = cast(list[list[int]], [d[0] for d in all_data])
|
||||||
offsets = [d[1] for d in all_data] # type: ignore
|
offsets = cast(list[list[int]], [d[1] for d in all_data])
|
||||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||||
|
|
||||||
def should_custom_ar(self, inp: torch.Tensor):
|
def should_custom_ar(self, inp: torch.Tensor):
|
||||||
|
|||||||
@ -52,9 +52,14 @@ class SymmMemCommunicator:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.group = group
|
self.group = group
|
||||||
self.world_size = dist.get_world_size(self.group)
|
self.world_size = dist.get_world_size(self.group)
|
||||||
self.device_capability = (
|
capability = current_platform.get_device_capability()
|
||||||
current_platform.get_device_capability().as_version_str()
|
if capability is None:
|
||||||
|
logger.warning(
|
||||||
|
"SymmMemCommunicator: device capability is unknown, "
|
||||||
|
"communicator is not available."
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
self.device_capability = capability.as_version_str()
|
||||||
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
|
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"SymmMemCommunicator: Device capability %s not supported, "
|
"SymmMemCommunicator: Device capability %s not supported, "
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import (
|
from vllm.distributed.kv_transfer.kv_connector.base import (
|
||||||
@ -48,6 +48,8 @@ class KVConnectorFactory:
|
|||||||
)
|
)
|
||||||
|
|
||||||
kv_transfer_config = config.kv_transfer_config
|
kv_transfer_config = config.kv_transfer_config
|
||||||
|
if kv_transfer_config is None:
|
||||||
|
raise ValueError("kv_transfer_config must be set to create a connector")
|
||||||
connector_cls = cls.get_connector_class(kv_transfer_config)
|
connector_cls = cls.get_connector_class(kv_transfer_config)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Creating v1 connector with name: %s and engine_id: %s",
|
"Creating v1 connector with name: %s and engine_id: %s",
|
||||||
@ -70,6 +72,8 @@ class KVConnectorFactory:
|
|||||||
) -> type[KVConnectorBaseType]:
|
) -> type[KVConnectorBaseType]:
|
||||||
"""Get the connector class by name."""
|
"""Get the connector class by name."""
|
||||||
connector_name = kv_transfer_config.kv_connector
|
connector_name = kv_transfer_config.kv_connector
|
||||||
|
if connector_name is None:
|
||||||
|
raise ValueError("Connector name is not set in KVTransferConfig")
|
||||||
if connector_name in cls._registry:
|
if connector_name in cls._registry:
|
||||||
connector_cls = cls._registry[connector_name]()
|
connector_cls = cls._registry[connector_name]()
|
||||||
else:
|
else:
|
||||||
@ -77,7 +81,13 @@ class KVConnectorFactory:
|
|||||||
if connector_module_path is None:
|
if connector_module_path is None:
|
||||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||||
connector_module = importlib.import_module(connector_module_path)
|
connector_module = importlib.import_module(connector_module_path)
|
||||||
|
try:
|
||||||
connector_cls = getattr(connector_module, connector_name)
|
connector_cls = getattr(connector_module, connector_name)
|
||||||
|
except AttributeError as e:
|
||||||
|
raise AttributeError(
|
||||||
|
f"Class {connector_name} not found in {connector_module_path}"
|
||||||
|
) from e
|
||||||
|
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
|
||||||
return connector_cls
|
return connector_cls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -151,21 +151,21 @@ class KVOutputAggregator:
|
|||||||
aggregated_kv_connector_stats = None
|
aggregated_kv_connector_stats = None
|
||||||
invalid_block_ids = set[int]()
|
invalid_block_ids = set[int]()
|
||||||
for model_runner_output in outputs:
|
for model_runner_output in outputs:
|
||||||
output = model_runner_output.kv_connector_output
|
kv_output = model_runner_output.kv_connector_output
|
||||||
if not output:
|
if not kv_output:
|
||||||
continue
|
continue
|
||||||
update_finished_set(
|
update_finished_set(
|
||||||
output.finished_sending, self._send_remaining_count, finished_sending
|
kv_output.finished_sending, self._send_remaining_count, finished_sending
|
||||||
)
|
)
|
||||||
update_finished_set(
|
update_finished_set(
|
||||||
output.finished_recving, self._recv_remaining_count, finished_recving
|
kv_output.finished_recving, self._recv_remaining_count, finished_recving
|
||||||
)
|
)
|
||||||
|
|
||||||
# Aggregate kv_connector_stats from all workers.
|
# Aggregate kv_connector_stats from all workers.
|
||||||
if aggregated_kv_connector_stats is None:
|
if aggregated_kv_connector_stats is None:
|
||||||
# Use the first worker's kv_connector_stats as accumulator.
|
# Use the first worker's kv_connector_stats as accumulator.
|
||||||
aggregated_kv_connector_stats = output.kv_connector_stats
|
aggregated_kv_connector_stats = kv_output.kv_connector_stats
|
||||||
elif kv_connector_stats := output.kv_connector_stats:
|
elif kv_connector_stats := kv_output.kv_connector_stats:
|
||||||
if aggregated_kv_connector_stats is None:
|
if aggregated_kv_connector_stats is None:
|
||||||
aggregated_kv_connector_stats = kv_connector_stats
|
aggregated_kv_connector_stats = kv_connector_stats
|
||||||
else:
|
else:
|
||||||
@ -176,7 +176,7 @@ class KVOutputAggregator:
|
|||||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||||
)
|
)
|
||||||
|
|
||||||
invalid_block_ids |= output.invalid_block_ids
|
invalid_block_ids |= kv_output.invalid_block_ids
|
||||||
|
|
||||||
# select output of the worker specified by output_rank
|
# select output of the worker specified by output_rank
|
||||||
output = outputs[output_rank]
|
output = outputs[output_rank]
|
||||||
|
|||||||
@ -95,6 +95,10 @@ class KVConnectorBase_V1(ABC):
|
|||||||
)
|
)
|
||||||
self._connector_metadata: KVConnectorMetadata | None = None
|
self._connector_metadata: KVConnectorMetadata | None = None
|
||||||
self._vllm_config = vllm_config
|
self._vllm_config = vllm_config
|
||||||
|
if vllm_config.kv_transfer_config is not None:
|
||||||
|
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
|
else:
|
||||||
|
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
|
||||||
self._role = role
|
self._role = role
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -86,13 +86,11 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
super().__init__(vllm_config=vllm_config, role=role)
|
super().__init__(vllm_config=vllm_config, role=role)
|
||||||
self._connectors: list[KVConnectorBase_V1] = []
|
self._connectors: list[KVConnectorBase_V1] = []
|
||||||
self._ktc_kv_transfer_config = []
|
self._ktc_kv_transfer_config = []
|
||||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors")
|
||||||
"connectors"
|
|
||||||
)
|
|
||||||
assert ktcs is not None
|
assert ktcs is not None
|
||||||
for ktc in ktcs:
|
for ktc in ktcs:
|
||||||
temp_config = copy.copy(vllm_config)
|
temp_config = copy.copy(vllm_config)
|
||||||
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
|
engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id)
|
||||||
temp_config.kv_transfer_config = KVTransferConfig(
|
temp_config.kv_transfer_config = KVTransferConfig(
|
||||||
**ktc, engine_id=engine_id
|
**ktc, engine_id=engine_id
|
||||||
)
|
)
|
||||||
@ -296,6 +294,7 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
str: the required KV cache layout. e.g. HND, or NHD.
|
str: the required KV cache layout. e.g. HND, or NHD.
|
||||||
None if the connector does not require a specific layout.
|
None if the connector does not require a specific layout.
|
||||||
"""
|
"""
|
||||||
|
assert vllm_config.kv_transfer_config is not None
|
||||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"connectors"
|
"connectors"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -297,6 +297,7 @@ class NixlConnectorScheduler:
|
|||||||
+ vllm_config.parallel_config.data_parallel_rank
|
+ vllm_config.parallel_config.data_parallel_rank
|
||||||
* vllm_config.parallel_config.tensor_parallel_size
|
* vllm_config.parallel_config.tensor_parallel_size
|
||||||
)
|
)
|
||||||
|
assert vllm_config.kv_transfer_config is not None
|
||||||
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
|
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
|
||||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||||
|
|
||||||
@ -340,7 +341,8 @@ class NixlConnectorScheduler:
|
|||||||
|
|
||||||
if params is not None and params.get("do_remote_prefill"):
|
if params is not None and params.get("do_remote_prefill"):
|
||||||
# Remote prefill: get all prompt blocks from remote.
|
# Remote prefill: get all prompt blocks from remote.
|
||||||
count = len(request.prompt_token_ids) - num_computed_tokens
|
token_ids = request.prompt_token_ids or []
|
||||||
|
count = len(token_ids) - num_computed_tokens
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return count, True
|
return count, True
|
||||||
|
|
||||||
@ -521,6 +523,9 @@ class NixlConnectorWorker:
|
|||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
|
||||||
|
if vllm_config.kv_transfer_config is None:
|
||||||
|
raise ValueError("kv_transfer_config must be set for NixlConnector")
|
||||||
|
|
||||||
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
|
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||||
"backends", ["UCX"]
|
"backends", ["UCX"]
|
||||||
)
|
)
|
||||||
@ -577,17 +582,18 @@ class NixlConnectorWorker:
|
|||||||
self.use_host_buffer = self.kv_buffer_device == "cpu"
|
self.use_host_buffer = self.kv_buffer_device == "cpu"
|
||||||
# support for oot platform which can't register nixl memory
|
# support for oot platform which can't register nixl memory
|
||||||
# type based on kv_buffer_device
|
# type based on kv_buffer_device
|
||||||
self.nixl_memory_type = current_platform.get_nixl_memory_type()
|
nixl_memory_type = current_platform.get_nixl_memory_type()
|
||||||
if self.nixl_memory_type is None:
|
if nixl_memory_type is None:
|
||||||
if self.kv_buffer_device == "cuda":
|
if self.kv_buffer_device == "cuda":
|
||||||
self.nixl_memory_type = "VRAM"
|
nixl_memory_type = "VRAM"
|
||||||
elif self.kv_buffer_device == "cpu":
|
elif self.kv_buffer_device == "cpu":
|
||||||
self.nixl_memory_type = "DRAM"
|
nixl_memory_type = "DRAM"
|
||||||
if self.nixl_memory_type is None:
|
if nixl_memory_type is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
|
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
|
||||||
"is not supported."
|
"is not supported."
|
||||||
)
|
)
|
||||||
|
self.nixl_memory_type = nixl_memory_type
|
||||||
|
|
||||||
# Note: host xfer buffer ops when use_host_buffer is True
|
# Note: host xfer buffer ops when use_host_buffer is True
|
||||||
self.copy_blocks: CopyBlocksOp | None = None
|
self.copy_blocks: CopyBlocksOp | None = None
|
||||||
|
|||||||
@ -75,9 +75,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
super().__init__(vllm_config=vllm_config, role=role)
|
super().__init__(vllm_config=vllm_config, role=role)
|
||||||
self._block_size = vllm_config.cache_config.block_size
|
self._block_size = vllm_config.cache_config.block_size
|
||||||
self._requests_need_load: dict[str, Any] = {}
|
self._requests_need_load: dict[str, Any] = {}
|
||||||
self.config = vllm_config.kv_transfer_config
|
self.is_producer = self._kv_transfer_config.is_kv_producer
|
||||||
self.is_producer = self.config.is_kv_producer
|
self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {}
|
||||||
self.chunked_prefill: dict[str, Any] = {}
|
|
||||||
|
|
||||||
self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0
|
self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0
|
||||||
self._local_rank = (
|
self._local_rank = (
|
||||||
@ -87,7 +86,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
self.p2p_nccl_engine = (
|
self.p2p_nccl_engine = (
|
||||||
P2pNcclEngine(
|
P2pNcclEngine(
|
||||||
local_rank=self._local_rank,
|
local_rank=self._local_rank,
|
||||||
config=self.config,
|
config=self._kv_transfer_config,
|
||||||
hostname="",
|
hostname="",
|
||||||
port_offset=self._rank,
|
port_offset=self._rank,
|
||||||
)
|
)
|
||||||
@ -346,7 +345,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
if self.is_producer:
|
if self.is_producer:
|
||||||
return 0, False
|
return 0, False
|
||||||
|
|
||||||
num_external_tokens = len(request.prompt_token_ids) - 1 - num_computed_tokens
|
prompt_token_ids = request.prompt_token_ids or []
|
||||||
|
num_external_tokens = len(prompt_token_ids) - 1 - num_computed_tokens
|
||||||
|
|
||||||
if num_external_tokens < 0:
|
if num_external_tokens < 0:
|
||||||
num_external_tokens = 0
|
num_external_tokens = 0
|
||||||
@ -387,7 +387,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
]
|
]
|
||||||
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
|
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
|
||||||
# the request's prompt is chunked prefill
|
# the request's prompt is chunked prefill
|
||||||
if num_tokens < len(new_req.prompt_token_ids):
|
if num_tokens < len(new_req.prompt_token_ids or []):
|
||||||
# 'CachedRequestData' has no attribute 'prompt_token_ids'
|
# 'CachedRequestData' has no attribute 'prompt_token_ids'
|
||||||
self.chunked_prefill[new_req.req_id] = (
|
self.chunked_prefill[new_req.req_id] = (
|
||||||
new_req.block_ids[0],
|
new_req.block_ids[0],
|
||||||
@ -397,7 +397,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
# the request's prompt is not chunked prefill
|
# the request's prompt is not chunked prefill
|
||||||
meta.add_request(
|
meta.add_request(
|
||||||
request_id=new_req.req_id,
|
request_id=new_req.req_id,
|
||||||
token_ids=new_req.prompt_token_ids,
|
token_ids=new_req.prompt_token_ids or [],
|
||||||
block_ids=new_req.block_ids[0],
|
block_ids=new_req.block_ids[0],
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
)
|
)
|
||||||
@ -405,7 +405,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
if new_req.req_id in self._requests_need_load:
|
if new_req.req_id in self._requests_need_load:
|
||||||
meta.add_request(
|
meta.add_request(
|
||||||
request_id=new_req.req_id,
|
request_id=new_req.req_id,
|
||||||
token_ids=new_req.prompt_token_ids,
|
token_ids=new_req.prompt_token_ids or [],
|
||||||
block_ids=new_req.block_ids[0],
|
block_ids=new_req.block_ids[0],
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
)
|
)
|
||||||
@ -421,10 +421,12 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id]
|
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id]
|
||||||
num_tokens = num_scheduled_tokens + num_computed_tokens
|
num_tokens = num_scheduled_tokens + num_computed_tokens
|
||||||
assert req_id in self.chunked_prefill
|
assert req_id in self.chunked_prefill
|
||||||
|
assert new_block_ids is not None
|
||||||
block_ids = new_block_ids[0]
|
block_ids = new_block_ids[0]
|
||||||
if not resumed_from_preemption:
|
if not resumed_from_preemption:
|
||||||
block_ids = self.chunked_prefill[req_id][0] + block_ids
|
block_ids = self.chunked_prefill[req_id][0] + block_ids
|
||||||
prompt_token_ids = self.chunked_prefill[req_id][1]
|
prompt_token_ids = self.chunked_prefill[req_id][1]
|
||||||
|
assert prompt_token_ids is not None
|
||||||
# the request's prompt is chunked prefill again
|
# the request's prompt is chunked prefill again
|
||||||
if num_tokens < len(prompt_token_ids):
|
if num_tokens < len(prompt_token_ids):
|
||||||
self.chunked_prefill[req_id] = (block_ids, prompt_token_ids)
|
self.chunked_prefill[req_id] = (block_ids, prompt_token_ids)
|
||||||
@ -450,6 +452,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
# NOTE(rob): For resumed req, new_block_ids is all
|
# NOTE(rob): For resumed req, new_block_ids is all
|
||||||
# of the block_ids for the request.
|
# of the block_ids for the request.
|
||||||
|
assert new_block_ids is not None
|
||||||
block_ids = new_block_ids[0]
|
block_ids = new_block_ids[0]
|
||||||
|
|
||||||
meta.add_request(
|
meta.add_request(
|
||||||
|
|||||||
@ -90,11 +90,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
super().__init__(vllm_config=vllm_config, role=role)
|
super().__init__(vllm_config=vllm_config, role=role)
|
||||||
self._block_size = vllm_config.cache_config.block_size
|
self._block_size = vllm_config.cache_config.block_size
|
||||||
self._requests_need_load: dict[str, Request] = {}
|
self._requests_need_load: dict[str, Request] = {}
|
||||||
transfer_config = vllm_config.kv_transfer_config
|
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||||
self._storage_path = transfer_config.get_from_extra_config(
|
|
||||||
"shared_storage_path", "/tmp"
|
"shared_storage_path", "/tmp"
|
||||||
)
|
)
|
||||||
logger.info(vllm_config.kv_transfer_config)
|
logger.info(self._kv_transfer_config)
|
||||||
logger.info("Shared storage path is %s", self._storage_path)
|
logger.info("Shared storage path is %s", self._storage_path)
|
||||||
|
|
||||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||||
@ -277,9 +276,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
# Now, first num_tokens_to_check tokens are hit, we need to prepare
|
# Now, first num_tokens_to_check tokens are hit, we need to prepare
|
||||||
# the metadata for the worker connector to correctly load the KV
|
# the metadata for the worker connector to correctly load the KV
|
||||||
num_tokens_to_check = align_to_block_size(
|
token_ids = request.prompt_token_ids or []
|
||||||
len(request.prompt_token_ids) - 1, self._block_size
|
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
|
||||||
)
|
|
||||||
|
|
||||||
return num_tokens_to_check - num_computed_tokens, False
|
return num_tokens_to_check - num_computed_tokens, False
|
||||||
|
|
||||||
@ -311,13 +309,15 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
total_need_load = 0
|
total_need_load = 0
|
||||||
for new_req in scheduler_output.scheduled_new_reqs:
|
for new_req in scheduler_output.scheduled_new_reqs:
|
||||||
|
token_ids = new_req.prompt_token_ids or []
|
||||||
|
mm_hashes = [f.identifier for f in new_req.mm_features]
|
||||||
if new_req.req_id in self._requests_need_load:
|
if new_req.req_id in self._requests_need_load:
|
||||||
meta.add_request(
|
meta.add_request(
|
||||||
token_ids=new_req.prompt_token_ids,
|
token_ids=token_ids,
|
||||||
block_ids=new_req.block_ids[0],
|
block_ids=new_req.block_ids[0],
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
is_store=False,
|
is_store=False,
|
||||||
mm_hashes=[f.identifier for f in new_req.mm_features],
|
mm_hashes=mm_hashes,
|
||||||
)
|
)
|
||||||
total_need_load += 1
|
total_need_load += 1
|
||||||
else:
|
else:
|
||||||
@ -325,13 +325,13 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
# but a single request can have both store and load.
|
# but a single request can have both store and load.
|
||||||
# NOTE(rob): for this debug implementation, we only cache
|
# NOTE(rob): for this debug implementation, we only cache
|
||||||
# the original prompt tokens.
|
# the original prompt tokens.
|
||||||
if not self._found_match_for_request(new_req):
|
if not self._found_match_for_prompt(token_ids, mm_hashes):
|
||||||
meta.add_request(
|
meta.add_request(
|
||||||
token_ids=new_req.prompt_token_ids,
|
token_ids=token_ids,
|
||||||
block_ids=new_req.block_ids[0],
|
block_ids=new_req.block_ids[0],
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
is_store=True,
|
is_store=True,
|
||||||
mm_hashes=[f.identifier for f in new_req.mm_features],
|
mm_hashes=mm_hashes,
|
||||||
)
|
)
|
||||||
|
|
||||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
@ -355,6 +355,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
# NOTE(rob): For resumed req, new_block_ids is all
|
# NOTE(rob): For resumed req, new_block_ids is all
|
||||||
# of the block_ids for the request.
|
# of the block_ids for the request.
|
||||||
|
assert new_block_ids is not None
|
||||||
block_ids = new_block_ids[0]
|
block_ids = new_block_ids[0]
|
||||||
|
|
||||||
meta.add_request(
|
meta.add_request(
|
||||||
@ -379,12 +380,22 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
request: "Request",
|
request: "Request",
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if the cache is hit for the request."""
|
"""Check if the cache is hit for the request."""
|
||||||
|
return self._found_match_for_prompt(
|
||||||
|
list(request.prompt_token_ids or []),
|
||||||
|
[f.identifier for f in request.mm_features],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _found_match_for_prompt(
|
||||||
|
self,
|
||||||
|
prompt_token_ids: list[int],
|
||||||
|
mm_hashes: list[str],
|
||||||
|
) -> bool:
|
||||||
num_tokens_to_check = align_to_block_size(
|
num_tokens_to_check = align_to_block_size(
|
||||||
len(request.prompt_token_ids) - 1, self._block_size
|
len(prompt_token_ids) - 1, self._block_size
|
||||||
)
|
)
|
||||||
foldername = self._generate_foldername_debug(
|
foldername = self._generate_foldername_debug(
|
||||||
torch.tensor(request.prompt_token_ids)[:num_tokens_to_check],
|
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
|
||||||
[f.identifier for f in request.mm_features],
|
mm_hashes,
|
||||||
create_folder=False,
|
create_folder=False,
|
||||||
)
|
)
|
||||||
return os.path.exists(foldername)
|
return os.path.exists(foldername)
|
||||||
|
|||||||
@ -236,6 +236,7 @@ class MooncakePipe(KVPipeBase):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.kv_rank = self.config.kv_rank
|
self.kv_rank = self.config.kv_rank
|
||||||
|
assert self.kv_rank is not None
|
||||||
if device is None:
|
if device is None:
|
||||||
self.device = self._select_device(self.config.kv_buffer_device)
|
self.device = self._select_device(self.config.kv_buffer_device)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -53,6 +53,7 @@ class PyNcclPipe(KVPipeBase):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.kv_rank = self.config.kv_rank
|
self.kv_rank = self.config.kv_rank
|
||||||
|
assert self.kv_rank is not None
|
||||||
self.kv_parallel_size = self.config.kv_parallel_size
|
self.kv_parallel_size = self.config.kv_parallel_size
|
||||||
if device is None:
|
if device is None:
|
||||||
self.device = self._select_device(self.config.kv_buffer_device)
|
self.device = self._select_device(self.config.kv_buffer_device)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user