mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[V1] Enable prefill optimization for Gemma3n (#22628)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
7ffbf27239
commit
cb293f6a79
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -10,12 +9,6 @@ import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.models.gemma3n_mm import (
|
||||
Gemma3nForConditionalGeneration)
|
||||
from vllm.model_executor.models.registry import ModelRegistry
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from ...utils import fork_new_process_for_each_test
|
||||
|
||||
@ -23,54 +16,6 @@ from ...utils import fork_new_process_for_each_test
|
||||
SEED = 42
|
||||
|
||||
|
||||
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = super().forward(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds,
|
||||
**kwargs)
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# attn_metadata is None during dummy runs
|
||||
if (attn_metadata is not None
|
||||
and self.language_model.cache_config.kv_sharing_fast_prefill):
|
||||
assert isinstance(attn_metadata, dict) # true in V1
|
||||
# Gemma3n-E2B has 30 layers, with last 20 layers being
|
||||
# cross-decoder layers. Check attention metadata is correct
|
||||
for layer_name, metadata in attn_metadata.items():
|
||||
layer_idx = extract_layer_index(layer_name)
|
||||
if layer_idx >= 20:
|
||||
assert hasattr(metadata, 'logits_indices_padded')
|
||||
assert hasattr(metadata, 'num_logits_indices')
|
||||
else:
|
||||
assert not hasattr(metadata, 'logits_indices_padded')
|
||||
assert not hasattr(metadata, 'num_logits_indices')
|
||||
|
||||
# Last layer will be a KV sharing layer
|
||||
layer_attn_metadata = attn_metadata[
|
||||
self.language_model.model.layers[-1].self_attn.attn.layer_name]
|
||||
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
|
||||
assert logits_indices_padded is not None
|
||||
num_logits_indices = layer_attn_metadata.num_logits_indices
|
||||
assert num_logits_indices > 0
|
||||
# Reset hidden states to random values and
|
||||
# only set logits at logits_indices to valid values
|
||||
# Because logits_indices are the only positions that are used
|
||||
# for output token sampling, this still produces same outputs
|
||||
logits_hs = hidden_states[logits_indices_padded]
|
||||
hidden_states = torch.randn_like(hidden_states)
|
||||
gen_indices = logits_indices_padded[:num_logits_indices]
|
||||
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_prompts():
|
||||
"""
|
||||
@ -124,8 +69,6 @@ def test_kv_sharing_fast_prefill(
|
||||
enforce_eager: bool,
|
||||
test_prompts: list[str],
|
||||
):
|
||||
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
|
||||
TestGemma3nForConditionalGeneration)
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
compilation_config = CompilationConfig(
|
||||
# This allows vLLM compilation backend to handle allocating and
|
||||
|
||||
@ -145,12 +145,19 @@ class CacheConfig:
|
||||
|
||||
self._verify_cache_dtype()
|
||||
self._verify_prefix_caching()
|
||||
self._verify_kv_sharing_fast_prefill()
|
||||
|
||||
def metrics_info(self):
|
||||
# convert cache_config to dict(key: str, value: str) for prometheus
|
||||
# metrics info
|
||||
return {key: str(value) for key, value in self.__dict__.items()}
|
||||
|
||||
def _verify_kv_sharing_fast_prefill(self) -> None:
|
||||
if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Fast prefill optimization for KV sharing is not supported "
|
||||
"in V0 currently.")
|
||||
|
||||
@model_validator(mode='after')
|
||||
def _verify_args(self) -> Self:
|
||||
if self.cpu_offload_gb < 0:
|
||||
@ -162,11 +169,6 @@ class CacheConfig:
|
||||
"GPU memory utilization must be less than 1.0. Got "
|
||||
f"{self.gpu_memory_utilization}.")
|
||||
|
||||
if self.kv_sharing_fast_prefill:
|
||||
logger.warning_once(
|
||||
"--kv-sharing-fast-prefill is currently work in progress "
|
||||
"and not functional yet (i.e. no prefill savings)")
|
||||
|
||||
return self
|
||||
|
||||
def _verify_cache_dtype(self) -> None:
|
||||
|
||||
@ -23,9 +23,11 @@ from torch import nn
|
||||
from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
||||
GeluAndMul,
|
||||
@ -45,6 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
|
||||
|
||||
from .interfaces import SupportsQuant
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||
@ -533,7 +536,178 @@ class Gemma3nDecoderLayer(nn.Module):
|
||||
return corrected_predictions
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||
kv_sharing_fast_prefill)
|
||||
class Gemma3nSelfDecoder(nn.Module):
|
||||
"""
|
||||
Includes altup embedding and self decoder layers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layers: list[Gemma3nDecoderLayer],
|
||||
layer_idx_start: int,
|
||||
per_layer_model_projection: ColumnParallelLinear,
|
||||
embed_scale_per_layer: torch.Tensor,
|
||||
embed_tokens_per_layer: VocabParallelEmbedding,
|
||||
per_layer_projection_norm: RMSNorm,
|
||||
per_layer_input_scale: torch.Tensor,
|
||||
altup_projections: nn.ModuleList,
|
||||
eps: torch.Tensor,
|
||||
embed_tokens: VocabParallelEmbedding,
|
||||
embed_scale: torch.Tensor,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_layers = decoder_layers
|
||||
self.layer_idx_start = layer_idx_start
|
||||
self.per_layer_model_projection = per_layer_model_projection
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.embed_scale_per_layer = embed_scale_per_layer
|
||||
self.embed_tokens_per_layer = embed_tokens_per_layer
|
||||
self.per_layer_projection_norm = per_layer_projection_norm
|
||||
self.per_layer_input_scale = per_layer_input_scale
|
||||
self.altup_projections = altup_projections
|
||||
self.eps = eps
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def get_per_layer_input_embeddings(
|
||||
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
# Deal with the fact that vocab_size_per_layer_input < vocab_size
|
||||
# which causes us to have some out of vocab tokens by setting
|
||||
# those token ids to 0. This matches the HF implementation.
|
||||
per_layer_inputs_mask = torch.logical_and(
|
||||
input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
|
||||
per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
|
||||
torch.zeros_like(input_ids))
|
||||
return self.embed_tokens_per_layer(
|
||||
per_layer_inputs_tokens) * self.embed_scale_per_layer
|
||||
|
||||
def get_per_layer_inputs(
|
||||
self,
|
||||
hidden_states_0: torch.Tensor,
|
||||
per_layer_inputs: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*hidden_states_0.shape[:-1],
|
||||
self.config.num_hidden_layers,
|
||||
self.config.hidden_size_per_layer_input,
|
||||
)
|
||||
per_layer_projection = self.per_layer_projection_norm(
|
||||
per_layer_projection)
|
||||
if per_layer_inputs is not None:
|
||||
# Profiling run does not compute per_layer_inputs
|
||||
per_layer_inputs = per_layer_projection + per_layer_inputs
|
||||
per_layer_inputs *= self.per_layer_input_scale
|
||||
else:
|
||||
per_layer_inputs = per_layer_projection
|
||||
return per_layer_inputs
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
|
||||
# Altup embed.
|
||||
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
|
||||
keepdim=True)**0.5
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
|
||||
new_magnitude = torch.mean(hidden_states[i]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, self.eps)
|
||||
hidden_states = torch.stack(hidden_states, dim=-1)
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states_0 = inputs_embeds
|
||||
else:
|
||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||
|
||||
adjusted_per_layer_inputs = self.get_per_layer_inputs(
|
||||
hidden_states_0, per_layer_inputs)
|
||||
hidden_states = self.altup_embed(hidden_states_0)
|
||||
|
||||
# [altnum_inputs, num_tokens, hidden_size]
|
||||
hidden_states = hidden_states.permute(2, 0, 1)
|
||||
|
||||
for idx, layer in enumerate(self.decoder_layers):
|
||||
layer_idx = idx + self.layer_idx_start
|
||||
# [altup_num_inputs, num_tokens, hidden_size]
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# [num_tokens, hidden_size, altnum_inputs]
|
||||
hidden_states = hidden_states.permute(1, 2, 0)
|
||||
|
||||
return hidden_states, adjusted_per_layer_inputs
|
||||
|
||||
|
||||
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||
kv_sharing_fast_prefill)
|
||||
class Gemma3nCrossDecoder(nn.Module):
|
||||
"""
|
||||
Cross-decoder layers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layers: list[Gemma3nDecoderLayer],
|
||||
layer_idx_start: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_layers = decoder_layers
|
||||
self.layer_idx_start = layer_idx_start
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
per_layer_inputs: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# [altnum_inputs, num_tokens, hidden_size]
|
||||
hidden_states = hidden_states.permute(2, 0, 1)
|
||||
for idx, layer in enumerate(self.decoder_layers):
|
||||
layer_idx = idx + self.layer_idx_start
|
||||
# [altup_num_inputs, num_tokens, hidden_size]
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
per_layer_input=per_layer_inputs[:, layer_idx, :],
|
||||
**kwargs,
|
||||
)
|
||||
# [num_tokens, hidden_size, altnum_inputs]
|
||||
hidden_states = hidden_states.permute(1, 2, 0)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# This disables torch.compile if --kv-sharing-fast-prefill passed
|
||||
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||
cache_config.kv_sharing_fast_prefill)
|
||||
class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -543,7 +717,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
@ -613,95 +786,211 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
lambda prefix: Gemma3nDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.eps = torch.tensor(torch.finfo().min)
|
||||
|
||||
first_kv_shared_layer_idx = (config.num_hidden_layers -
|
||||
config.num_kv_shared_layers)
|
||||
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
|
||||
with set_model_tag("self_decoder"):
|
||||
self.self_decoder = Gemma3nSelfDecoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.self_decoder",
|
||||
decoder_layers=self.layers[:first_kv_shared_layer_idx],
|
||||
layer_idx_start=0,
|
||||
per_layer_model_projection=self.per_layer_model_projection,
|
||||
embed_scale_per_layer=self.embed_scale_per_layer,
|
||||
embed_tokens_per_layer=self.embed_tokens_per_layer,
|
||||
per_layer_projection_norm=self.per_layer_projection_norm,
|
||||
per_layer_input_scale=self.per_layer_input_scale,
|
||||
altup_projections=self.altup_projections,
|
||||
eps=self.eps,
|
||||
embed_tokens=self.embed_tokens,
|
||||
embed_scale=self.embed_scale,
|
||||
)
|
||||
# Layer idx 20-30 are cross-decoder layers in YOCO
|
||||
with set_model_tag("cross_decoder"):
|
||||
self.cross_decoder = Gemma3nCrossDecoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.cross_decoder",
|
||||
decoder_layers=self.layers[first_kv_shared_layer_idx:],
|
||||
layer_idx_start=first_kv_shared_layer_idx,
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.eps = torch.tensor(torch.finfo().min)
|
||||
|
||||
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
|
||||
|
||||
if self.fast_prefill_enabled:
|
||||
# Allocate static buffers for CUDAGraph
|
||||
# TODO(sarckk): Extract this functionality to interface
|
||||
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
device = next(self.parameters()).device
|
||||
self.positions = torch.zeros(max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
self.hidden_states = torch.zeros(
|
||||
(max_num_tokens, config.hidden_size,
|
||||
self.config.altup_num_inputs),
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.per_layer_inputs = torch.zeros(
|
||||
(max_num_tokens, self.config.num_hidden_layers,
|
||||
self.config.hidden_size_per_layer_input),
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.embed_scale
|
||||
return self.self_decoder.get_input_embeddings(input_ids)
|
||||
|
||||
def get_per_layer_input_embeddings(
|
||||
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
# Deal with the fact that vocab_size_per_layer_input < vocab_size
|
||||
# which causes us to have some out of vocab tokens by setting
|
||||
# those token ids to 0. This matches the HF implementation.
|
||||
per_layer_inputs_mask = torch.logical_and(
|
||||
input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
|
||||
per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
|
||||
torch.zeros_like(input_ids))
|
||||
return self.embed_tokens_per_layer(
|
||||
per_layer_inputs_tokens) * self.embed_scale_per_layer
|
||||
def fast_prefill_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
logits_indices_padded, num_logits_indices = None, None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
# attn_metadata is None during dummy runs
|
||||
if (self.fast_prefill_enabled and attn_metadata is not None):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
# Last layer is a KV sharing layer
|
||||
layer_attn_metadata = attn_metadata[
|
||||
self.layers[-1].self_attn.attn.layer_name]
|
||||
if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
|
||||
logits_indices_padded = (
|
||||
layer_attn_metadata.logits_indices_padded)
|
||||
num_logits_indices = layer_attn_metadata.num_logits_indices
|
||||
|
||||
# Copy inputs for cudagraph
|
||||
batch_size = positions.size(0)
|
||||
self.positions[:batch_size].copy_(positions)
|
||||
self_decoder_hidden_states, per_layer_inputs_adjusted = \
|
||||
self.self_decoder(
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:batch_size],
|
||||
inputs_embeds=inputs_embeds,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if logits_indices_padded is None:
|
||||
logits_indices_padded = torch.arange(
|
||||
positions.size(0),
|
||||
dtype=positions.dtype,
|
||||
device=positions.device,
|
||||
)
|
||||
|
||||
# NOTE(sarckk): There is currently a bug caused by
|
||||
# vLLM converting output of last piecewise CUDA graph
|
||||
# to weakref, causing memory to be prematurely freed
|
||||
# when there are multiple compilation units
|
||||
# Keep .clone() until fix in
|
||||
# https://github.com/vllm-project/vllm/pull/22282
|
||||
hidden_states = self_decoder_hidden_states.clone()
|
||||
|
||||
# Copy inputs for cudagraph
|
||||
num_padded_logits_indices = logits_indices_padded.size(0)
|
||||
self.positions[:num_padded_logits_indices].copy_(
|
||||
positions[logits_indices_padded])
|
||||
self.hidden_states[:num_padded_logits_indices].copy_(
|
||||
self_decoder_hidden_states[logits_indices_padded])
|
||||
self.per_layer_inputs[:num_padded_logits_indices].copy_(
|
||||
per_layer_inputs_adjusted[logits_indices_padded])
|
||||
cross_decoder_hidden_states = self.cross_decoder(
|
||||
positions=self.positions[:num_padded_logits_indices],
|
||||
hidden_states=self.hidden_states[:num_padded_logits_indices],
|
||||
per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if num_logits_indices is not None:
|
||||
assert num_logits_indices > 0
|
||||
# Merge cross-decoder and self-decoder hidden states
|
||||
hidden_states[logits_indices_padded[:num_logits_indices]] = (
|
||||
cross_decoder_hidden_states[:num_logits_indices])
|
||||
else:
|
||||
hidden_states = cross_decoder_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
def normal_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, per_layer_inputs = self.self_decoder(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.cross_decoder(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def altup_unembed(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Altup unembed.
|
||||
target_magnitude = torch.mean(hidden_states[..., 0]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
hidden_states[..., i] = self.altup_unembed_projections[i - 1](
|
||||
hidden_states[..., i])
|
||||
new_magnitude = torch.mean(hidden_states[..., i]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
hidden_states[..., i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, self.eps)
|
||||
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
|
||||
hidden_states = torch.mean(hidden_states, dim=-1)
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
per_layer_inputs: torch.Tensor,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states_0 = inputs_embeds
|
||||
else:
|
||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||
|
||||
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*hidden_states_0.shape[:-1],
|
||||
self.config.num_hidden_layers,
|
||||
self.config.hidden_size_per_layer_input,
|
||||
)
|
||||
per_layer_projection = self.per_layer_projection_norm(
|
||||
per_layer_projection)
|
||||
|
||||
if per_layer_inputs is not None:
|
||||
# Profiling run does not compute per_layer_inputs
|
||||
per_layer_inputs = per_layer_projection + per_layer_inputs
|
||||
per_layer_inputs *= self.per_layer_input_scale
|
||||
else:
|
||||
per_layer_inputs = per_layer_projection
|
||||
|
||||
# Altup embed.
|
||||
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
|
||||
keepdim=True)**0.5
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
|
||||
new_magnitude = torch.mean(hidden_states[i]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, self.eps)
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
|
||||
# Transformer blocks.
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
# [altup_num_inputs, num_tokens, hidden_size]
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
per_layer_input=per_layer_inputs[:, layer_idx, :],
|
||||
if self.fast_prefill_enabled:
|
||||
hidden_states = self.fast_prefill_forward(
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Altup unembed.
|
||||
target_magnitude = torch.mean(hidden_states[0]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
hidden_states[i] = self.altup_unembed_projections[i - 1](
|
||||
hidden_states[i])
|
||||
new_magnitude = torch.mean(hidden_states[i]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, self.eps)
|
||||
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
|
||||
hidden_states = torch.mean(hidden_states, dim=0)
|
||||
|
||||
else:
|
||||
hidden_states = self.normal_forward(
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.altup_unembed(hidden_states)
|
||||
return self.norm(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
||||
@ -620,7 +620,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
|
||||
# them here, as the model forward has only access to the input_embeds.
|
||||
if input_ids is not None:
|
||||
per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
|
||||
per_layer_inputs = self.language_model.model.self_decoder.get_per_layer_input_embeddings(
|
||||
input_ids)
|
||||
per_layer_inputs = per_layer_inputs.reshape(
|
||||
-1, self.config.text_config.num_hidden_layers,
|
||||
|
||||
@ -4,11 +4,13 @@ import abc
|
||||
import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, make_dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
|
||||
from dataclasses import dataclass, fields, make_dataclass
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol,
|
||||
TypeVar)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.utils import cdiv
|
||||
@ -19,7 +21,8 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
@ -65,6 +68,10 @@ class CommonAttentionMetadata:
|
||||
|
||||
causal: bool = True
|
||||
|
||||
# Needed by FastPrefillAttentionBuilder
|
||||
logits_indices_padded: Optional[torch.Tensor] = None
|
||||
num_logits_indices: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UbatchSlice:
|
||||
@ -542,6 +549,69 @@ def make_local_attention_virtual_batches(
|
||||
)
|
||||
|
||||
|
||||
def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> CommonAttentionMetadata:
|
||||
if common_attn_metadata.max_query_len == 1:
|
||||
# All requests are decode (assume 1 token for now)
|
||||
# Skip computing fast prefill path
|
||||
return common_attn_metadata
|
||||
|
||||
assert common_attn_metadata.logits_indices_padded is not None
|
||||
assert common_attn_metadata.num_logits_indices is not None
|
||||
|
||||
logits_indices_padded = common_attn_metadata.logits_indices_padded
|
||||
num_logits_indices = common_attn_metadata.num_logits_indices
|
||||
# Get rid of CUDAGraph padding, if any
|
||||
logits_indices = logits_indices_padded[:num_logits_indices]
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
# Example inputs
|
||||
# num_reqs: 3
|
||||
# generation_indices: [14, 18, 19, 27]
|
||||
# query_start_loc: [0, 15, 20, 28]
|
||||
# seq_lens: [41, 31, 40]
|
||||
|
||||
# Find how many decode indices belong to each request
|
||||
# request_ids: [0, 1, 1, 2]
|
||||
request_ids = torch.bucketize(logits_indices,
|
||||
query_start_loc[1:],
|
||||
right=True)
|
||||
|
||||
# Figure out how many tokens are in each request
|
||||
# num_decode_tokens: [1, 2, 1]
|
||||
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
|
||||
|
||||
# Calculate new query_start_loc with tokens in generation_indices
|
||||
# decode_query_start_loc: [0, 1, 3, 4]
|
||||
decode_query_start_loc = torch.empty(num_reqs + 1,
|
||||
device=query_start_loc.device,
|
||||
dtype=query_start_loc.dtype)
|
||||
|
||||
decode_query_start_loc[0] = 0
|
||||
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
|
||||
decode_max_query_len = int(num_decode_tokens.max().item())
|
||||
total_num_decode_tokens = int(num_decode_tokens.sum().item())
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=decode_query_start_loc,
|
||||
query_start_loc_cpu=decode_query_start_loc.to("cpu",
|
||||
non_blocking=True),
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
|
||||
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_decode_tokens,
|
||||
max_query_len=decode_max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
return common_attn_metadata
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str, attention_backend_cls: type[AttentionBackend],
|
||||
builder_cls: type[AttentionMetadataBuilder[M]]
|
||||
@ -679,13 +749,56 @@ def subclass_attention_metadata(
|
||||
return Wrapped
|
||||
|
||||
|
||||
def make_kv_sharing_fast_prefill_attention_metadata(
|
||||
metadata_cls: Any, ) -> Any:
|
||||
"""
|
||||
Return a new subclass of `metadata_cls` for fast prefill
|
||||
"""
|
||||
return subclass_attention_metadata(
|
||||
name_prefix="KVSharingFastPrefill",
|
||||
metadata_cls=metadata_cls,
|
||||
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
|
||||
)
|
||||
@runtime_checkable
|
||||
class KVSharingFastPrefillMetadata(Protocol):
|
||||
logits_indices_padded: torch.Tensor
|
||||
num_logits_indices: int
|
||||
|
||||
|
||||
def create_fast_prefill_custom_backend(
|
||||
prefix: str,
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
) -> type[AttentionBackend]:
|
||||
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
new_common_attn_metadata =\
|
||||
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
|
||||
metadata = super().build(common_prefix_len,
|
||||
new_common_attn_metadata, fast_build)
|
||||
|
||||
class KVSharingFastPrefillAttentionMetadata(
|
||||
metadata.__class__, # type: ignore
|
||||
KVSharingFastPrefillMetadata):
|
||||
|
||||
def __init__(self, metadata, common_attn_metadata):
|
||||
# Shallow copy all fields in metadata cls
|
||||
for field in fields(metadata.__class__):
|
||||
setattr(self, field.name,
|
||||
getattr(metadata, field.name))
|
||||
|
||||
# Set additional fields that will be used in model code
|
||||
assert (common_attn_metadata.logits_indices_padded
|
||||
is not None
|
||||
and common_attn_metadata.num_logits_indices
|
||||
is not None)
|
||||
self.logits_indices_padded = \
|
||||
common_attn_metadata.logits_indices_padded
|
||||
self.num_logits_indices = \
|
||||
common_attn_metadata.num_logits_indices
|
||||
|
||||
return KVSharingFastPrefillAttentionMetadata(
|
||||
metadata, common_attn_metadata)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=FastPrefillAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
@ -335,6 +335,13 @@ class AsyncLLM(EngineClient):
|
||||
returning the RequestOutput back to the caller.
|
||||
"""
|
||||
|
||||
if (self.vllm_config.cache_config.kv_sharing_fast_prefill
|
||||
and sampling_params.prompt_logprobs):
|
||||
raise ValueError(
|
||||
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||
"prompt tokens, please disable it when the requests need "
|
||||
"prompt logprobs")
|
||||
|
||||
try:
|
||||
# We start the output_handler on the first call to generate() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import gc
|
||||
import itertools
|
||||
import time
|
||||
@ -58,7 +57,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
supports_dynamo)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_kv_sharing_fast_prefill_attention_metadata,
|
||||
create_fast_prefill_custom_backend,
|
||||
reorder_batch_to_split_decodes_and_prefills)
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
@ -84,9 +83,10 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
|
||||
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
from .utils import (AttentionGroup, MultiModalBudget,
|
||||
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
|
||||
gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
@ -860,6 +860,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=blk_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
logits_indices_padded=logits_indices_padded,
|
||||
num_logits_indices=logits_indices.size(0),
|
||||
causal=True,
|
||||
)
|
||||
|
||||
@ -884,28 +886,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
|
||||
fast_prefill_metadata = attn_metadata_i
|
||||
if (self.cache_config.kv_sharing_fast_prefill
|
||||
and self.kv_sharing_fast_prefill_eligible_layers):
|
||||
# Dynamically create a a dataclass type that inherits
|
||||
# from attention metadata type but includes additional
|
||||
# fields logits_indices_padded and num_logits_indices
|
||||
# which are required for prefill truncation
|
||||
fast_prefill_metadata_type = (
|
||||
make_kv_sharing_fast_prefill_attention_metadata(
|
||||
metadata_cls=type(attn_metadata_i), ))
|
||||
fast_prefill_metadata = fast_prefill_metadata_type(
|
||||
**dataclasses.asdict(attn_metadata_i),
|
||||
logits_indices_padded=logits_indices_padded,
|
||||
num_logits_indices=logits_indices.size(0),
|
||||
)
|
||||
|
||||
for layer_name in attn_group.layer_names:
|
||||
if (self.cache_config.kv_sharing_fast_prefill
|
||||
and layer_name
|
||||
in self.kv_sharing_fast_prefill_eligible_layers):
|
||||
attn_metadata[layer_name] = fast_prefill_metadata
|
||||
continue
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
# Hot-Swap lora model
|
||||
@ -1484,6 +1465,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return self.kv_connector_no_forward(scheduler_output,
|
||||
self.vllm_config)
|
||||
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
assert not self.input_batch.num_prompt_logprobs, (
|
||||
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||
"prompt tokens, tokens, please disable it when the requests "
|
||||
"need prompt logprobs")
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||
@ -2742,6 +2729,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# layer.
|
||||
for layer_name in layer_names:
|
||||
attn_backend = layers[layer_name].get_attn_backend()
|
||||
|
||||
if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
|
||||
attn_backend = create_fast_prefill_custom_backend(
|
||||
"FastPrefill",
|
||||
attn_backend,
|
||||
)
|
||||
|
||||
key = attn_backend.full_cls_name()
|
||||
attn_backends[key] = attn_backend
|
||||
attn_backend_layers[key].append(layer_name)
|
||||
@ -3074,20 +3068,40 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
|
||||
# Setup `kv_cache_config` and `kv_caches` for models
|
||||
# with cross-layer KV sharing
|
||||
if self.shared_kv_cache_layers:
|
||||
initialize_kv_cache_for_kv_sharing(
|
||||
self.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
kv_caches,
|
||||
self.attn_groups,
|
||||
self.runner_only_attn_layers,
|
||||
)
|
||||
# Set up cross-layer KV cache sharing
|
||||
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
|
||||
):
|
||||
logger.debug("%s reuses KV cache of %s", layer_name,
|
||||
target_layer_name)
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
|
||||
self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Add layers that re-use KV cache to KV cache group of its target layer.
|
||||
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
|
||||
"""
|
||||
if not self.shared_kv_cache_layers:
|
||||
# No cross-layer KV sharing, return
|
||||
return
|
||||
|
||||
add_kv_sharing_layers_to_kv_cache_groups(
|
||||
self.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
self.runner_only_attn_layers,
|
||||
)
|
||||
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
|
||||
# similar KV sharing setups, only the layers that generate KV caches
|
||||
# are involved in the prefill phase, enabling prefill to early exit.
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention)
|
||||
# Iterate in reversed order and add layers that re-use KV cache
|
||||
# e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
|
||||
for layer_name in reversed(attn_layers):
|
||||
if layer_name in self.shared_kv_cache_layers:
|
||||
self.kv_sharing_fast_prefill_eligible_layers.add(
|
||||
@ -3095,11 +3109,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
break
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
@ -3111,6 +3120,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
|
||||
@ -55,9 +55,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from .utils import (MultiModalBudget, bind_kv_cache,
|
||||
initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs)
|
||||
from .utils import (MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups,
|
||||
bind_kv_cache, sanity_check_mm_encoder_outputs)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -1599,6 +1598,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
def maybe_setup_cross_layer_kv_sharing(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Add layers that re-use KV cache to KV cache group of its target layer.
|
||||
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
|
||||
"""
|
||||
if not self.shared_kv_cache_layers:
|
||||
# No cross-layer KV sharing, return
|
||||
return
|
||||
|
||||
add_kv_sharing_layers_to_kv_cache_groups(
|
||||
self.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
)
|
||||
|
||||
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
|
||||
):
|
||||
logger.debug("%s reuses KV cache of %s", layer_name,
|
||||
target_layer_name)
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
@ -1664,14 +1687,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Setup `kv_cache_config` and `kv_caches` for models
|
||||
# with cross-layer KV sharing
|
||||
if self.shared_kv_cache_layers:
|
||||
initialize_kv_cache_for_kv_sharing(
|
||||
self.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
kv_caches,
|
||||
)
|
||||
# Set up cross-layer KV cache sharing if needed
|
||||
self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config)
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
|
||||
@ -203,12 +203,9 @@ def gather_mm_placeholders(
|
||||
return placeholders[is_embed]
|
||||
|
||||
|
||||
def initialize_kv_cache_for_kv_sharing(
|
||||
def add_kv_sharing_layers_to_kv_cache_groups(
|
||||
shared_kv_cache_layers: dict[str, str],
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
# Optional for now to avoid breaking TPU
|
||||
attn_groups: Optional[list[list[AttentionGroup]]] = None,
|
||||
runner_only_attn_layers: Optional[set[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -223,38 +220,15 @@ def initialize_kv_cache_for_kv_sharing(
|
||||
means this layer will perform attention using the keys and values
|
||||
from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
kv_cache_groups: The KV cache groups of the model.
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
Note that layers in shared_kv_cache_layers.keys() are not
|
||||
originally included as it only contains layers which have its own
|
||||
KV cache allocation.
|
||||
attn_groups: Optional list of attention groups. Layers in the same KV
|
||||
cache group may be placed in different attention groups if they
|
||||
have different attention backends. Currently only provided by
|
||||
GPU model runner.
|
||||
"""
|
||||
# mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx)
|
||||
layer_to_attn_group_idx: dict[str, tuple[int, int]] = {}
|
||||
if attn_groups:
|
||||
for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups):
|
||||
for attn_group_idx, attn_group in enumerate(kv_attn_groups):
|
||||
for layer_name in attn_group.layer_names:
|
||||
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx,
|
||||
attn_group_idx)
|
||||
else:
|
||||
for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups):
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
# attn group idx default to 0 if not provided
|
||||
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0)
|
||||
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
|
||||
for kv_cache_group in kv_cache_groups:
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
layer_to_kv_cache_group[layer_name] = kv_cache_group
|
||||
|
||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0]
|
||||
kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name)
|
||||
|
||||
if attn_groups:
|
||||
attn_group_idx = layer_to_attn_group_idx[target_layer_name][1]
|
||||
attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
|
||||
layer_name)
|
||||
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
|
||||
tgt_kv_cache_group.layer_names.append(layer_name)
|
||||
|
||||
if runner_only_attn_layers is not None:
|
||||
runner_only_attn_layers.add(layer_name)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user