mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:15:35 +08:00
[V1] Enable prefill optimization for Gemma3n (#22628)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
7ffbf27239
commit
cb293f6a79
@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -10,12 +9,6 @@ import torch
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import CompilationConfig, CompilationLevel
|
from vllm.config import CompilationConfig, CompilationLevel
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.models.gemma3n_mm import (
|
|
||||||
Gemma3nForConditionalGeneration)
|
|
||||||
from vllm.model_executor.models.registry import ModelRegistry
|
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
from ...utils import fork_new_process_for_each_test
|
from ...utils import fork_new_process_for_each_test
|
||||||
|
|
||||||
@ -23,54 +16,6 @@ from ...utils import fork_new_process_for_each_test
|
|||||||
SEED = 42
|
SEED = 42
|
||||||
|
|
||||||
|
|
||||||
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
hidden_states = super().forward(input_ids, positions,
|
|
||||||
intermediate_tensors, inputs_embeds,
|
|
||||||
**kwargs)
|
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
# attn_metadata is None during dummy runs
|
|
||||||
if (attn_metadata is not None
|
|
||||||
and self.language_model.cache_config.kv_sharing_fast_prefill):
|
|
||||||
assert isinstance(attn_metadata, dict) # true in V1
|
|
||||||
# Gemma3n-E2B has 30 layers, with last 20 layers being
|
|
||||||
# cross-decoder layers. Check attention metadata is correct
|
|
||||||
for layer_name, metadata in attn_metadata.items():
|
|
||||||
layer_idx = extract_layer_index(layer_name)
|
|
||||||
if layer_idx >= 20:
|
|
||||||
assert hasattr(metadata, 'logits_indices_padded')
|
|
||||||
assert hasattr(metadata, 'num_logits_indices')
|
|
||||||
else:
|
|
||||||
assert not hasattr(metadata, 'logits_indices_padded')
|
|
||||||
assert not hasattr(metadata, 'num_logits_indices')
|
|
||||||
|
|
||||||
# Last layer will be a KV sharing layer
|
|
||||||
layer_attn_metadata = attn_metadata[
|
|
||||||
self.language_model.model.layers[-1].self_attn.attn.layer_name]
|
|
||||||
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
|
|
||||||
assert logits_indices_padded is not None
|
|
||||||
num_logits_indices = layer_attn_metadata.num_logits_indices
|
|
||||||
assert num_logits_indices > 0
|
|
||||||
# Reset hidden states to random values and
|
|
||||||
# only set logits at logits_indices to valid values
|
|
||||||
# Because logits_indices are the only positions that are used
|
|
||||||
# for output token sampling, this still produces same outputs
|
|
||||||
logits_hs = hidden_states[logits_indices_padded]
|
|
||||||
hidden_states = torch.randn_like(hidden_states)
|
|
||||||
gen_indices = logits_indices_padded[:num_logits_indices]
|
|
||||||
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_prompts():
|
def test_prompts():
|
||||||
"""
|
"""
|
||||||
@ -124,8 +69,6 @@ def test_kv_sharing_fast_prefill(
|
|||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
test_prompts: list[str],
|
test_prompts: list[str],
|
||||||
):
|
):
|
||||||
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
|
|
||||||
TestGemma3nForConditionalGeneration)
|
|
||||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
# This allows vLLM compilation backend to handle allocating and
|
# This allows vLLM compilation backend to handle allocating and
|
||||||
|
|||||||
@ -145,12 +145,19 @@ class CacheConfig:
|
|||||||
|
|
||||||
self._verify_cache_dtype()
|
self._verify_cache_dtype()
|
||||||
self._verify_prefix_caching()
|
self._verify_prefix_caching()
|
||||||
|
self._verify_kv_sharing_fast_prefill()
|
||||||
|
|
||||||
def metrics_info(self):
|
def metrics_info(self):
|
||||||
# convert cache_config to dict(key: str, value: str) for prometheus
|
# convert cache_config to dict(key: str, value: str) for prometheus
|
||||||
# metrics info
|
# metrics info
|
||||||
return {key: str(value) for key, value in self.__dict__.items()}
|
return {key: str(value) for key, value in self.__dict__.items()}
|
||||||
|
|
||||||
|
def _verify_kv_sharing_fast_prefill(self) -> None:
|
||||||
|
if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Fast prefill optimization for KV sharing is not supported "
|
||||||
|
"in V0 currently.")
|
||||||
|
|
||||||
@model_validator(mode='after')
|
@model_validator(mode='after')
|
||||||
def _verify_args(self) -> Self:
|
def _verify_args(self) -> Self:
|
||||||
if self.cpu_offload_gb < 0:
|
if self.cpu_offload_gb < 0:
|
||||||
@ -162,11 +169,6 @@ class CacheConfig:
|
|||||||
"GPU memory utilization must be less than 1.0. Got "
|
"GPU memory utilization must be less than 1.0. Got "
|
||||||
f"{self.gpu_memory_utilization}.")
|
f"{self.gpu_memory_utilization}.")
|
||||||
|
|
||||||
if self.kv_sharing_fast_prefill:
|
|
||||||
logger.warning_once(
|
|
||||||
"--kv-sharing-fast-prefill is currently work in progress "
|
|
||||||
"and not functional yet (i.e. no prefill savings)")
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _verify_cache_dtype(self) -> None:
|
def _verify_cache_dtype(self) -> None:
|
||||||
|
|||||||
@ -23,9 +23,11 @@ from torch import nn
|
|||||||
from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig
|
from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
|
from vllm.compilation.backends import set_model_tag
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
||||||
GeluAndMul,
|
GeluAndMul,
|
||||||
@ -45,6 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
|
||||||
|
|
||||||
from .interfaces import SupportsQuant
|
from .interfaces import SupportsQuant
|
||||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||||
@ -533,7 +536,178 @@ class Gemma3nDecoderLayer(nn.Module):
|
|||||||
return corrected_predictions
|
return corrected_predictions
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||||
|
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||||
|
kv_sharing_fast_prefill)
|
||||||
|
class Gemma3nSelfDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Includes altup embedding and self decoder layers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
decoder_layers: list[Gemma3nDecoderLayer],
|
||||||
|
layer_idx_start: int,
|
||||||
|
per_layer_model_projection: ColumnParallelLinear,
|
||||||
|
embed_scale_per_layer: torch.Tensor,
|
||||||
|
embed_tokens_per_layer: VocabParallelEmbedding,
|
||||||
|
per_layer_projection_norm: RMSNorm,
|
||||||
|
per_layer_input_scale: torch.Tensor,
|
||||||
|
altup_projections: nn.ModuleList,
|
||||||
|
eps: torch.Tensor,
|
||||||
|
embed_tokens: VocabParallelEmbedding,
|
||||||
|
embed_scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.layer_idx_start = layer_idx_start
|
||||||
|
self.per_layer_model_projection = per_layer_model_projection
|
||||||
|
self.config = vllm_config.model_config.hf_config
|
||||||
|
self.embed_scale_per_layer = embed_scale_per_layer
|
||||||
|
self.embed_tokens_per_layer = embed_tokens_per_layer
|
||||||
|
self.per_layer_projection_norm = per_layer_projection_norm
|
||||||
|
self.per_layer_input_scale = per_layer_input_scale
|
||||||
|
self.altup_projections = altup_projections
|
||||||
|
self.eps = eps
|
||||||
|
self.embed_tokens = embed_tokens
|
||||||
|
self.embed_scale = embed_scale
|
||||||
|
|
||||||
|
def get_per_layer_input_embeddings(
|
||||||
|
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Deal with the fact that vocab_size_per_layer_input < vocab_size
|
||||||
|
# which causes us to have some out of vocab tokens by setting
|
||||||
|
# those token ids to 0. This matches the HF implementation.
|
||||||
|
per_layer_inputs_mask = torch.logical_and(
|
||||||
|
input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
|
||||||
|
per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
|
||||||
|
torch.zeros_like(input_ids))
|
||||||
|
return self.embed_tokens_per_layer(
|
||||||
|
per_layer_inputs_tokens) * self.embed_scale_per_layer
|
||||||
|
|
||||||
|
def get_per_layer_inputs(
|
||||||
|
self,
|
||||||
|
hidden_states_0: torch.Tensor,
|
||||||
|
per_layer_inputs: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
||||||
|
per_layer_projection = per_layer_projection.reshape(
|
||||||
|
*hidden_states_0.shape[:-1],
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
self.config.hidden_size_per_layer_input,
|
||||||
|
)
|
||||||
|
per_layer_projection = self.per_layer_projection_norm(
|
||||||
|
per_layer_projection)
|
||||||
|
if per_layer_inputs is not None:
|
||||||
|
# Profiling run does not compute per_layer_inputs
|
||||||
|
per_layer_inputs = per_layer_projection + per_layer_inputs
|
||||||
|
per_layer_inputs *= self.per_layer_input_scale
|
||||||
|
else:
|
||||||
|
per_layer_inputs = per_layer_projection
|
||||||
|
return per_layer_inputs
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
|
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Altup embed.
|
||||||
|
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
||||||
|
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
|
||||||
|
keepdim=True)**0.5
|
||||||
|
for i in range(1, self.config.altup_num_inputs):
|
||||||
|
hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
|
||||||
|
new_magnitude = torch.mean(hidden_states[i]**2,
|
||||||
|
dim=-1,
|
||||||
|
keepdim=True)**0.5
|
||||||
|
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||||
|
new_magnitude, self.eps)
|
||||||
|
hidden_states = torch.stack(hidden_states, dim=-1)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states_0 = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
adjusted_per_layer_inputs = self.get_per_layer_inputs(
|
||||||
|
hidden_states_0, per_layer_inputs)
|
||||||
|
hidden_states = self.altup_embed(hidden_states_0)
|
||||||
|
|
||||||
|
# [altnum_inputs, num_tokens, hidden_size]
|
||||||
|
hidden_states = hidden_states.permute(2, 0, 1)
|
||||||
|
|
||||||
|
for idx, layer in enumerate(self.decoder_layers):
|
||||||
|
layer_idx = idx + self.layer_idx_start
|
||||||
|
# [altup_num_inputs, num_tokens, hidden_size]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# [num_tokens, hidden_size, altnum_inputs]
|
||||||
|
hidden_states = hidden_states.permute(1, 2, 0)
|
||||||
|
|
||||||
|
return hidden_states, adjusted_per_layer_inputs
|
||||||
|
|
||||||
|
|
||||||
|
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||||
|
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||||
|
kv_sharing_fast_prefill)
|
||||||
|
class Gemma3nCrossDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Cross-decoder layers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
decoder_layers: list[Gemma3nDecoderLayer],
|
||||||
|
layer_idx_start: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.layer_idx_start = layer_idx_start
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
per_layer_inputs: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# [altnum_inputs, num_tokens, hidden_size]
|
||||||
|
hidden_states = hidden_states.permute(2, 0, 1)
|
||||||
|
for idx, layer in enumerate(self.decoder_layers):
|
||||||
|
layer_idx = idx + self.layer_idx_start
|
||||||
|
# [altup_num_inputs, num_tokens, hidden_size]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
per_layer_input=per_layer_inputs[:, layer_idx, :],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# [num_tokens, hidden_size, altnum_inputs]
|
||||||
|
hidden_states = hidden_states.permute(1, 2, 0)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# This disables torch.compile if --kv-sharing-fast-prefill passed
|
||||||
|
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||||
|
cache_config.kv_sharing_fast_prefill)
|
||||||
class Gemma3nTextModel(nn.Module, SupportsQuant):
|
class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
@ -543,7 +717,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
@ -613,95 +786,211 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
lambda prefix: Gemma3nDecoderLayer(
|
lambda prefix: Gemma3nDecoderLayer(
|
||||||
config, cache_config, quant_config, prefix=prefix),
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
|
self.eps = torch.tensor(torch.finfo().min)
|
||||||
|
|
||||||
|
first_kv_shared_layer_idx = (config.num_hidden_layers -
|
||||||
|
config.num_kv_shared_layers)
|
||||||
|
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
|
||||||
|
with set_model_tag("self_decoder"):
|
||||||
|
self.self_decoder = Gemma3nSelfDecoder(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=f"{prefix}.self_decoder",
|
||||||
|
decoder_layers=self.layers[:first_kv_shared_layer_idx],
|
||||||
|
layer_idx_start=0,
|
||||||
|
per_layer_model_projection=self.per_layer_model_projection,
|
||||||
|
embed_scale_per_layer=self.embed_scale_per_layer,
|
||||||
|
embed_tokens_per_layer=self.embed_tokens_per_layer,
|
||||||
|
per_layer_projection_norm=self.per_layer_projection_norm,
|
||||||
|
per_layer_input_scale=self.per_layer_input_scale,
|
||||||
|
altup_projections=self.altup_projections,
|
||||||
|
eps=self.eps,
|
||||||
|
embed_tokens=self.embed_tokens,
|
||||||
|
embed_scale=self.embed_scale,
|
||||||
|
)
|
||||||
|
# Layer idx 20-30 are cross-decoder layers in YOCO
|
||||||
|
with set_model_tag("cross_decoder"):
|
||||||
|
self.cross_decoder = Gemma3nCrossDecoder(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=f"{prefix}.cross_decoder",
|
||||||
|
decoder_layers=self.layers[first_kv_shared_layer_idx:],
|
||||||
|
layer_idx_start=first_kv_shared_layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
self.norm = RMSNorm(
|
self.norm = RMSNorm(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
self.eps = torch.tensor(torch.finfo().min)
|
|
||||||
|
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
|
||||||
|
|
||||||
|
if self.fast_prefill_enabled:
|
||||||
|
# Allocate static buffers for CUDAGraph
|
||||||
|
# TODO(sarckk): Extract this functionality to interface
|
||||||
|
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
self.positions = torch.zeros(max_num_tokens,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device)
|
||||||
|
self.hidden_states = torch.zeros(
|
||||||
|
(max_num_tokens, config.hidden_size,
|
||||||
|
self.config.altup_num_inputs),
|
||||||
|
dtype=self.embed_tokens.weight.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.per_layer_inputs = torch.zeros(
|
||||||
|
(max_num_tokens, self.config.num_hidden_layers,
|
||||||
|
self.config.hidden_size_per_layer_input),
|
||||||
|
dtype=self.embed_tokens.weight.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embed_tokens(input_ids) * self.embed_scale
|
return self.self_decoder.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def get_per_layer_input_embeddings(
|
def fast_prefill_forward(
|
||||||
self, input_ids: torch.Tensor) -> torch.Tensor:
|
self,
|
||||||
# Deal with the fact that vocab_size_per_layer_input < vocab_size
|
input_ids: torch.Tensor,
|
||||||
# which causes us to have some out of vocab tokens by setting
|
positions: torch.Tensor,
|
||||||
# those token ids to 0. This matches the HF implementation.
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
per_layer_inputs_mask = torch.logical_and(
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
|
**kwargs,
|
||||||
per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
|
) -> torch.Tensor:
|
||||||
torch.zeros_like(input_ids))
|
logits_indices_padded, num_logits_indices = None, None
|
||||||
return self.embed_tokens_per_layer(
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
per_layer_inputs_tokens) * self.embed_scale_per_layer
|
|
||||||
|
# attn_metadata is None during dummy runs
|
||||||
|
if (self.fast_prefill_enabled and attn_metadata is not None):
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
|
# Last layer is a KV sharing layer
|
||||||
|
layer_attn_metadata = attn_metadata[
|
||||||
|
self.layers[-1].self_attn.attn.layer_name]
|
||||||
|
if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
|
||||||
|
logits_indices_padded = (
|
||||||
|
layer_attn_metadata.logits_indices_padded)
|
||||||
|
num_logits_indices = layer_attn_metadata.num_logits_indices
|
||||||
|
|
||||||
|
# Copy inputs for cudagraph
|
||||||
|
batch_size = positions.size(0)
|
||||||
|
self.positions[:batch_size].copy_(positions)
|
||||||
|
self_decoder_hidden_states, per_layer_inputs_adjusted = \
|
||||||
|
self.self_decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=self.positions[:batch_size],
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
per_layer_inputs=per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if logits_indices_padded is None:
|
||||||
|
logits_indices_padded = torch.arange(
|
||||||
|
positions.size(0),
|
||||||
|
dtype=positions.dtype,
|
||||||
|
device=positions.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(sarckk): There is currently a bug caused by
|
||||||
|
# vLLM converting output of last piecewise CUDA graph
|
||||||
|
# to weakref, causing memory to be prematurely freed
|
||||||
|
# when there are multiple compilation units
|
||||||
|
# Keep .clone() until fix in
|
||||||
|
# https://github.com/vllm-project/vllm/pull/22282
|
||||||
|
hidden_states = self_decoder_hidden_states.clone()
|
||||||
|
|
||||||
|
# Copy inputs for cudagraph
|
||||||
|
num_padded_logits_indices = logits_indices_padded.size(0)
|
||||||
|
self.positions[:num_padded_logits_indices].copy_(
|
||||||
|
positions[logits_indices_padded])
|
||||||
|
self.hidden_states[:num_padded_logits_indices].copy_(
|
||||||
|
self_decoder_hidden_states[logits_indices_padded])
|
||||||
|
self.per_layer_inputs[:num_padded_logits_indices].copy_(
|
||||||
|
per_layer_inputs_adjusted[logits_indices_padded])
|
||||||
|
cross_decoder_hidden_states = self.cross_decoder(
|
||||||
|
positions=self.positions[:num_padded_logits_indices],
|
||||||
|
hidden_states=self.hidden_states[:num_padded_logits_indices],
|
||||||
|
per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_logits_indices is not None:
|
||||||
|
assert num_logits_indices > 0
|
||||||
|
# Merge cross-decoder and self-decoder hidden states
|
||||||
|
hidden_states[logits_indices_padded[:num_logits_indices]] = (
|
||||||
|
cross_decoder_hidden_states[:num_logits_indices])
|
||||||
|
else:
|
||||||
|
hidden_states = cross_decoder_hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def normal_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states, per_layer_inputs = self.self_decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
per_layer_inputs=per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = self.cross_decoder(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
per_layer_inputs=per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def altup_unembed(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Altup unembed.
|
||||||
|
target_magnitude = torch.mean(hidden_states[..., 0]**2,
|
||||||
|
dim=-1,
|
||||||
|
keepdim=True)**0.5
|
||||||
|
for i in range(1, self.config.altup_num_inputs):
|
||||||
|
hidden_states[..., i] = self.altup_unembed_projections[i - 1](
|
||||||
|
hidden_states[..., i])
|
||||||
|
new_magnitude = torch.mean(hidden_states[..., i]**2,
|
||||||
|
dim=-1,
|
||||||
|
keepdim=True)**0.5
|
||||||
|
hidden_states[..., i] *= target_magnitude / torch.maximum(
|
||||||
|
new_magnitude, self.eps)
|
||||||
|
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
|
||||||
|
hidden_states = torch.mean(hidden_states, dim=-1)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
per_layer_inputs: torch.Tensor,
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
if inputs_embeds is not None:
|
if self.fast_prefill_enabled:
|
||||||
hidden_states_0 = inputs_embeds
|
hidden_states = self.fast_prefill_forward(
|
||||||
else:
|
input_ids,
|
||||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
positions,
|
||||||
|
inputs_embeds,
|
||||||
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
per_layer_inputs,
|
||||||
per_layer_projection = per_layer_projection.reshape(
|
|
||||||
*hidden_states_0.shape[:-1],
|
|
||||||
self.config.num_hidden_layers,
|
|
||||||
self.config.hidden_size_per_layer_input,
|
|
||||||
)
|
|
||||||
per_layer_projection = self.per_layer_projection_norm(
|
|
||||||
per_layer_projection)
|
|
||||||
|
|
||||||
if per_layer_inputs is not None:
|
|
||||||
# Profiling run does not compute per_layer_inputs
|
|
||||||
per_layer_inputs = per_layer_projection + per_layer_inputs
|
|
||||||
per_layer_inputs *= self.per_layer_input_scale
|
|
||||||
else:
|
|
||||||
per_layer_inputs = per_layer_projection
|
|
||||||
|
|
||||||
# Altup embed.
|
|
||||||
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
|
||||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
|
|
||||||
keepdim=True)**0.5
|
|
||||||
for i in range(1, self.config.altup_num_inputs):
|
|
||||||
hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
|
|
||||||
new_magnitude = torch.mean(hidden_states[i]**2,
|
|
||||||
dim=-1,
|
|
||||||
keepdim=True)**0.5
|
|
||||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
|
||||||
new_magnitude, self.eps)
|
|
||||||
hidden_states = torch.stack(hidden_states, dim=0)
|
|
||||||
|
|
||||||
# Transformer blocks.
|
|
||||||
for layer_idx, layer in enumerate(self.layers):
|
|
||||||
# [altup_num_inputs, num_tokens, hidden_size]
|
|
||||||
hidden_states = layer(
|
|
||||||
positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
per_layer_input=per_layer_inputs[:, layer_idx, :],
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
# Altup unembed.
|
hidden_states = self.normal_forward(
|
||||||
target_magnitude = torch.mean(hidden_states[0]**2,
|
input_ids,
|
||||||
dim=-1,
|
positions,
|
||||||
keepdim=True)**0.5
|
inputs_embeds,
|
||||||
for i in range(1, self.config.altup_num_inputs):
|
per_layer_inputs,
|
||||||
hidden_states[i] = self.altup_unembed_projections[i - 1](
|
**kwargs,
|
||||||
hidden_states[i])
|
)
|
||||||
new_magnitude = torch.mean(hidden_states[i]**2,
|
hidden_states = self.altup_unembed(hidden_states)
|
||||||
dim=-1,
|
|
||||||
keepdim=True)**0.5
|
|
||||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
|
||||||
new_magnitude, self.eps)
|
|
||||||
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
|
|
||||||
hidden_states = torch.mean(hidden_states, dim=0)
|
|
||||||
|
|
||||||
return self.norm(hidden_states)
|
return self.norm(hidden_states)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
|||||||
@ -620,7 +620,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
|
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
|
||||||
# them here, as the model forward has only access to the input_embeds.
|
# them here, as the model forward has only access to the input_embeds.
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
|
per_layer_inputs = self.language_model.model.self_decoder.get_per_layer_input_embeddings(
|
||||||
input_ids)
|
input_ids)
|
||||||
per_layer_inputs = per_layer_inputs.reshape(
|
per_layer_inputs = per_layer_inputs.reshape(
|
||||||
-1, self.config.text_config.num_hidden_layers,
|
-1, self.config.text_config.num_hidden_layers,
|
||||||
|
|||||||
@ -4,11 +4,13 @@ import abc
|
|||||||
import enum
|
import enum
|
||||||
import functools
|
import functools
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, make_dataclass
|
from dataclasses import dataclass, fields, make_dataclass
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
|
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol,
|
||||||
|
TypeVar)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import runtime_checkable
|
||||||
|
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
@ -19,7 +21,8 @@ if TYPE_CHECKING:
|
|||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata)
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||||
get_kv_connector_cache_layout)
|
get_kv_connector_cache_layout)
|
||||||
@ -65,6 +68,10 @@ class CommonAttentionMetadata:
|
|||||||
|
|
||||||
causal: bool = True
|
causal: bool = True
|
||||||
|
|
||||||
|
# Needed by FastPrefillAttentionBuilder
|
||||||
|
logits_indices_padded: Optional[torch.Tensor] = None
|
||||||
|
num_logits_indices: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UbatchSlice:
|
class UbatchSlice:
|
||||||
@ -542,6 +549,69 @@ def make_local_attention_virtual_batches(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
) -> CommonAttentionMetadata:
|
||||||
|
if common_attn_metadata.max_query_len == 1:
|
||||||
|
# All requests are decode (assume 1 token for now)
|
||||||
|
# Skip computing fast prefill path
|
||||||
|
return common_attn_metadata
|
||||||
|
|
||||||
|
assert common_attn_metadata.logits_indices_padded is not None
|
||||||
|
assert common_attn_metadata.num_logits_indices is not None
|
||||||
|
|
||||||
|
logits_indices_padded = common_attn_metadata.logits_indices_padded
|
||||||
|
num_logits_indices = common_attn_metadata.num_logits_indices
|
||||||
|
# Get rid of CUDAGraph padding, if any
|
||||||
|
logits_indices = logits_indices_padded[:num_logits_indices]
|
||||||
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
|
# Example inputs
|
||||||
|
# num_reqs: 3
|
||||||
|
# generation_indices: [14, 18, 19, 27]
|
||||||
|
# query_start_loc: [0, 15, 20, 28]
|
||||||
|
# seq_lens: [41, 31, 40]
|
||||||
|
|
||||||
|
# Find how many decode indices belong to each request
|
||||||
|
# request_ids: [0, 1, 1, 2]
|
||||||
|
request_ids = torch.bucketize(logits_indices,
|
||||||
|
query_start_loc[1:],
|
||||||
|
right=True)
|
||||||
|
|
||||||
|
# Figure out how many tokens are in each request
|
||||||
|
# num_decode_tokens: [1, 2, 1]
|
||||||
|
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
|
||||||
|
|
||||||
|
# Calculate new query_start_loc with tokens in generation_indices
|
||||||
|
# decode_query_start_loc: [0, 1, 3, 4]
|
||||||
|
decode_query_start_loc = torch.empty(num_reqs + 1,
|
||||||
|
device=query_start_loc.device,
|
||||||
|
dtype=query_start_loc.dtype)
|
||||||
|
|
||||||
|
decode_query_start_loc[0] = 0
|
||||||
|
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
|
||||||
|
decode_max_query_len = int(num_decode_tokens.max().item())
|
||||||
|
total_num_decode_tokens = int(num_decode_tokens.sum().item())
|
||||||
|
|
||||||
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
|
query_start_loc=decode_query_start_loc,
|
||||||
|
query_start_loc_cpu=decode_query_start_loc.to("cpu",
|
||||||
|
non_blocking=True),
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
|
||||||
|
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
num_actual_tokens=total_num_decode_tokens,
|
||||||
|
max_query_len=decode_max_query_len,
|
||||||
|
max_seq_len=common_attn_metadata.max_seq_len,
|
||||||
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||||
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
return common_attn_metadata
|
||||||
|
|
||||||
|
|
||||||
def subclass_attention_backend(
|
def subclass_attention_backend(
|
||||||
name_prefix: str, attention_backend_cls: type[AttentionBackend],
|
name_prefix: str, attention_backend_cls: type[AttentionBackend],
|
||||||
builder_cls: type[AttentionMetadataBuilder[M]]
|
builder_cls: type[AttentionMetadataBuilder[M]]
|
||||||
@ -679,13 +749,56 @@ def subclass_attention_metadata(
|
|||||||
return Wrapped
|
return Wrapped
|
||||||
|
|
||||||
|
|
||||||
def make_kv_sharing_fast_prefill_attention_metadata(
|
@runtime_checkable
|
||||||
metadata_cls: Any, ) -> Any:
|
class KVSharingFastPrefillMetadata(Protocol):
|
||||||
"""
|
logits_indices_padded: torch.Tensor
|
||||||
Return a new subclass of `metadata_cls` for fast prefill
|
num_logits_indices: int
|
||||||
"""
|
|
||||||
return subclass_attention_metadata(
|
|
||||||
name_prefix="KVSharingFastPrefill",
|
def create_fast_prefill_custom_backend(
|
||||||
metadata_cls=metadata_cls,
|
prefix: str,
|
||||||
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
|
underlying_attn_backend: AttentionBackend,
|
||||||
)
|
) -> type[AttentionBackend]:
|
||||||
|
|
||||||
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
|
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
|
||||||
|
def build(self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False) -> AttentionMetadata:
|
||||||
|
new_common_attn_metadata =\
|
||||||
|
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
|
||||||
|
metadata = super().build(common_prefix_len,
|
||||||
|
new_common_attn_metadata, fast_build)
|
||||||
|
|
||||||
|
class KVSharingFastPrefillAttentionMetadata(
|
||||||
|
metadata.__class__, # type: ignore
|
||||||
|
KVSharingFastPrefillMetadata):
|
||||||
|
|
||||||
|
def __init__(self, metadata, common_attn_metadata):
|
||||||
|
# Shallow copy all fields in metadata cls
|
||||||
|
for field in fields(metadata.__class__):
|
||||||
|
setattr(self, field.name,
|
||||||
|
getattr(metadata, field.name))
|
||||||
|
|
||||||
|
# Set additional fields that will be used in model code
|
||||||
|
assert (common_attn_metadata.logits_indices_padded
|
||||||
|
is not None
|
||||||
|
and common_attn_metadata.num_logits_indices
|
||||||
|
is not None)
|
||||||
|
self.logits_indices_padded = \
|
||||||
|
common_attn_metadata.logits_indices_padded
|
||||||
|
self.num_logits_indices = \
|
||||||
|
common_attn_metadata.num_logits_indices
|
||||||
|
|
||||||
|
return KVSharingFastPrefillAttentionMetadata(
|
||||||
|
metadata, common_attn_metadata)
|
||||||
|
|
||||||
|
attn_backend = subclass_attention_backend(
|
||||||
|
name_prefix=prefix,
|
||||||
|
attention_backend_cls=underlying_attn_backend,
|
||||||
|
builder_cls=FastPrefillAttentionBuilder)
|
||||||
|
|
||||||
|
return attn_backend
|
||||||
|
|||||||
@ -335,6 +335,13 @@ class AsyncLLM(EngineClient):
|
|||||||
returning the RequestOutput back to the caller.
|
returning the RequestOutput back to the caller.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if (self.vllm_config.cache_config.kv_sharing_fast_prefill
|
||||||
|
and sampling_params.prompt_logprobs):
|
||||||
|
raise ValueError(
|
||||||
|
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||||
|
"prompt tokens, please disable it when the requests need "
|
||||||
|
"prompt logprobs")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# We start the output_handler on the first call to generate() so
|
# We start the output_handler on the first call to generate() so
|
||||||
# we can call __init__ before the event loop, which enables us
|
# we can call __init__ before the event loop, which enables us
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import gc
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
import time
|
import time
|
||||||
@ -58,7 +57,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|||||||
supports_dynamo)
|
supports_dynamo)
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
make_kv_sharing_fast_prefill_attention_metadata,
|
create_fast_prefill_custom_backend,
|
||||||
reorder_batch_to_split_decodes_and_prefills)
|
reorder_batch_to_split_decodes_and_prefills)
|
||||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||||
@ -84,9 +83,10 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
|||||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
|
||||||
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
|
from .utils import (AttentionGroup, MultiModalBudget,
|
||||||
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
|
||||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||||
|
scatter_mm_placeholders)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
@ -860,6 +860,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
block_table_tensor=blk_table_tensor,
|
block_table_tensor=blk_table_tensor,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
|
logits_indices_padded=logits_indices_padded,
|
||||||
|
num_logits_indices=logits_indices.size(0),
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -884,28 +886,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
))
|
))
|
||||||
|
|
||||||
fast_prefill_metadata = attn_metadata_i
|
|
||||||
if (self.cache_config.kv_sharing_fast_prefill
|
|
||||||
and self.kv_sharing_fast_prefill_eligible_layers):
|
|
||||||
# Dynamically create a a dataclass type that inherits
|
|
||||||
# from attention metadata type but includes additional
|
|
||||||
# fields logits_indices_padded and num_logits_indices
|
|
||||||
# which are required for prefill truncation
|
|
||||||
fast_prefill_metadata_type = (
|
|
||||||
make_kv_sharing_fast_prefill_attention_metadata(
|
|
||||||
metadata_cls=type(attn_metadata_i), ))
|
|
||||||
fast_prefill_metadata = fast_prefill_metadata_type(
|
|
||||||
**dataclasses.asdict(attn_metadata_i),
|
|
||||||
logits_indices_padded=logits_indices_padded,
|
|
||||||
num_logits_indices=logits_indices.size(0),
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer_name in attn_group.layer_names:
|
for layer_name in attn_group.layer_names:
|
||||||
if (self.cache_config.kv_sharing_fast_prefill
|
|
||||||
and layer_name
|
|
||||||
in self.kv_sharing_fast_prefill_eligible_layers):
|
|
||||||
attn_metadata[layer_name] = fast_prefill_metadata
|
|
||||||
continue
|
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
# Hot-Swap lora model
|
# Hot-Swap lora model
|
||||||
@ -1484,6 +1465,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
return self.kv_connector_no_forward(scheduler_output,
|
return self.kv_connector_no_forward(scheduler_output,
|
||||||
self.vllm_config)
|
self.vllm_config)
|
||||||
|
|
||||||
|
if self.cache_config.kv_sharing_fast_prefill:
|
||||||
|
assert not self.input_batch.num_prompt_logprobs, (
|
||||||
|
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||||
|
"prompt tokens, tokens, please disable it when the requests "
|
||||||
|
"need prompt logprobs")
|
||||||
|
|
||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||||
@ -2742,6 +2729,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# layer.
|
# layer.
|
||||||
for layer_name in layer_names:
|
for layer_name in layer_names:
|
||||||
attn_backend = layers[layer_name].get_attn_backend()
|
attn_backend = layers[layer_name].get_attn_backend()
|
||||||
|
|
||||||
|
if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
|
||||||
|
attn_backend = create_fast_prefill_custom_backend(
|
||||||
|
"FastPrefill",
|
||||||
|
attn_backend,
|
||||||
|
)
|
||||||
|
|
||||||
key = attn_backend.full_cls_name()
|
key = attn_backend.full_cls_name()
|
||||||
attn_backends[key] = attn_backend
|
attn_backends[key] = attn_backend
|
||||||
attn_backend_layers[key].append(layer_name)
|
attn_backend_layers[key].append(layer_name)
|
||||||
@ -3074,20 +3068,40 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||||||
kv_cache_raw_tensors)
|
kv_cache_raw_tensors)
|
||||||
|
|
||||||
# Setup `kv_cache_config` and `kv_caches` for models
|
# Set up cross-layer KV cache sharing
|
||||||
# with cross-layer KV sharing
|
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
|
||||||
if self.shared_kv_cache_layers:
|
):
|
||||||
initialize_kv_cache_for_kv_sharing(
|
logger.debug("%s reuses KV cache of %s", layer_name,
|
||||||
|
target_layer_name)
|
||||||
|
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||||
|
|
||||||
|
bind_kv_cache(kv_caches,
|
||||||
|
self.compilation_config.static_forward_context,
|
||||||
|
self.kv_caches)
|
||||||
|
return kv_caches
|
||||||
|
|
||||||
|
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
|
||||||
|
self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
|
"""
|
||||||
|
Add layers that re-use KV cache to KV cache group of its target layer.
|
||||||
|
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
|
||||||
|
"""
|
||||||
|
if not self.shared_kv_cache_layers:
|
||||||
|
# No cross-layer KV sharing, return
|
||||||
|
return
|
||||||
|
|
||||||
|
add_kv_sharing_layers_to_kv_cache_groups(
|
||||||
self.shared_kv_cache_layers,
|
self.shared_kv_cache_layers,
|
||||||
kv_cache_config.kv_cache_groups,
|
kv_cache_config.kv_cache_groups,
|
||||||
kv_caches,
|
|
||||||
self.attn_groups,
|
|
||||||
self.runner_only_attn_layers,
|
self.runner_only_attn_layers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cache_config.kv_sharing_fast_prefill:
|
||||||
|
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
|
||||||
|
# similar KV sharing setups, only the layers that generate KV caches
|
||||||
|
# are involved in the prefill phase, enabling prefill to early exit.
|
||||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||||
Attention)
|
Attention)
|
||||||
# Iterate in reversed order and add layers that re-use KV cache
|
|
||||||
# e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
|
|
||||||
for layer_name in reversed(attn_layers):
|
for layer_name in reversed(attn_layers):
|
||||||
if layer_name in self.shared_kv_cache_layers:
|
if layer_name in self.shared_kv_cache_layers:
|
||||||
self.kv_sharing_fast_prefill_eligible_layers.add(
|
self.kv_sharing_fast_prefill_eligible_layers.add(
|
||||||
@ -3095,11 +3109,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
bind_kv_cache(kv_caches,
|
|
||||||
self.compilation_config.static_forward_context,
|
|
||||||
self.kv_caches)
|
|
||||||
return kv_caches
|
|
||||||
|
|
||||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize KV cache based on `kv_cache_config`.
|
Initialize KV cache based on `kv_cache_config`.
|
||||||
@ -3111,6 +3120,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
self.may_reinitialize_input_batch(kv_cache_config)
|
self.may_reinitialize_input_batch(kv_cache_config)
|
||||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||||
|
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||||
self.initialize_attn_backend(kv_cache_config)
|
self.initialize_attn_backend(kv_cache_config)
|
||||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||||
|
|
||||||
|
|||||||
@ -55,9 +55,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
|||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from .utils import (MultiModalBudget, bind_kv_cache,
|
from .utils import (MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups,
|
||||||
initialize_kv_cache_for_kv_sharing,
|
bind_kv_cache, sanity_check_mm_encoder_outputs)
|
||||||
sanity_check_mm_encoder_outputs)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -1599,6 +1598,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.encoder_cache.clear()
|
self.encoder_cache.clear()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
def maybe_setup_cross_layer_kv_sharing(
|
||||||
|
self,
|
||||||
|
kv_caches: dict[str, torch.Tensor],
|
||||||
|
kv_cache_config: KVCacheConfig,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add layers that re-use KV cache to KV cache group of its target layer.
|
||||||
|
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
|
||||||
|
"""
|
||||||
|
if not self.shared_kv_cache_layers:
|
||||||
|
# No cross-layer KV sharing, return
|
||||||
|
return
|
||||||
|
|
||||||
|
add_kv_sharing_layers_to_kv_cache_groups(
|
||||||
|
self.shared_kv_cache_layers,
|
||||||
|
kv_cache_config.kv_cache_groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
|
||||||
|
):
|
||||||
|
logger.debug("%s reuses KV cache of %s", layer_name,
|
||||||
|
target_layer_name)
|
||||||
|
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||||
|
|
||||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize KV cache based on `kv_cache_config`.
|
Initialize KV cache based on `kv_cache_config`.
|
||||||
@ -1664,14 +1687,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# Setup `kv_cache_config` and `kv_caches` for models
|
# Set up cross-layer KV cache sharing if needed
|
||||||
# with cross-layer KV sharing
|
self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config)
|
||||||
if self.shared_kv_cache_layers:
|
|
||||||
initialize_kv_cache_for_kv_sharing(
|
|
||||||
self.shared_kv_cache_layers,
|
|
||||||
kv_cache_config.kv_cache_groups,
|
|
||||||
kv_caches,
|
|
||||||
)
|
|
||||||
|
|
||||||
bind_kv_cache(
|
bind_kv_cache(
|
||||||
kv_caches,
|
kv_caches,
|
||||||
|
|||||||
@ -203,12 +203,9 @@ def gather_mm_placeholders(
|
|||||||
return placeholders[is_embed]
|
return placeholders[is_embed]
|
||||||
|
|
||||||
|
|
||||||
def initialize_kv_cache_for_kv_sharing(
|
def add_kv_sharing_layers_to_kv_cache_groups(
|
||||||
shared_kv_cache_layers: dict[str, str],
|
shared_kv_cache_layers: dict[str, str],
|
||||||
kv_cache_groups: list[KVCacheGroupSpec],
|
kv_cache_groups: list[KVCacheGroupSpec],
|
||||||
kv_caches: dict[str, torch.Tensor],
|
|
||||||
# Optional for now to avoid breaking TPU
|
|
||||||
attn_groups: Optional[list[list[AttentionGroup]]] = None,
|
|
||||||
runner_only_attn_layers: Optional[set[str]] = None,
|
runner_only_attn_layers: Optional[set[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -223,38 +220,15 @@ def initialize_kv_cache_for_kv_sharing(
|
|||||||
means this layer will perform attention using the keys and values
|
means this layer will perform attention using the keys and values
|
||||||
from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||||
kv_cache_groups: The KV cache groups of the model.
|
kv_cache_groups: The KV cache groups of the model.
|
||||||
kv_caches: The allocated kv_caches with layer names as keys.
|
|
||||||
Note that layers in shared_kv_cache_layers.keys() are not
|
|
||||||
originally included as it only contains layers which have its own
|
|
||||||
KV cache allocation.
|
|
||||||
attn_groups: Optional list of attention groups. Layers in the same KV
|
|
||||||
cache group may be placed in different attention groups if they
|
|
||||||
have different attention backends. Currently only provided by
|
|
||||||
GPU model runner.
|
|
||||||
"""
|
"""
|
||||||
# mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx)
|
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
|
||||||
layer_to_attn_group_idx: dict[str, tuple[int, int]] = {}
|
for kv_cache_group in kv_cache_groups:
|
||||||
if attn_groups:
|
|
||||||
for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups):
|
|
||||||
for attn_group_idx, attn_group in enumerate(kv_attn_groups):
|
|
||||||
for layer_name in attn_group.layer_names:
|
|
||||||
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx,
|
|
||||||
attn_group_idx)
|
|
||||||
else:
|
|
||||||
for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups):
|
|
||||||
for layer_name in kv_cache_group.layer_names:
|
for layer_name in kv_cache_group.layer_names:
|
||||||
# attn group idx default to 0 if not provided
|
layer_to_kv_cache_group[layer_name] = kv_cache_group
|
||||||
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0)
|
|
||||||
|
|
||||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
|
||||||
kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0]
|
tgt_kv_cache_group.layer_names.append(layer_name)
|
||||||
kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name)
|
|
||||||
|
|
||||||
if attn_groups:
|
|
||||||
attn_group_idx = layer_to_attn_group_idx[target_layer_name][1]
|
|
||||||
attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
|
|
||||||
layer_name)
|
|
||||||
|
|
||||||
if runner_only_attn_layers is not None:
|
if runner_only_attn_layers is not None:
|
||||||
runner_only_attn_layers.add(layer_name)
|
runner_only_attn_layers.add(layer_name)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user