mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 02:04:29 +08:00
[Core] Use CpuGpuBuffer for block table tensors (#24795)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
3059b9cc6b
commit
eeb135eb87
@ -125,7 +125,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
return False
|
||||
|
||||
num_blocks = block_table.num_blocks_per_row[req_index]
|
||||
block_table_values = block_table.block_table_np[req_index, :num_blocks]
|
||||
block_table_values = block_table.block_table.np[req_index, :num_blocks]
|
||||
return (block_table_values == req_block_ids).all()
|
||||
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
@ -45,7 +46,7 @@ def _compare_objs(obj1,
|
||||
|
||||
is_same = False
|
||||
if isinstance(a, torch.Tensor):
|
||||
if (a.numel() == 0 or b.numel() == 0):
|
||||
if a.numel() == 0 or b.numel() == 0:
|
||||
is_same = (a.numel() == 0 and b.numel() == 0)
|
||||
elif torch.allclose(a, b):
|
||||
is_same = True
|
||||
@ -61,6 +62,8 @@ def _compare_objs(obj1,
|
||||
is_same = True # if we make it here must be same
|
||||
elif a == b:
|
||||
is_same = True
|
||||
elif isinstance(a, CpuGpuBuffer):
|
||||
is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu)
|
||||
assert is_same, f"Attribute {attr_name} is different"\
|
||||
f" in {obj1} and {obj2}: {a} != {b}"
|
||||
|
||||
|
||||
@ -165,7 +165,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
req_state.block_ids[0]):
|
||||
return False
|
||||
num_blocks = block_table.num_blocks_per_row[req_index]
|
||||
return (block_table.block_table_np[req_index, :num_blocks] ==
|
||||
return (block_table.block_table.np[req_index, :num_blocks] ==
|
||||
req_state.block_ids[0]).all()
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -7,6 +8,7 @@ import torch
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -29,28 +31,13 @@ class BlockTable:
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
self.block_table = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.block_table = self._make_buffer(max_num_reqs,
|
||||
max_num_blocks_per_req,
|
||||
dtype=torch.int32)
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.slot_mapping = self._make_buffer(self.max_num_batched_tokens,
|
||||
dtype=torch.int64)
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
@ -69,7 +56,7 @@ class BlockTable:
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
self.block_table.np[row_idx, start:start + num_blocks] = block_ids
|
||||
|
||||
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
||||
self.num_blocks_per_row[row_idx] = 0
|
||||
@ -77,17 +64,14 @@ class BlockTable:
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
|
||||
src, :num_blocks]
|
||||
block_table_np = self.block_table.np
|
||||
block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks_src = self.num_blocks_per_row[src]
|
||||
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
||||
self.num_blocks_per_row[src] = num_blocks_tgt
|
||||
self.num_blocks_per_row[tgt] = num_blocks_src
|
||||
|
||||
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
|
||||
src_tgt, tgt_src = [src, tgt], [tgt, src]
|
||||
self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
|
||||
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
|
||||
|
||||
def compute_slot_mapping(self, req_indices: np.ndarray,
|
||||
positions: np.ndarray) -> None:
|
||||
@ -107,7 +91,7 @@ class BlockTable:
|
||||
virtual_block_size = self.block_size * self.dcp_world_size
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // virtual_block_size)
|
||||
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
@ -117,40 +101,45 @@ class BlockTable:
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
|
||||
self.slot_mapping.np[:req_indices.shape[0]] = np.where(
|
||||
mask, slot_mapping, -1)
|
||||
else:
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // self.block_size)
|
||||
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:req_indices.shape[0]])
|
||||
out=self.slot_mapping.np[:req_indices.shape[0]])
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
self.block_table.copy_to_gpu(num_reqs)
|
||||
|
||||
def commit_slot_mapping(self, num_tokens: int) -> None:
|
||||
self.slot_mapping[:num_tokens].copy_(
|
||||
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
|
||||
self.slot_mapping.copy_to_gpu(num_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.fill_(0)
|
||||
self.block_table_cpu.fill_(0)
|
||||
self.block_table.gpu.fill_(0)
|
||||
self.block_table.cpu.fill_(0)
|
||||
|
||||
def get_device_tensor(self) -> torch.Tensor:
|
||||
def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
|
||||
"""Returns the device tensor of the block table."""
|
||||
return self.block_table
|
||||
return self.block_table.gpu[:num_reqs]
|
||||
|
||||
def get_cpu_tensor(self) -> torch.Tensor:
|
||||
"""Returns the CPU tensor of the block table."""
|
||||
return self.block_table_cpu
|
||||
return self.block_table.cpu
|
||||
|
||||
def get_numpy_array(self) -> np.ndarray:
|
||||
"""Returns the numpy array of the block table."""
|
||||
return self.block_table_np
|
||||
return self.block_table.np
|
||||
|
||||
def _make_buffer(self, *size: Union[int, torch.SymInt],
|
||||
dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
|
||||
class MultiGroupBlockTable:
|
||||
|
||||
@ -89,7 +89,7 @@ class CPUModelRunner(GPUModelRunner):
|
||||
assert isinstance(device_tensor, torch.Tensor)
|
||||
setattr(obj, device_attr_name, cpu_tensor)
|
||||
|
||||
for k, v in vars(self).items():
|
||||
for v in vars(self).values():
|
||||
if isinstance(v, CpuGpuBuffer):
|
||||
v.gpu = v.cpu
|
||||
|
||||
@ -98,9 +98,9 @@ class CPUModelRunner(GPUModelRunner):
|
||||
replace_tensor(self.input_batch, k, k[:-11])
|
||||
|
||||
for block_table in self.input_batch.block_table.block_tables:
|
||||
for k, v in vars(block_table).items():
|
||||
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(block_table, k, k[:-4])
|
||||
for v in vars(block_table).values():
|
||||
if isinstance(v, CpuGpuBuffer):
|
||||
v.gpu = v.cpu
|
||||
|
||||
def load_model(self, eep_scale_up: bool = False) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
|
||||
@ -427,9 +427,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
*size: Union[int, torch.SymInt],
|
||||
dtype: torch.dtype,
|
||||
numpy: bool = True) -> CpuGpuBuffer:
|
||||
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
|
||||
# if a bfloat16 buffer is needed without a corresponding numpy array,
|
||||
# don't bother instantiating the numpy array.
|
||||
return CpuGpuBuffer(*size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
@ -1062,13 +1059,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_common_prefix_blocks = 0
|
||||
else:
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = blk_table.slot_mapping[:
|
||||
total_num_scheduled_tokens]
|
||||
blk_table_tensor = blk_table.get_device_tensor(num_reqs)
|
||||
slot_mapping = blk_table.slot_mapping.gpu[:
|
||||
total_num_scheduled_tokens]
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode.
|
||||
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(
|
||||
-1)
|
||||
num_common_prefix_blocks = (
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id])
|
||||
@ -2903,10 +2901,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_table_tensor=self.input_batch.block_table[
|
||||
kv_cache_group_id].get_device_tensor()[:num_reqs],
|
||||
slot_mapping=self.input_batch.
|
||||
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
|
||||
block_table_tensor=self.input_batch.
|
||||
block_table[kv_cache_group_id].get_device_tensor(num_reqs),
|
||||
slot_mapping=self.input_batch.block_table[
|
||||
kv_cache_group_id].slot_mapping.gpu[:num_tokens],
|
||||
causal=True)
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
if ubatch_slices is not None:
|
||||
@ -3265,8 +3263,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=False)
|
||||
|
||||
# Capture full cudagraph for uniform decode batches if we have
|
||||
# dont already have full mixed prefill-decode cudagraphs
|
||||
# Capture full cudagraph for uniform decode batches if we
|
||||
# don't already have full mixed prefill-decode cudagraphs.
|
||||
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||||
cudagraph_mode.separate_routine():
|
||||
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user