[V1] Enable prefill optimization for Gemma3n (#22628)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin 2025-08-28 14:54:30 -07:00 committed by GitHub
parent 7ffbf27239
commit cb293f6a79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 591 additions and 236 deletions

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from typing import Optional, Union
import pytest
import torch
@ -10,12 +9,6 @@ import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.forward_context import get_forward_context
from vllm.model_executor.models.gemma3n_mm import (
Gemma3nForConditionalGeneration)
from vllm.model_executor.models.registry import ModelRegistry
from vllm.model_executor.models.utils import extract_layer_index
from vllm.sequence import IntermediateTensors
from ...utils import fork_new_process_for_each_test
@ -23,54 +16,6 @@ from ...utils import fork_new_process_for_each_test
SEED = 42
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = super().forward(input_ids, positions,
intermediate_tensors, inputs_embeds,
**kwargs)
attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs
if (attn_metadata is not None
and self.language_model.cache_config.kv_sharing_fast_prefill):
assert isinstance(attn_metadata, dict) # true in V1
# Gemma3n-E2B has 30 layers, with last 20 layers being
# cross-decoder layers. Check attention metadata is correct
for layer_name, metadata in attn_metadata.items():
layer_idx = extract_layer_index(layer_name)
if layer_idx >= 20:
assert hasattr(metadata, 'logits_indices_padded')
assert hasattr(metadata, 'num_logits_indices')
else:
assert not hasattr(metadata, 'logits_indices_padded')
assert not hasattr(metadata, 'num_logits_indices')
# Last layer will be a KV sharing layer
layer_attn_metadata = attn_metadata[
self.language_model.model.layers[-1].self_attn.attn.layer_name]
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
assert logits_indices_padded is not None
num_logits_indices = layer_attn_metadata.num_logits_indices
assert num_logits_indices > 0
# Reset hidden states to random values and
# only set logits at logits_indices to valid values
# Because logits_indices are the only positions that are used
# for output token sampling, this still produces same outputs
logits_hs = hidden_states[logits_indices_padded]
hidden_states = torch.randn_like(hidden_states)
gen_indices = logits_indices_padded[:num_logits_indices]
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
return hidden_states
@pytest.fixture
def test_prompts():
"""
@ -124,8 +69,6 @@ def test_kv_sharing_fast_prefill(
enforce_eager: bool,
test_prompts: list[str],
):
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
TestGemma3nForConditionalGeneration)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
compilation_config = CompilationConfig(
# This allows vLLM compilation backend to handle allocating and

View File

@ -145,12 +145,19 @@ class CacheConfig:
self._verify_cache_dtype()
self._verify_prefix_caching()
self._verify_kv_sharing_fast_prefill()
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_kv_sharing_fast_prefill(self) -> None:
if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1:
raise NotImplementedError(
"Fast prefill optimization for KV sharing is not supported "
"in V0 currently.")
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.cpu_offload_gb < 0:
@ -162,11 +169,6 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
if self.kv_sharing_fast_prefill:
logger.warning_once(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)")
return self
def _verify_cache_dtype(self) -> None:

View File

@ -23,9 +23,11 @@ from torch import nn
from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig
from vllm.attention import Attention
from vllm.compilation.backends import set_model_tag
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
GeluAndMul,
@ -45,6 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
from .interfaces import SupportsQuant
from .utils import (AutoWeightsLoader, extract_layer_index,
@ -533,7 +536,178 @@ class Gemma3nDecoderLayer(nn.Module):
return corrected_predictions
@support_torch_compile
# This enables torch.compile if --kv-sharing-fast-prefill passed
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
kv_sharing_fast_prefill)
class Gemma3nSelfDecoder(nn.Module):
"""
Includes altup embedding and self decoder layers
"""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma3nDecoderLayer],
layer_idx_start: int,
per_layer_model_projection: ColumnParallelLinear,
embed_scale_per_layer: torch.Tensor,
embed_tokens_per_layer: VocabParallelEmbedding,
per_layer_projection_norm: RMSNorm,
per_layer_input_scale: torch.Tensor,
altup_projections: nn.ModuleList,
eps: torch.Tensor,
embed_tokens: VocabParallelEmbedding,
embed_scale: torch.Tensor,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
self.per_layer_model_projection = per_layer_model_projection
self.config = vllm_config.model_config.hf_config
self.embed_scale_per_layer = embed_scale_per_layer
self.embed_tokens_per_layer = embed_tokens_per_layer
self.per_layer_projection_norm = per_layer_projection_norm
self.per_layer_input_scale = per_layer_input_scale
self.altup_projections = altup_projections
self.eps = eps
self.embed_tokens = embed_tokens
self.embed_scale = embed_scale
def get_per_layer_input_embeddings(
self, input_ids: torch.Tensor) -> torch.Tensor:
# Deal with the fact that vocab_size_per_layer_input < vocab_size
# which causes us to have some out of vocab tokens by setting
# those token ids to 0. This matches the HF implementation.
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
torch.zeros_like(input_ids))
return self.embed_tokens_per_layer(
per_layer_inputs_tokens) * self.embed_scale_per_layer
def get_per_layer_inputs(
self,
hidden_states_0: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor],
) -> torch.Tensor:
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
per_layer_projection = per_layer_projection.reshape(
*hidden_states_0.shape[:-1],
self.config.num_hidden_layers,
self.config.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(
per_layer_projection)
if per_layer_inputs is not None:
# Profiling run does not compute per_layer_inputs
per_layer_inputs = per_layer_projection + per_layer_inputs
per_layer_inputs *= self.per_layer_input_scale
else:
per_layer_inputs = per_layer_projection
return per_layer_inputs
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.embed_scale
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
# Altup embed.
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
keepdim=True)**0.5
for i in range(1, self.config.altup_num_inputs):
hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
new_magnitude = torch.mean(hidden_states[i]**2,
dim=-1,
keepdim=True)**0.5
hidden_states[i] *= target_magnitude / torch.maximum(
new_magnitude, self.eps)
hidden_states = torch.stack(hidden_states, dim=-1)
return hidden_states
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
per_layer_inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
hidden_states_0 = inputs_embeds
else:
hidden_states_0 = self.get_input_embeddings(input_ids)
adjusted_per_layer_inputs = self.get_per_layer_inputs(
hidden_states_0, per_layer_inputs)
hidden_states = self.altup_embed(hidden_states_0)
# [altnum_inputs, num_tokens, hidden_size]
hidden_states = hidden_states.permute(2, 0, 1)
for idx, layer in enumerate(self.decoder_layers):
layer_idx = idx + self.layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
**kwargs,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states = hidden_states.permute(1, 2, 0)
return hidden_states, adjusted_per_layer_inputs
# This enables torch.compile if --kv-sharing-fast-prefill passed
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
kv_sharing_fast_prefill)
class Gemma3nCrossDecoder(nn.Module):
"""
Cross-decoder layers
"""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma3nDecoderLayer],
layer_idx_start: int,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# [altnum_inputs, num_tokens, hidden_size]
hidden_states = hidden_states.permute(2, 0, 1)
for idx, layer in enumerate(self.decoder_layers):
layer_idx = idx + self.layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
per_layer_input=per_layer_inputs[:, layer_idx, :],
**kwargs,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states = hidden_states.permute(1, 2, 0)
return hidden_states
# This disables torch.compile if --kv-sharing-fast-prefill passed
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
cache_config.kv_sharing_fast_prefill)
class Gemma3nTextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -543,7 +717,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
@ -613,95 +786,211 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
lambda prefix: Gemma3nDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.eps = torch.tensor(torch.finfo().min)
first_kv_shared_layer_idx = (config.num_hidden_layers -
config.num_kv_shared_layers)
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
with set_model_tag("self_decoder"):
self.self_decoder = Gemma3nSelfDecoder(
vllm_config=vllm_config,
prefix=f"{prefix}.self_decoder",
decoder_layers=self.layers[:first_kv_shared_layer_idx],
layer_idx_start=0,
per_layer_model_projection=self.per_layer_model_projection,
embed_scale_per_layer=self.embed_scale_per_layer,
embed_tokens_per_layer=self.embed_tokens_per_layer,
per_layer_projection_norm=self.per_layer_projection_norm,
per_layer_input_scale=self.per_layer_input_scale,
altup_projections=self.altup_projections,
eps=self.eps,
embed_tokens=self.embed_tokens,
embed_scale=self.embed_scale,
)
# Layer idx 20-30 are cross-decoder layers in YOCO
with set_model_tag("cross_decoder"):
self.cross_decoder = Gemma3nCrossDecoder(
vllm_config=vllm_config,
prefix=f"{prefix}.cross_decoder",
decoder_layers=self.layers[first_kv_shared_layer_idx:],
layer_idx_start=first_kv_shared_layer_idx,
)
self.norm = RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
)
self.eps = torch.tensor(torch.finfo().min)
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
if self.fast_prefill_enabled:
# Allocate static buffers for CUDAGraph
# TODO(sarckk): Extract this functionality to interface
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
device = next(self.parameters()).device
self.positions = torch.zeros(max_num_tokens,
dtype=torch.int64,
device=device)
self.hidden_states = torch.zeros(
(max_num_tokens, config.hidden_size,
self.config.altup_num_inputs),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
self.per_layer_inputs = torch.zeros(
(max_num_tokens, self.config.num_hidden_layers,
self.config.hidden_size_per_layer_input),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.embed_scale
return self.self_decoder.get_input_embeddings(input_ids)
def get_per_layer_input_embeddings(
self, input_ids: torch.Tensor) -> torch.Tensor:
# Deal with the fact that vocab_size_per_layer_input < vocab_size
# which causes us to have some out of vocab tokens by setting
# those token ids to 0. This matches the HF implementation.
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
torch.zeros_like(input_ids))
return self.embed_tokens_per_layer(
per_layer_inputs_tokens) * self.embed_scale_per_layer
def fast_prefill_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
per_layer_inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
logits_indices_padded, num_logits_indices = None, None
attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs
if (self.fast_prefill_enabled and attn_metadata is not None):
assert isinstance(attn_metadata, dict)
# Last layer is a KV sharing layer
layer_attn_metadata = attn_metadata[
self.layers[-1].self_attn.attn.layer_name]
if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
logits_indices_padded = (
layer_attn_metadata.logits_indices_padded)
num_logits_indices = layer_attn_metadata.num_logits_indices
# Copy inputs for cudagraph
batch_size = positions.size(0)
self.positions[:batch_size].copy_(positions)
self_decoder_hidden_states, per_layer_inputs_adjusted = \
self.self_decoder(
input_ids=input_ids,
positions=self.positions[:batch_size],
inputs_embeds=inputs_embeds,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
if logits_indices_padded is None:
logits_indices_padded = torch.arange(
positions.size(0),
dtype=positions.dtype,
device=positions.device,
)
# NOTE(sarckk): There is currently a bug caused by
# vLLM converting output of last piecewise CUDA graph
# to weakref, causing memory to be prematurely freed
# when there are multiple compilation units
# Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states = self_decoder_hidden_states.clone()
# Copy inputs for cudagraph
num_padded_logits_indices = logits_indices_padded.size(0)
self.positions[:num_padded_logits_indices].copy_(
positions[logits_indices_padded])
self.hidden_states[:num_padded_logits_indices].copy_(
self_decoder_hidden_states[logits_indices_padded])
self.per_layer_inputs[:num_padded_logits_indices].copy_(
per_layer_inputs_adjusted[logits_indices_padded])
cross_decoder_hidden_states = self.cross_decoder(
positions=self.positions[:num_padded_logits_indices],
hidden_states=self.hidden_states[:num_padded_logits_indices],
per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices],
**kwargs,
)
if num_logits_indices is not None:
assert num_logits_indices > 0
# Merge cross-decoder and self-decoder hidden states
hidden_states[logits_indices_padded[:num_logits_indices]] = (
cross_decoder_hidden_states[:num_logits_indices])
else:
hidden_states = cross_decoder_hidden_states
return hidden_states
def normal_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
per_layer_inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states, per_layer_inputs = self.self_decoder(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
hidden_states = self.cross_decoder(
positions=positions,
hidden_states=hidden_states,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
return hidden_states
def altup_unembed(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# Altup unembed.
target_magnitude = torch.mean(hidden_states[..., 0]**2,
dim=-1,
keepdim=True)**0.5
for i in range(1, self.config.altup_num_inputs):
hidden_states[..., i] = self.altup_unembed_projections[i - 1](
hidden_states[..., i])
new_magnitude = torch.mean(hidden_states[..., i]**2,
dim=-1,
keepdim=True)**0.5
hidden_states[..., i] *= target_magnitude / torch.maximum(
new_magnitude, self.eps)
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
hidden_states = torch.mean(hidden_states, dim=-1)
return hidden_states
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
per_layer_inputs: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None:
hidden_states_0 = inputs_embeds
else:
hidden_states_0 = self.get_input_embeddings(input_ids)
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
per_layer_projection = per_layer_projection.reshape(
*hidden_states_0.shape[:-1],
self.config.num_hidden_layers,
self.config.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(
per_layer_projection)
if per_layer_inputs is not None:
# Profiling run does not compute per_layer_inputs
per_layer_inputs = per_layer_projection + per_layer_inputs
per_layer_inputs *= self.per_layer_input_scale
else:
per_layer_inputs = per_layer_projection
# Altup embed.
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
keepdim=True)**0.5
for i in range(1, self.config.altup_num_inputs):
hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
new_magnitude = torch.mean(hidden_states[i]**2,
dim=-1,
keepdim=True)**0.5
hidden_states[i] *= target_magnitude / torch.maximum(
new_magnitude, self.eps)
hidden_states = torch.stack(hidden_states, dim=0)
# Transformer blocks.
for layer_idx, layer in enumerate(self.layers):
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
per_layer_input=per_layer_inputs[:, layer_idx, :],
if self.fast_prefill_enabled:
hidden_states = self.fast_prefill_forward(
input_ids,
positions,
inputs_embeds,
per_layer_inputs,
**kwargs,
)
# Altup unembed.
target_magnitude = torch.mean(hidden_states[0]**2,
dim=-1,
keepdim=True)**0.5
for i in range(1, self.config.altup_num_inputs):
hidden_states[i] = self.altup_unembed_projections[i - 1](
hidden_states[i])
new_magnitude = torch.mean(hidden_states[i]**2,
dim=-1,
keepdim=True)**0.5
hidden_states[i] *= target_magnitude / torch.maximum(
new_magnitude, self.eps)
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
hidden_states = torch.mean(hidden_states, dim=0)
else:
hidden_states = self.normal_forward(
input_ids,
positions,
inputs_embeds,
per_layer_inputs,
**kwargs,
)
hidden_states = self.altup_unembed(hidden_states)
return self.norm(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,

View File

@ -620,7 +620,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds.
if input_ids is not None:
per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
per_layer_inputs = self.language_model.model.self_decoder.get_per_layer_input_embeddings(
input_ids)
per_layer_inputs = per_layer_inputs.reshape(
-1, self.config.text_config.num_hidden_layers,

View File

@ -4,11 +4,13 @@ import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
from dataclasses import dataclass, fields, make_dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol,
TypeVar)
import numpy as np
import torch
from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv
@ -19,7 +21,8 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.layer import Attention
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
@ -65,6 +68,10 @@ class CommonAttentionMetadata:
causal: bool = True
# Needed by FastPrefillAttentionBuilder
logits_indices_padded: Optional[torch.Tensor] = None
num_logits_indices: Optional[int] = None
@dataclass
class UbatchSlice:
@ -542,6 +549,69 @@ def make_local_attention_virtual_batches(
)
def make_kv_sharing_fast_prefill_common_attn_metadata(
common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
if common_attn_metadata.max_query_len == 1:
# All requests are decode (assume 1 token for now)
# Skip computing fast prefill path
return common_attn_metadata
assert common_attn_metadata.logits_indices_padded is not None
assert common_attn_metadata.num_logits_indices is not None
logits_indices_padded = common_attn_metadata.logits_indices_padded
num_logits_indices = common_attn_metadata.num_logits_indices
# Get rid of CUDAGraph padding, if any
logits_indices = logits_indices_padded[:num_logits_indices]
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
# Example inputs
# num_reqs: 3
# generation_indices: [14, 18, 19, 27]
# query_start_loc: [0, 15, 20, 28]
# seq_lens: [41, 31, 40]
# Find how many decode indices belong to each request
# request_ids: [0, 1, 1, 2]
request_ids = torch.bucketize(logits_indices,
query_start_loc[1:],
right=True)
# Figure out how many tokens are in each request
# num_decode_tokens: [1, 2, 1]
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
# Calculate new query_start_loc with tokens in generation_indices
# decode_query_start_loc: [0, 1, 3, 4]
decode_query_start_loc = torch.empty(num_reqs + 1,
device=query_start_loc.device,
dtype=query_start_loc.dtype)
decode_query_start_loc[0] = 0
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
decode_max_query_len = int(num_decode_tokens.max().item())
total_num_decode_tokens = int(num_decode_tokens.sum().item())
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=decode_query_start_loc,
query_start_loc_cpu=decode_query_start_loc.to("cpu",
non_blocking=True),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_decode_tokens,
max_query_len=decode_max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
)
return common_attn_metadata
def subclass_attention_backend(
name_prefix: str, attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]]
@ -679,13 +749,56 @@ def subclass_attention_metadata(
return Wrapped
def make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls: Any, ) -> Any:
"""
Return a new subclass of `metadata_cls` for fast prefill
"""
return subclass_attention_metadata(
name_prefix="KVSharingFastPrefill",
metadata_cls=metadata_cls,
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
)
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
logits_indices_padded: torch.Tensor
num_logits_indices: int
def create_fast_prefill_custom_backend(
prefix: str,
underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
underlying_builder = underlying_attn_backend.get_builder_cls()
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_common_attn_metadata =\
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
metadata = super().build(common_prefix_len,
new_common_attn_metadata, fast_build)
class KVSharingFastPrefillAttentionMetadata(
metadata.__class__, # type: ignore
KVSharingFastPrefillMetadata):
def __init__(self, metadata, common_attn_metadata):
# Shallow copy all fields in metadata cls
for field in fields(metadata.__class__):
setattr(self, field.name,
getattr(metadata, field.name))
# Set additional fields that will be used in model code
assert (common_attn_metadata.logits_indices_padded
is not None
and common_attn_metadata.num_logits_indices
is not None)
self.logits_indices_padded = \
common_attn_metadata.logits_indices_padded
self.num_logits_indices = \
common_attn_metadata.num_logits_indices
return KVSharingFastPrefillAttentionMetadata(
metadata, common_attn_metadata)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=FastPrefillAttentionBuilder)
return attn_backend

View File

@ -335,6 +335,13 @@ class AsyncLLM(EngineClient):
returning the RequestOutput back to the caller.
"""
if (self.vllm_config.cache_config.kv_sharing_fast_prefill
and sampling_params.prompt_logprobs):
raise ValueError(
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, please disable it when the requests need "
"prompt logprobs")
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import gc
import itertools
import time
@ -58,7 +57,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
supports_dynamo)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
create_fast_prefill_custom_backend,
reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (AttentionSpec,
@ -84,9 +83,10 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
from .utils import (AttentionGroup, MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
if TYPE_CHECKING:
import xgrammar as xgr
@ -860,6 +860,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_seq_len=max_seq_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping,
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
causal=True,
)
@ -884,28 +886,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_attn_metadata=common_attn_metadata,
))
fast_prefill_metadata = attn_metadata_i
if (self.cache_config.kv_sharing_fast_prefill
and self.kv_sharing_fast_prefill_eligible_layers):
# Dynamically create a a dataclass type that inherits
# from attention metadata type but includes additional
# fields logits_indices_padded and num_logits_indices
# which are required for prefill truncation
fast_prefill_metadata_type = (
make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls=type(attn_metadata_i), ))
fast_prefill_metadata = fast_prefill_metadata_type(
**dataclasses.asdict(attn_metadata_i),
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
)
for layer_name in attn_group.layer_names:
if (self.cache_config.kv_sharing_fast_prefill
and layer_name
in self.kv_sharing_fast_prefill_eligible_layers):
attn_metadata[layer_name] = fast_prefill_metadata
continue
attn_metadata[layer_name] = attn_metadata_i
# Hot-Swap lora model
@ -1484,6 +1465,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.input_batch.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, tokens, please disable it when the requests "
"need prompt logprobs")
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
@ -2742,6 +2729,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# layer.
for layer_name in layer_names:
attn_backend = layers[layer_name].get_attn_backend()
if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
attn_backend = create_fast_prefill_custom_backend(
"FastPrefill",
attn_backend,
)
key = attn_backend.full_cls_name()
attn_backends[key] = attn_backend
attn_backend_layers[key].append(layer_name)
@ -3074,20 +3068,40 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
kv_cache_raw_tensors)
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
if self.shared_kv_cache_layers:
initialize_kv_cache_for_kv_sharing(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
kv_caches,
self.attn_groups,
self.runner_only_attn_layers,
)
# Set up cross-layer KV cache sharing
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
):
logger.debug("%s reuses KV cache of %s", layer_name,
target_layer_name)
kv_caches[layer_name] = kv_caches[target_layer_name]
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
self, kv_cache_config: KVCacheConfig) -> None:
"""
Add layers that re-use KV cache to KV cache group of its target layer.
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
"""
if not self.shared_kv_cache_layers:
# No cross-layer KV sharing, return
return
add_kv_sharing_layers_to_kv_cache_groups(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
self.runner_only_attn_layers,
)
if self.cache_config.kv_sharing_fast_prefill:
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
# similar KV sharing setups, only the layers that generate KV caches
# are involved in the prefill phase, enabling prefill to early exit.
attn_layers = get_layers_from_vllm_config(self.vllm_config,
Attention)
# Iterate in reversed order and add layers that re-use KV cache
# e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
for layer_name in reversed(attn_layers):
if layer_name in self.shared_kv_cache_layers:
self.kv_sharing_fast_prefill_eligible_layers.add(
@ -3095,11 +3109,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
break
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
@ -3111,6 +3120,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)

