mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders (#17483)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
4c31218f80
commit
950751a987
@ -221,6 +221,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
@ -310,6 +311,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
@ -318,6 +320,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
|
||||
@ -1,14 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
def initialize_kv_cache(runner: GPUModelRunner):
|
||||
"""
|
||||
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
|
||||
"""
|
||||
kv_cache_spec = FullAttentionSpec(block_size=16,
|
||||
num_kv_heads=1,
|
||||
head_size=64,
|
||||
dtype=torch.float16,
|
||||
use_mla=False)
|
||||
runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_runner():
|
||||
scheduler_config = SchedulerConfig(
|
||||
@ -38,7 +55,9 @@ def model_runner():
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
return GPUModelRunner(vllm_config, device)
|
||||
runner = GPUModelRunner(vllm_config, device)
|
||||
initialize_kv_cache(runner)
|
||||
return runner
|
||||
|
||||
|
||||
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
|
||||
@ -19,6 +19,8 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -167,7 +169,7 @@ def make_local_attention_virtual_batches(
|
||||
query_start_loc_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table: torch.Tensor,
|
||||
page_size: int = 0,
|
||||
block_size: int = 0,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
@ -238,14 +240,14 @@ def make_local_attention_virtual_batches(
|
||||
# For the example the local attention blocks start at:
|
||||
# _b0_ _____b1_____ _b2_
|
||||
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||
block_starts = k_seqstarts_absolute // page_size
|
||||
assert attn_chunk_size % page_size == 0, \
|
||||
block_starts = k_seqstarts_absolute // block_size
|
||||
assert attn_chunk_size % block_size == 0, \
|
||||
f"attn_chunk_size {attn_chunk_size} is not " \
|
||||
f"divisible by page_size {page_size}"
|
||||
pages_per_local_batch = attn_chunk_size // page_size
|
||||
f"divisible by block_size {block_size}"
|
||||
pages_per_local_batch = attn_chunk_size // block_size
|
||||
|
||||
# Create a block_table for the local attention blocks
|
||||
# For out example if we have a block-table like (assuming page_size=2):
|
||||
# For out example if we have a block-table like (assuming block_size=2):
|
||||
# block_table = [
|
||||
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||
@ -289,7 +291,8 @@ def _get_sliding_window_configs(
|
||||
|
||||
class FlashAttentionMetadataBuilder:
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner"):
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
model_config = runner.model_config
|
||||
compilation_config = runner.vllm_config.compilation_config
|
||||
|
||||
@ -299,7 +302,9 @@ class FlashAttentionMetadataBuilder:
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
runner.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.page_size = self.runner.block_size
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
if get_flash_attn_version() == 3:
|
||||
self.aot_schedule = not compilation_config.full_cuda_graph
|
||||
@ -323,9 +328,17 @@ class FlashAttentionMetadataBuilder:
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
@ -354,7 +367,7 @@ class FlashAttentionMetadataBuilder:
|
||||
num_heads_q=self.num_heads_q,
|
||||
num_heads_kv=self.num_heads_kv,
|
||||
headdim=self.headdim,
|
||||
page_size=self.page_size,
|
||||
page_size=self.block_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
window_size=self.aot_sliding_window,
|
||||
@ -365,12 +378,12 @@ class FlashAttentionMetadataBuilder:
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table = make_local_attention_virtual_batches(
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
block_table,
|
||||
self.runner.block_size,
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
@ -389,7 +402,7 @@ class FlashAttentionMetadataBuilder:
|
||||
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
local_seqused_k=local_seqused_k,
|
||||
local_block_table=virt_block_table,
|
||||
local_block_table=virt_block_table_tensor,
|
||||
local_max_query_len=local_max_query_len,
|
||||
local_max_seq_len=local_max_seq_len,
|
||||
local_scheduler_metadata=local_scheduler_metadata,
|
||||
@ -440,7 +453,7 @@ class FlashAttentionMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
|
||||
@ -19,6 +19,8 @@ from vllm.config import (VllmConfig, get_current_vllm_config,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -202,7 +204,8 @@ class FlashInferMetadata:
|
||||
|
||||
class FlashInferMetadataBuilder:
|
||||
|
||||
def __init__(self, runner: GPUModelRunner):
|
||||
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
self.runner = runner
|
||||
self._workspace_buffer = None
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
@ -213,6 +216,8 @@ class FlashInferMetadataBuilder:
|
||||
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
@ -400,13 +405,12 @@ class FlashInferMetadataBuilder:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
assert (self._num_decode_tokens +
|
||||
self._num_prefill_tokens == num_actual_tokens)
|
||||
page_size = self.runner.block_size
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
device = self.runner.device
|
||||
qo_indptr = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
@ -422,12 +426,13 @@ class FlashInferMetadataBuilder:
|
||||
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
shared_kv_page_indices = block_table[0, :num_common_kv_blocks]
|
||||
shared_kv_page_indices = block_table_tensor[
|
||||
0, :num_common_kv_blocks]
|
||||
shared_kv_last_page_len = torch.tensor([page_size],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
# Remove the blocks of the shared prefix from all requests.
|
||||
block_table = block_table[:, num_common_kv_blocks:]
|
||||
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
||||
block_table_bounds -= num_common_kv_blocks
|
||||
else:
|
||||
shared_qo_indptr = None
|
||||
@ -435,11 +440,11 @@ class FlashInferMetadataBuilder:
|
||||
shared_kv_page_indices = None
|
||||
shared_kv_last_page_len = None
|
||||
|
||||
mask = (torch.arange(block_table.size(1),
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device).unsqueeze(0)
|
||||
mask = (torch.arange(block_table_tensor.size(1),
|
||||
dtype=block_table_tensor.dtype,
|
||||
device=block_table_tensor.device).unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
paged_kv_indices = block_table[mask]
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_indptr = torch.cat([
|
||||
torch.zeros(1,
|
||||
@ -459,10 +464,10 @@ class FlashInferMetadataBuilder:
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
num_qo_heads=self.runner.num_query_heads,
|
||||
num_kv_heads=self.runner.num_kv_heads,
|
||||
head_dim=self.runner.head_size,
|
||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||
head_dim=self.kv_cache_spec.head_size,
|
||||
page_size=page_size,
|
||||
data_type=self.runner.kv_cache_dtype,
|
||||
data_type=self.kv_cache_spec.dtype,
|
||||
q_data_type=self.runner.dtype,
|
||||
slot_mapping=slot_mapping,
|
||||
num_decodes=self._num_decodes,
|
||||
@ -481,7 +486,7 @@ class FlashInferMetadataBuilder:
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
if self.runner.kv_cache_dtype != self.runner.model_config.dtype:
|
||||
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
# kv cache dtype to something different from query dtype.
|
||||
return False
|
||||
|
||||
@ -207,6 +207,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
@ -334,6 +336,8 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
|
||||
def __init__(self,
|
||||
runner: "GPUModelRunner",
|
||||
kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable,
|
||||
metadata_cls: Optional[type[M]] = None):
|
||||
self.metadata_cls = metadata_cls \
|
||||
if metadata_cls is not None else MLACommonMetadata
|
||||
@ -346,10 +350,11 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
runner.parallel_config)
|
||||
self.mla_dims = get_mla_dims(model_config)
|
||||
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
# Dont try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
self.page_size = self.runner.block_size
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
@ -375,6 +380,7 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
dtype=model_config.dtype,
|
||||
device=runner.device,
|
||||
)
|
||||
self.block_table = block_table
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
@ -436,9 +442,10 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _build_decode(self, block_table: torch.Tensor, seq_lens: torch.Tensor):
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor):
|
||||
return MLACommonDecodeMetadata(
|
||||
block_table=block_table,
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
@ -451,9 +458,9 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.runner.device
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
@ -530,7 +537,7 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
self.chunked_prefill_workspace_size
|
||||
|
||||
prefill_metadata = MLACommonPrefillMetadata(
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
block_table=block_table_tensor[reqs_start:, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=chunked_context_metadata,
|
||||
@ -539,7 +546,7 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
decode_metadata = None
|
||||
if self._num_decodes > 0:
|
||||
decode_metadata = self._build_decode(
|
||||
block_table=block_table[:self._num_decodes, ...],
|
||||
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
|
||||
seq_lens=seq_lens[:self._num_decodes],
|
||||
)
|
||||
|
||||
|
||||
@ -16,6 +16,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -52,13 +54,14 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
def __init__(self, runner):
|
||||
super().__init__(runner)
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
def _build_decode(self, block_table: torch.Tensor,
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
@ -68,7 +71,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
)
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table,
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
|
||||
@ -14,6 +14,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@ -59,8 +61,9 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
|
||||
def __init__(self, runner):
|
||||
super().__init__(runner)
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table)
|
||||
max_model_len = self.runner.model_config.max_model_len
|
||||
assert max_model_len == 32768,\
|
||||
"AITER MLA requires max_model_len=32768"
|
||||
|
||||
@ -14,11 +14,13 @@ class BlockTable:
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
@ -36,6 +38,15 @@ class BlockTable:
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
block_ids: list[int],
|
||||
|
||||
@ -59,6 +59,7 @@ class InputBatch:
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
@ -66,6 +67,7 @@ class InputBatch:
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
@ -100,6 +102,7 @@ class InputBatch:
|
||||
self.block_table = BlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@ -150,8 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
f"FA3. Current attention backend is {attn_backend_name}, "
|
||||
f"FlashAttention version is {flash_attn_version}.")
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self))
|
||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||
|
||||
# Multi-modal data support
|
||||
@ -174,6 +172,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Initialize in initialize_kv_cache
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
# self.attn_metadata_builder: type[AttentionMetadataBuilder]
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
@ -203,6 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=model_config.get_vocab_size(),
|
||||
@ -291,11 +291,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.positions_np = self.positions_cpu.numpy()
|
||||
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
@ -586,7 +581,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_offsets = positions_np % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
out=self.input_batch.block_table.
|
||||
slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
@ -614,12 +610,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
||||
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
self.slot_mapping[:total_num_scheduled_tokens].copy_(
|
||||
self.slot_mapping_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache
|
||||
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||
|
||||
@ -1821,6 +1813,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self),
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||||
self.input_batch.block_table)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
|
||||
@ -179,6 +179,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.vocab_size,
|
||||
@ -197,10 +198,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
device="cpu")
|
||||
self.positions_np = self.positions_cpu.numpy()
|
||||
|
||||
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(self.max_num_reqs, self.max_num_blocks_per_req),
|
||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||
@ -533,7 +530,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_offsets = positions_np % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
out=self.input_batch.block_table.
|
||||
slot_mapping_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
@ -557,10 +555,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.position_ids = self.positions_cpu[:
|
||||
padded_total_num_scheduled_tokens].to(
|
||||
self.device)
|
||||
self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID
|
||||
slot_mapping = self.slot_mapping_cpu[:
|
||||
padded_total_num_scheduled_tokens].to(
|
||||
self.device)
|
||||
self.input_batch.block_table.slot_mapping_cpu[
|
||||
total_num_scheduled_tokens:] = _PAD_SLOT_ID
|
||||
slot_mapping = (
|
||||
self.input_batch.block_table.
|
||||
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
|
||||
self.device))
|
||||
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
||||
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user