mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 12:07:12 +08:00
Override attention metadata for fast prefill in some KV sharing setups (#21590)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
366f6b3a4d
commit
ad510309ee
143
tests/v1/e2e/test_kv_sharing_fast_prefill.py
Normal file
143
tests/v1/e2e/test_kv_sharing_fast_prefill.py
Normal file
@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import gc
|
||||
import random
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration
|
||||
from vllm.model_executor.models.registry import ModelRegistry
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from ...utils import fork_new_process_for_each_test
|
||||
|
||||
|
||||
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, **kwargs)
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# attn_metadata is None during dummy runs
|
||||
if (attn_metadata is not None
|
||||
and self.cache_config.kv_sharing_fast_prefill):
|
||||
assert isinstance(attn_metadata, dict) # true in V1
|
||||
# Gemma3n-E2B has 30 layers, with last 20 layers being
|
||||
# cross-decoder layers. Check attention metadata is correct
|
||||
for layer_name, metadata in attn_metadata.items():
|
||||
layer_idx = extract_layer_index(layer_name)
|
||||
if layer_idx >= 20:
|
||||
assert hasattr(metadata, 'logits_indices_padded')
|
||||
assert hasattr(metadata, 'num_logits_indices')
|
||||
else:
|
||||
assert not hasattr(metadata, 'logits_indices_padded')
|
||||
assert not hasattr(metadata, 'num_logits_indices')
|
||||
|
||||
# Last layer will be a KV sharing layer
|
||||
layer_attn_metadata = attn_metadata[
|
||||
self.model.language_model.layers[-1].self_attn.attn.layer_name]
|
||||
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
|
||||
assert logits_indices_padded is not None
|
||||
num_logits_indices = layer_attn_metadata.num_logits_indices
|
||||
assert num_logits_indices > 0
|
||||
# Reset hidden states to random values and
|
||||
# only set logits at logits_indices to valid values
|
||||
# Because logits_indices are the only positions that are used
|
||||
# for output token sampling, this still produces same outputs
|
||||
logits_hs = hidden_states[logits_indices_padded]
|
||||
hidden_states = torch.randn_like(hidden_states)
|
||||
gen_indices = logits_indices_padded[:num_logits_indices]
|
||||
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_prompts():
|
||||
"""
|
||||
Adapted from tests/v1/e2e/test_spec_decode.py
|
||||
"""
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
# Setting higher num prompts increases the chance of numerics mismatch
|
||||
# due to matrix multiplication numerics depending on batch dimension
|
||||
num_prompts = 10
|
||||
prompts = []
|
||||
|
||||
random.seed(0)
|
||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||
|
||||
for kind in random_prompt_type_choices:
|
||||
word_choices = ["test", "temp", "hello", "where"]
|
||||
word = random.choice(word_choices)
|
||||
if kind == "repeat":
|
||||
prompt = f"""please repeat the word '{word}' 10 times."""
|
||||
elif kind == "sentence":
|
||||
prompt = f"""please give a ten-word sentence that
|
||||
uses the word {word} at least once."""
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt type: {kind}")
|
||||
prompts.append(prompt)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_kv_sharing_fast_prefill(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
enforce_eager: bool,
|
||||
test_prompts: list[str],
|
||||
):
|
||||
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
|
||||
TestGemma3nForConditionalGeneration)
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
compilation_config = CompilationConfig(
|
||||
# This allows vLLM compilation backend to handle allocating and
|
||||
# managing buffers for cudagraph
|
||||
cudagraph_copy_inputs=True,
|
||||
level=CompilationLevel.PIECEWISE
|
||||
if not enforce_eager else CompilationLevel.NO_COMPILATION)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
llm = LLM(
|
||||
model="google/gemma-3n-E2B-it",
|
||||
enforce_eager=enforce_eager,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
ref_responses = llm.generate(test_prompts, sampling_params)
|
||||
|
||||
del llm
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
llm = LLM(model="google/gemma-3n-E2B-it",
|
||||
enforce_eager=enforce_eager,
|
||||
compilation_config=compilation_config,
|
||||
kv_sharing_fast_prefill=True)
|
||||
optimized_responses = llm.generate(test_prompts, sampling_params)
|
||||
|
||||
misses = 0
|
||||
|
||||
for ref_response, optimized_response in zip(ref_responses,
|
||||
optimized_responses):
|
||||
if ref_response.outputs[0].text != optimized_response.outputs[
|
||||
0].text:
|
||||
misses += 1
|
||||
|
||||
assert misses == 0
|
||||
@ -1795,6 +1795,16 @@ class CacheConfig:
|
||||
num_cpu_blocks: Optional[int] = field(default=None, init=False)
|
||||
"""The number of blocks to allocate for CPU memory."""
|
||||
|
||||
kv_sharing_fast_prefill: bool = False
|
||||
"""This feature is work in progress and no prefill optimization takes place
|
||||
with this flag enabled currently.
|
||||
|
||||
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
|
||||
some layers can skip tokens corresponding to prefill. This flag enables
|
||||
attention metadata for eligible layers to be overriden with metadata
|
||||
necessary for implementating this optimization in some models (e.g. Gemma3n)
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
@ -1836,6 +1846,11 @@ class CacheConfig:
|
||||
"GPU memory utilization must be less than 1.0. Got "
|
||||
f"{self.gpu_memory_utilization}.")
|
||||
|
||||
if self.kv_sharing_fast_prefill:
|
||||
logger.warning_once(
|
||||
"--kv-sharing-fast-prefill is currently work in progress "
|
||||
"and not functional yet (i.e. no prefill savings)")
|
||||
|
||||
return self
|
||||
|
||||
def _verify_cache_dtype(self) -> None:
|
||||
|
||||
@ -445,6 +445,9 @@ class EngineArgs:
|
||||
# DEPRECATED
|
||||
enable_prompt_adapter: bool = False
|
||||
|
||||
kv_sharing_fast_prefill: bool = \
|
||||
CacheConfig.kv_sharing_fast_prefill
|
||||
|
||||
def __post_init__(self):
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
# without having to manually construct a
|
||||
@ -697,6 +700,8 @@ class EngineArgs:
|
||||
**cache_kwargs["cpu_offload_gb"])
|
||||
cache_group.add_argument("--calculate-kv-scales",
|
||||
**cache_kwargs["calculate_kv_scales"])
|
||||
cache_group.add_argument("--kv-sharing-fast-prefill",
|
||||
**cache_kwargs["kv_sharing_fast_prefill"])
|
||||
|
||||
# Multimodal related configs
|
||||
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||
@ -1069,6 +1074,7 @@ class EngineArgs:
|
||||
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
|
||||
cpu_offload_gb=self.cpu_offload_gb,
|
||||
calculate_kv_scales=self.calculate_kv_scales,
|
||||
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
||||
)
|
||||
|
||||
# Get the current placement group if Ray is initialized and
|
||||
|
||||
@ -793,6 +793,7 @@ class Gemma3nForConditionalGeneration(nn.Module):
|
||||
del lora_config # Unused.
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.model = Gemma3nModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.logits_processor = LogitsProcessor(
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
import abc
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar
|
||||
from dataclasses import dataclass, make_dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -508,3 +508,34 @@ def reorder_batch_to_split_decodes_and_prefills(
|
||||
modified_batch = True
|
||||
|
||||
return modified_batch
|
||||
|
||||
|
||||
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
|
||||
('logits_indices_padded', Optional[torch.Tensor], None),
|
||||
('num_logits_indices', int, 0),
|
||||
]
|
||||
|
||||
|
||||
def subclass_attention_metadata(
|
||||
name_prefix: str,
|
||||
metadata_cls: Any,
|
||||
fields: list[tuple[str, Any, Any]],
|
||||
) -> Any:
|
||||
"""
|
||||
Return a new subclass of `metadata_cls` with additional fields
|
||||
"""
|
||||
name: str = name_prefix + metadata_cls.__name__ # type: ignore
|
||||
Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
|
||||
return Wrapped
|
||||
|
||||
|
||||
def make_kv_sharing_fast_prefill_attention_metadata(
|
||||
metadata_cls: Any, ) -> Any:
|
||||
"""
|
||||
Return a new subclass of `metadata_cls` for fast prefill
|
||||
"""
|
||||
return subclass_attention_metadata(
|
||||
name_prefix="KVSharingFastPrefill",
|
||||
metadata_cls=metadata_cls,
|
||||
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import gc
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
@ -47,6 +48,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_kv_sharing_fast_prefill_attention_metadata,
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
@ -320,6 +322,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# means this layer will perform attention using the keys and values
|
||||
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
self.shared_kv_cache_layers: dict[str, str] = {}
|
||||
self.kv_sharing_fast_prefill_eligible_layers: set[str] = set()
|
||||
|
||||
self.kv_sharing_fast_prefill_logits_indices = None
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int32, device=self.device)
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""
|
||||
@ -735,6 +743,55 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
spec_decode_common_attn_metadata = None
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if not use_spec_decode:
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||
# partial requests. While we should not sample any token
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for req_id, draft_token_ids in (
|
||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
logits_indices_padded = None
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
assert self.kv_sharing_fast_prefill_logits_indices is not None
|
||||
num_logits = logits_indices.shape[0]
|
||||
assert num_logits > 0
|
||||
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
|
||||
logits_indices)
|
||||
# There might have leftover indices in logits_indices[num_logits:]
|
||||
# from previous iterations, whose values may be greater than the
|
||||
# batch size in the current iteration. To ensure indices are always
|
||||
# valid, we fill the padded indices with the last index.
|
||||
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
|
||||
logits_indices[-1].item())
|
||||
if (self.use_cuda_graph
|
||||
and num_logits <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_logits_padded = self.vllm_config.pad_for_cudagraph(
|
||||
num_logits)
|
||||
else:
|
||||
num_logits_padded = num_logits
|
||||
logits_indices_padded = (
|
||||
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
|
||||
)
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
|
||||
# Prepare encoder attention metadata separately
|
||||
@ -806,7 +863,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
|
||||
fast_prefill_metadata = attn_metadata_i
|
||||
if (self.cache_config.kv_sharing_fast_prefill
|
||||
and self.kv_sharing_fast_prefill_eligible_layers):
|
||||
# Dynamically create a a dataclass type that inherits
|
||||
# from attention metadata type but includes additional
|
||||
# fields logits_indices_padded and num_logits_indices
|
||||
# which are required for prefill truncation
|
||||
fast_prefill_metadata_type = (
|
||||
make_kv_sharing_fast_prefill_attention_metadata(
|
||||
metadata_cls=type(attn_metadata_i), ))
|
||||
fast_prefill_metadata = fast_prefill_metadata_type(
|
||||
**dataclasses.asdict(attn_metadata_i),
|
||||
logits_indices_padded=logits_indices_padded,
|
||||
num_logits_indices=logits_indices.size(0),
|
||||
)
|
||||
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
if (self.cache_config.kv_sharing_fast_prefill and layer_name
|
||||
in self.kv_sharing_fast_prefill_eligible_layers):
|
||||
attn_metadata[layer_name] = fast_prefill_metadata
|
||||
continue
|
||||
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
# Hack for now to fix chunked local attention + no hybrid kv cache
|
||||
@ -838,30 +916,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
b.can_run_in_cudagraph(common_attn_metadata)
|
||||
for b in self.attn_metadata_builders)
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if not use_spec_decode:
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||
# partial requests. While we should not sample any token
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for req_id, draft_token_ids in (
|
||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
@ -1433,6 +1487,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
spec_decode_metadata, num_scheduled_tokens_np,
|
||||
spec_decode_common_attn_metadata) = (
|
||||
self._prepare_inputs(scheduler_output))
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.use_cuda_graph
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
@ -2814,6 +2869,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_config.kv_cache_groups,
|
||||
kv_caches,
|
||||
)
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention)
|
||||
# Iterate in reversed order and add layers that re-use KV cache
|
||||
# e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
|
||||
for layer_name in reversed(attn_layers):
|
||||
if layer_name in self.shared_kv_cache_layers:
|
||||
self.kv_sharing_fast_prefill_eligible_layers.add(
|
||||
layer_name)
|
||||
else:
|
||||
break
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user