View File

@ -55,9 +55,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
from .utils import (MultiModalBudget, bind_kv_cache,
initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs)
from .utils import (MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache, sanity_check_mm_encoder_outputs)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@ -1599,6 +1598,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.encoder_cache.clear()
gc.collect()
def maybe_setup_cross_layer_kv_sharing(
self,
kv_caches: dict[str, torch.Tensor],
kv_cache_config: KVCacheConfig,
) -> None:
"""
Add layers that re-use KV cache to KV cache group of its target layer.
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
"""
if not self.shared_kv_cache_layers:
# No cross-layer KV sharing, return
return
add_kv_sharing_layers_to_kv_cache_groups(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
)
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
):
logger.debug("%s reuses KV cache of %s", layer_name,
target_layer_name)
kv_caches[layer_name] = kv_caches[target_layer_name]
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
@ -1664,14 +1687,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
raise NotImplementedError
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
if self.shared_kv_cache_layers:
initialize_kv_cache_for_kv_sharing(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
kv_caches,
)
# Set up cross-layer KV cache sharing if needed
self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config)
bind_kv_cache(
kv_caches,

View File

@ -203,12 +203,9 @@ def gather_mm_placeholders(
return placeholders[is_embed]
def initialize_kv_cache_for_kv_sharing(
def add_kv_sharing_layers_to_kv_cache_groups(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],
kv_caches: dict[str, torch.Tensor],
# Optional for now to avoid breaking TPU
attn_groups: Optional[list[list[AttentionGroup]]] = None,
runner_only_attn_layers: Optional[set[str]] = None,
) -> None:
"""
@ -223,38 +220,15 @@ def initialize_kv_cache_for_kv_sharing(
means this layer will perform attention using the keys and values
from the KV cache of `shared_kv_cache_layers[layer_name]`.
kv_cache_groups: The KV cache groups of the model.
kv_caches: The allocated kv_caches with layer names as keys.
Note that layers in shared_kv_cache_layers.keys() are not
originally included as it only contains layers which have its own
KV cache allocation.
attn_groups: Optional list of attention groups. Layers in the same KV
cache group may be placed in different attention groups if they
have different attention backends. Currently only provided by
GPU model runner.
"""
# mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx)
layer_to_attn_group_idx: dict[str, tuple[int, int]] = {}
if attn_groups:
for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups):
for attn_group_idx, attn_group in enumerate(kv_attn_groups):
for layer_name in attn_group.layer_names:
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx,
attn_group_idx)
else:
for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
# attn group idx default to 0 if not provided
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0)
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
for kv_cache_group in kv_cache_groups:
for layer_name in kv_cache_group.layer_names:
layer_to_kv_cache_group[layer_name] = kv_cache_group
for layer_name, target_layer_name in shared_kv_cache_layers.items():
kv_caches[layer_name] = kv_caches[target_layer_name]
kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0]
kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name)
if attn_groups:
attn_group_idx = layer_to_attn_group_idx[target_layer_name][1]
attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
layer_name)
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
tgt_kv_cache_group.layer_names.append(layer_name)
if runner_only_attn_layers is not None:
runner_only_attn_layers.add(layer_name)