[Core] Use CpuGpuBuffer for block table tensors (#24795)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-09-16 19:18:06 -07:00 committed by GitHub
parent 3059b9cc6b
commit eeb135eb87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 53 additions and 63 deletions

View File

@ -125,7 +125,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
return False
num_blocks = block_table.num_blocks_per_row[req_index]
block_table_values = block_table.block_table_np[req_index, :num_blocks]
block_table_values = block_table.block_table.np[req_index, :num_blocks]
return (block_table_values == req_block_ids).all()

View File

@ -15,6 +15,7 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -45,7 +46,7 @@ def _compare_objs(obj1,
is_same = False
if isinstance(a, torch.Tensor):
if (a.numel() == 0 or b.numel() == 0):
if a.numel() == 0 or b.numel() == 0:
is_same = (a.numel() == 0 and b.numel() == 0)
elif torch.allclose(a, b):
is_same = True
@ -61,6 +62,8 @@ def _compare_objs(obj1,
is_same = True # if we make it here must be same
elif a == b:
is_same = True
elif isinstance(a, CpuGpuBuffer):
is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu)
assert is_same, f"Attribute {attr_name} is different"\
f" in {obj1} and {obj2}: {a} != {b}"

View File

@ -165,7 +165,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_state.block_ids[0]):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table_np[req_index, :num_blocks] ==
return (block_table.block_table.np[req_index, :num_blocks] ==
req_state.block_ids[0]).all()

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
import numpy as np
import torch
@ -7,6 +8,7 @@ import torch
from vllm.distributed import get_dcp_group
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
logger = init_logger(__name__)
@ -29,28 +31,13 @@ class BlockTable:
self.pin_memory = pin_memory
self.device = device
self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()
self.block_table = self._make_buffer(max_num_reqs,
max_num_blocks_per_req,
dtype=torch.int32)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
self.slot_mapping = self._make_buffer(self.max_num_batched_tokens,
dtype=torch.int64)
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
@ -69,7 +56,7 @@ class BlockTable:
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.block_table.np[row_idx, start:start + num_blocks] = block_ids
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
@ -77,17 +64,14 @@ class BlockTable:
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
src, :num_blocks]
block_table_np = self.block_table.np
block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None:
num_blocks_src = self.num_blocks_per_row[src]
num_blocks_tgt = self.num_blocks_per_row[tgt]
self.num_blocks_per_row[src] = num_blocks_tgt
self.num_blocks_per_row[tgt] = num_blocks_src
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
src_tgt, tgt_src = [src, tgt], [tgt, src]
self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
@ -107,7 +91,7 @@ class BlockTable:
virtual_block_size = self.block_size * self.dcp_world_size
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // virtual_block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices]
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
@ -117,40 +101,45 @@ class BlockTable:
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
self.slot_mapping.np[:req_indices.shape[0]] = np.where(
mask, slot_mapping, -1)
else:
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // self.block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices]
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:req_indices.shape[0]])
out=self.slot_mapping.np[:req_indices.shape[0]])
def commit_block_table(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)
self.block_table.copy_to_gpu(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping[:num_tokens].copy_(
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
self.slot_mapping.copy_to_gpu(num_tokens)
def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0)
def get_device_tensor(self) -> torch.Tensor:
def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table
return self.block_table.gpu[:num_reqs]
def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table_cpu
return self.block_table.cpu
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
return self.block_table.np
def _make_buffer(self, *size: Union[int, torch.SymInt],
dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
class MultiGroupBlockTable:

View File

@ -89,7 +89,7 @@ class CPUModelRunner(GPUModelRunner):
assert isinstance(device_tensor, torch.Tensor)
setattr(obj, device_attr_name, cpu_tensor)
for k, v in vars(self).items():
for v in vars(self).values():
if isinstance(v, CpuGpuBuffer):
v.gpu = v.cpu
@ -98,9 +98,9 @@ class CPUModelRunner(GPUModelRunner):
replace_tensor(self.input_batch, k, k[:-11])
for block_table in self.input_batch.block_table.block_tables:
for k, v in vars(block_table).items():
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(block_table, k, k[:-4])
for v in vars(block_table).values():
if isinstance(v, CpuGpuBuffer):
v.gpu = v.cpu
def load_model(self, eep_scale_up: bool = False) -> None:
logger.info("Starting to load model %s...", self.model_config.model)

View File

@ -427,9 +427,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
*size: Union[int, torch.SymInt],
dtype: torch.dtype,
numpy: bool = True) -> CpuGpuBuffer:
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
# if a bfloat16 buffer is needed without a corresponding numpy array,
# don't bother instantiating the numpy array.
return CpuGpuBuffer(*size,
dtype=dtype,
device=self.device,
@ -1062,13 +1059,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_common_prefix_blocks = 0
else:
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
slot_mapping = blk_table.slot_mapping[:
total_num_scheduled_tokens]
blk_table_tensor = blk_table.get_device_tensor(num_reqs)
slot_mapping = blk_table.slot_mapping.gpu[:
total_num_scheduled_tokens]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(
-1)
num_common_prefix_blocks = (
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id])
@ -2903,10 +2901,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=self.max_model_len,
block_table_tensor=self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()[:num_reqs],
slot_mapping=self.input_batch.
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
block_table_tensor=self.input_batch.
block_table[kv_cache_group_id].get_device_tensor(num_reqs),
slot_mapping=self.input_batch.block_table[
kv_cache_group_id].slot_mapping.gpu[:num_tokens],
causal=True)
for attn_group in self.attn_groups[kv_cache_group_id]:
if ubatch_slices is not None:
@ -3265,8 +3263,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=False)
# Capture full cudagraph for uniform decode batches if we have
# dont already have full mixed prefill-decode cudagraphs
# Capture full cudagraph for uniform decode batches if we
# don't already have full mixed prefill-decode cudagraphs.
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
cudagraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \