Override attention metadata for fast prefill in some KV sharing setups (#21590)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin 2025-07-30 08:54:15 -07:00 committed by GitHub
parent 366f6b3a4d
commit ad510309ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 287 additions and 26 deletions

View File

@ -0,0 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import random
from typing import Optional, Union
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.forward_context import get_forward_context
from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration
from vllm.model_executor.models.registry import ModelRegistry
from vllm.model_executor.models.utils import extract_layer_index
from vllm.sequence import IntermediateTensors
from ...utils import fork_new_process_for_each_test
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, **kwargs)
attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs
if (attn_metadata is not None
and self.cache_config.kv_sharing_fast_prefill):
assert isinstance(attn_metadata, dict) # true in V1
# Gemma3n-E2B has 30 layers, with last 20 layers being
# cross-decoder layers. Check attention metadata is correct
for layer_name, metadata in attn_metadata.items():
layer_idx = extract_layer_index(layer_name)
if layer_idx >= 20:
assert hasattr(metadata, 'logits_indices_padded')
assert hasattr(metadata, 'num_logits_indices')
else:
assert not hasattr(metadata, 'logits_indices_padded')
assert not hasattr(metadata, 'num_logits_indices')
# Last layer will be a KV sharing layer
layer_attn_metadata = attn_metadata[
self.model.language_model.layers[-1].self_attn.attn.layer_name]
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
assert logits_indices_padded is not None
num_logits_indices = layer_attn_metadata.num_logits_indices
assert num_logits_indices > 0
# Reset hidden states to random values and
# only set logits at logits_indices to valid values
# Because logits_indices are the only positions that are used
# for output token sampling, this still produces same outputs
logits_hs = hidden_states[logits_indices_padded]
hidden_states = torch.randn_like(hidden_states)
gen_indices = logits_indices_padded[:num_logits_indices]
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
return hidden_states
@pytest.fixture
def test_prompts():
"""
Adapted from tests/v1/e2e/test_spec_decode.py
"""
prompt_types = ["repeat", "sentence"]
# Setting higher num prompts increases the chance of numerics mismatch
# due to matrix multiplication numerics depending on batch dimension
num_prompts = 10
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
if kind == "repeat":
prompt = f"""please repeat the word '{word}' 10 times."""
elif kind == "sentence":
prompt = f"""please give a ten-word sentence that
uses the word {word} at least once."""
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append(prompt)
return prompts
@fork_new_process_for_each_test
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_kv_sharing_fast_prefill(
monkeypatch: pytest.MonkeyPatch,
enforce_eager: bool,
test_prompts: list[str],
):
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
TestGemma3nForConditionalGeneration)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
compilation_config = CompilationConfig(
# This allows vLLM compilation backend to handle allocating and
# managing buffers for cudagraph
cudagraph_copy_inputs=True,
level=CompilationLevel.PIECEWISE
if not enforce_eager else CompilationLevel.NO_COMPILATION)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
)
ref_responses = llm.generate(test_prompts, sampling_params)
del llm
gc.collect()
torch.cuda.empty_cache()
llm = LLM(model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
kv_sharing_fast_prefill=True)
optimized_responses = llm.generate(test_prompts, sampling_params)
misses = 0
for ref_response, optimized_response in zip(ref_responses,
optimized_responses):
if ref_response.outputs[0].text != optimized_response.outputs[
0].text:
misses += 1
assert misses == 0

View File

@ -1795,6 +1795,16 @@ class CacheConfig:
num_cpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory."""
kv_sharing_fast_prefill: bool = False
"""This feature is work in progress and no prefill optimization takes place
with this flag enabled currently.
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
some layers can skip tokens corresponding to prefill. This flag enables
attention metadata for eligible layers to be overriden with metadata
necessary for implementating this optimization in some models (e.g. Gemma3n)
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@ -1836,6 +1846,11 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
if self.kv_sharing_fast_prefill:
logger.warning_once(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)")
return self
def _verify_cache_dtype(self) -> None:

View File

@ -445,6 +445,9 @@ class EngineArgs:
# DEPRECATED
enable_prompt_adapter: bool = False
kv_sharing_fast_prefill: bool = \
CacheConfig.kv_sharing_fast_prefill
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
@ -697,6 +700,8 @@ class EngineArgs:
**cache_kwargs["cpu_offload_gb"])
cache_group.add_argument("--calculate-kv-scales",
**cache_kwargs["calculate_kv_scales"])
cache_group.add_argument("--kv-sharing-fast-prefill",
**cache_kwargs["kv_sharing_fast_prefill"])
# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
@ -1069,6 +1074,7 @@ class EngineArgs:
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
)
# Get the current placement group if Ray is initialized and

View File

@ -793,6 +793,7 @@ class Gemma3nForConditionalGeneration(nn.Module):
del lora_config # Unused.
super().__init__()
self.config = config
self.cache_config = vllm_config.cache_config
self.model = Gemma3nModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(

View File

@ -3,8 +3,8 @@
import abc
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
import numpy as np
import torch
@ -508,3 +508,34 @@ def reorder_batch_to_split_decodes_and_prefills(
modified_batch = True
return modified_batch
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0),
]
def subclass_attention_metadata(
name_prefix: str,
metadata_cls: Any,
fields: list[tuple[str, Any, Any]],
) -> Any:
"""
Return a new subclass of `metadata_cls` with additional fields
"""
name: str = name_prefix + metadata_cls.__name__ # type: ignore
Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
return Wrapped
def make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls: Any, ) -> Any:
"""
Return a new subclass of `metadata_cls` for fast prefill
"""
return subclass_attention_metadata(
name_prefix="KVSharingFastPrefill",
metadata_cls=metadata_cls,
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import gc
import time
from contextlib import contextmanager
@ -47,6 +48,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
make_local_attention_virtual_batches)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec,
@ -320,6 +322,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}
self.kv_sharing_fast_prefill_eligible_layers: set[str] = set()
self.kv_sharing_fast_prefill_logits_indices = None
if self.cache_config.kv_sharing_fast_prefill:
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=self.device)
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
@ -735,6 +743,55 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_common_attn_metadata = None
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
logits_indices_padded = None
if self.cache_config.kv_sharing_fast_prefill:
assert self.kv_sharing_fast_prefill_logits_indices is not None
num_logits = logits_indices.shape[0]
assert num_logits > 0
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
logits_indices)
# There might have leftover indices in logits_indices[num_logits:]
# from previous iterations, whose values may be greater than the
# batch size in the current iteration. To ensure indices are always
# valid, we fill the padded indices with the last index.
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item())
if (self.use_cuda_graph
and num_logits <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_logits_padded = self.vllm_config.pad_for_cudagraph(
num_logits)
else:
num_logits_padded = num_logits
logits_indices_padded = (
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
)
attn_metadata: dict[str, Any] = {}
# Prepare encoder attention metadata separately
@ -806,7 +863,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_attn_metadata=common_attn_metadata,
))
fast_prefill_metadata = attn_metadata_i
if (self.cache_config.kv_sharing_fast_prefill
and self.kv_sharing_fast_prefill_eligible_layers):
# Dynamically create a a dataclass type that inherits
# from attention metadata type but includes additional
# fields logits_indices_padded and num_logits_indices
# which are required for prefill truncation
fast_prefill_metadata_type = (
make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls=type(attn_metadata_i), ))
fast_prefill_metadata = fast_prefill_metadata_type(
**dataclasses.asdict(attn_metadata_i),
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
)
for layer_name in kv_cache_group_spec.layer_names:
if (self.cache_config.kv_sharing_fast_prefill and layer_name
in self.kv_sharing_fast_prefill_eligible_layers):
attn_metadata[layer_name] = fast_prefill_metadata
continue
attn_metadata[layer_name] = attn_metadata_i
# Hack for now to fix chunked local attention + no hybrid kv cache
@ -838,30 +916,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
b.can_run_in_cudagraph(common_attn_metadata)
for b in self.attn_metadata_builders)
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
@ -1433,6 +1487,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_metadata, num_scheduled_tokens_np,
spec_decode_common_attn_metadata) = (
self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
@ -2814,6 +2869,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config.kv_cache_groups,
kv_caches,
)
attn_layers = get_layers_from_vllm_config(self.vllm_config,
Attention)
# Iterate in reversed order and add layers that re-use KV cache
# e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
for layer_name in reversed(attn_layers):
if layer_name in self.shared_kv_cache_layers:
self.kv_sharing_fast_prefill_eligible_layers.add(
layer_name)
else:
break
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,