From 0e54bbe108519076a025c3dc6215fc2b339b4aad Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Date: Tue, 23 Sep 2025 19:25:34 -0700 Subject: [PATCH] [KV sharing] Re-land Gemma3n model changes from #22628 (#24357) Signed-off-by: Yong Hoon Shin Signed-off-by: yewentao256 --- vllm/model_executor/models/gemma3n.py | 402 ++++++++++++++++++++++---- 1 file changed, 344 insertions(+), 58 deletions(-) diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index f4d288fd887e9..0b6bccb334982 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -26,6 +26,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, GeluAndMul, @@ -44,6 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from .interfaces import SupportsQuant from .utils import (AutoWeightsLoader, extract_layer_index, @@ -51,6 +53,8 @@ from .utils import (AutoWeightsLoader, extract_layer_index, logger = init_logger(__name__) +EPS = torch.tensor(torch.finfo().min) + class Gemma3nAltUp(nn.Module): """Alternating updates (Altup) @@ -532,16 +536,29 @@ class Gemma3nDecoderLayer(nn.Module): return corrected_predictions -@support_torch_compile -class Gemma3nTextModel(nn.Module, SupportsQuant): +# This enables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. + kv_sharing_fast_prefill) +class Gemma3nSelfDecoder(nn.Module): + """ + Includes altup embedding and self decoder layers + """ - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma3nDecoderLayer], + layer_idx_start: int, + ): super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.config = config - self.quant_config = quant_config + quant_config = vllm_config.quant_config self.embed_tokens = VocabParallelEmbedding( config.vocab_size, @@ -594,32 +611,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): prefix=f"{prefix}.altup_projections.{idx-1}", ) for idx in range(1, self.config.altup_num_inputs) ]) - self.altup_unembed_projections = nn.ModuleList([ - ColumnParallelLinear( - config.hidden_size, - config.hidden_size, - bias=False, - gather_output=True, - return_bias=False, - quant_config=quant_config, - prefix=f"{prefix}.altup_unembed_projections.{idx-1}", - ) for idx in range(1, self.config.altup_num_inputs) - ]) - - # Transformer blocks. - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: Gemma3nDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.norm = RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, - ) - self.eps = torch.tensor(torch.finfo().min) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) * self.embed_scale def get_per_layer_input_embeddings( self, input_ids: torch.Tensor) -> torch.Tensor: @@ -633,20 +624,11 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): return self.embed_tokens_per_layer( per_layer_inputs_tokens) * self.embed_scale_per_layer - def forward( + def get_per_layer_inputs( self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - per_layer_inputs: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - if inputs_embeds is not None: - hidden_states_0 = inputs_embeds - else: - hidden_states_0 = self.get_input_embeddings(input_ids) - + hidden_states_0: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor], + ) -> torch.Tensor: per_layer_projection = self.per_layer_model_projection(hidden_states_0) per_layer_projection = per_layer_projection.reshape( *hidden_states_0.shape[:-1], @@ -655,14 +637,18 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ) per_layer_projection = self.per_layer_projection_norm( per_layer_projection) - if per_layer_inputs is not None: # Profiling run does not compute per_layer_inputs per_layer_inputs = per_layer_projection + per_layer_inputs per_layer_inputs *= self.per_layer_input_scale else: per_layer_inputs = per_layer_projection + return per_layer_inputs + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) * self.embed_scale + + def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor: # Altup embed. hidden_states = [hidden_states_0] * self.config.altup_num_inputs target_magnitude = torch.mean(hidden_states_0**2, dim=-1, @@ -673,11 +659,77 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): dim=-1, keepdim=True)**0.5 hidden_states[i] *= target_magnitude / torch.maximum( - new_magnitude, self.eps) - hidden_states = torch.stack(hidden_states, dim=0) + new_magnitude, EPS) + hidden_states = torch.stack(hidden_states, dim=-1) + return hidden_states - # Transformer blocks. - for layer_idx, layer in enumerate(self.layers): + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + hidden_states_0 = inputs_embeds + else: + hidden_states_0 = self.get_input_embeddings(input_ids) + + adjusted_per_layer_inputs = self.get_per_layer_inputs( + hidden_states_0, per_layer_inputs) + hidden_states = self.altup_embed(hidden_states_0) + + # [altnum_inputs, num_tokens, hidden_size] + hidden_states = hidden_states.permute(2, 0, 1) + + for idx, layer in enumerate(self.decoder_layers): + layer_idx = idx + self.layer_idx_start + # [altup_num_inputs, num_tokens, hidden_size] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :], + **kwargs, + ) + + # [num_tokens, hidden_size, altnum_inputs] + hidden_states = hidden_states.permute(1, 2, 0) + + return hidden_states, adjusted_per_layer_inputs + + +# This enables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. + kv_sharing_fast_prefill) +class Gemma3nCrossDecoder(nn.Module): + """ + Cross-decoder layers + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma3nDecoderLayer], + layer_idx_start: int, + ): + super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_inputs: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # [altnum_inputs, num_tokens, hidden_size] + hidden_states = hidden_states.permute(2, 0, 1) + for idx, layer in enumerate(self.decoder_layers): + layer_idx = idx + self.layer_idx_start # [altup_num_inputs, num_tokens, hidden_size] hidden_states = layer( positions=positions, @@ -685,22 +737,249 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): per_layer_input=per_layer_inputs[:, layer_idx, :], **kwargs, ) + # [num_tokens, hidden_size, altnum_inputs] + hidden_states = hidden_states.permute(1, 2, 0) + return hidden_states + +# This disables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. + cache_config.kv_sharing_fast_prefill) +class Gemma3nTextModel(nn.Module, SupportsQuant): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.altup_unembed_projections = nn.ModuleList([ + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + gather_output=True, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.altup_unembed_projections.{idx-1}", + ) for idx in range(1, self.config.altup_num_inputs) + ]) + + # Allocate config.num_kv_shared_layers layers for self-decoder + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma3nDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + + first_kv_shared_layer_idx = (config.num_hidden_layers - + config.num_kv_shared_layers) + + # NOTE(sarckk): importing this top level seems to cause issues + # during running of tests. + from vllm.compilation.backends import set_model_tag + + # Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO) + with set_model_tag("self_decoder"): + self.self_decoder = Gemma3nSelfDecoder( + vllm_config=vllm_config, + prefix=f"{prefix}.self_decoder", + decoder_layers=self.layers[:first_kv_shared_layer_idx], + layer_idx_start=0, + ) + # Layer idx 20-30 are cross-decoder layers in YOCO + with set_model_tag("cross_decoder"): + self.cross_decoder = Gemma3nCrossDecoder( + vllm_config=vllm_config, + prefix=f"{prefix}.cross_decoder", + decoder_layers=self.layers[first_kv_shared_layer_idx:], + layer_idx_start=first_kv_shared_layer_idx, + ) + + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill + + if self.fast_prefill_enabled: + # Allocate static buffers for CUDAGraph + # TODO(sarckk): Extract this functionality to interface + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + device = next(self.parameters()).device + self.positions = torch.zeros(max_num_tokens, + dtype=torch.int64, + device=device) + self.hidden_states = torch.zeros( + (max_num_tokens, config.hidden_size, + self.config.altup_num_inputs), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) + self.per_layer_inputs = torch.zeros( + (max_num_tokens, self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) + + @property + def embed_tokens(self): + return self.self_decoder.embed_tokens + + def get_per_layer_input_embeddings( + self, input_ids: torch.Tensor) -> torch.Tensor: + return self.self_decoder.get_per_layer_input_embeddings(input_ids) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.self_decoder.get_input_embeddings(input_ids) + + def fast_prefill_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + logits_indices_padded, num_logits_indices = None, None + attn_metadata = get_forward_context().attn_metadata + + # attn_metadata is None during dummy runs + if (self.fast_prefill_enabled and attn_metadata is not None): + assert isinstance(attn_metadata, dict) + # Last layer is a KV sharing layer + layer_attn_metadata = attn_metadata[ + self.layers[-1].self_attn.attn.layer_name] + if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)): + logits_indices_padded = ( + layer_attn_metadata.logits_indices_padded) + num_logits_indices = layer_attn_metadata.num_logits_indices + + # Copy inputs for cudagraph + batch_size = positions.size(0) + self.positions[:batch_size].copy_(positions) + self_decoder_hidden_states, per_layer_inputs_adjusted = \ + self.self_decoder( + input_ids=input_ids, + positions=self.positions[:batch_size], + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + + if logits_indices_padded is None: + logits_indices_padded = torch.arange( + positions.size(0), + dtype=positions.dtype, + device=positions.device, + ) + + # NOTE(sarckk): There is currently a bug caused by + # vLLM converting output of last piecewise CUDA graph + # to weakref, causing memory to be prematurely freed + # when there are multiple compilation units + # Keep .clone() until fix in + # https://github.com/vllm-project/vllm/pull/22282 + hidden_states = self_decoder_hidden_states.clone() + + # Copy inputs for cudagraph + num_padded_logits_indices = logits_indices_padded.size(0) + self.positions[:num_padded_logits_indices].copy_( + positions[logits_indices_padded]) + self.hidden_states[:num_padded_logits_indices].copy_( + self_decoder_hidden_states[logits_indices_padded]) + self.per_layer_inputs[:num_padded_logits_indices].copy_( + per_layer_inputs_adjusted[logits_indices_padded]) + cross_decoder_hidden_states = self.cross_decoder( + positions=self.positions[:num_padded_logits_indices], + hidden_states=self.hidden_states[:num_padded_logits_indices], + per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices], + **kwargs, + ) + + if num_logits_indices is not None: + assert num_logits_indices > 0 + # Merge cross-decoder and self-decoder hidden states + hidden_states[logits_indices_padded[:num_logits_indices]] = ( + cross_decoder_hidden_states[:num_logits_indices]) + else: + hidden_states = cross_decoder_hidden_states + + return hidden_states + + def normal_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states, per_layer_inputs = self.self_decoder( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + hidden_states = self.cross_decoder( + positions=positions, + hidden_states=hidden_states, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + return hidden_states + + def altup_unembed( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: # Altup unembed. - target_magnitude = torch.mean(hidden_states[0]**2, + target_magnitude = torch.mean(hidden_states[..., 0]**2, dim=-1, keepdim=True)**0.5 for i in range(1, self.config.altup_num_inputs): - hidden_states[i] = self.altup_unembed_projections[i - 1]( - hidden_states[i]) - new_magnitude = torch.mean(hidden_states[i]**2, + hidden_states[..., i] = self.altup_unembed_projections[i - 1]( + hidden_states[..., i]) + new_magnitude = torch.mean(hidden_states[..., i]**2, dim=-1, keepdim=True)**0.5 - hidden_states[i] *= target_magnitude / torch.maximum( - new_magnitude, self.eps) - # [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size] - hidden_states = torch.mean(hidden_states, dim=0) + hidden_states[..., i] *= target_magnitude / torch.maximum( + new_magnitude, EPS) + # [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size] + hidden_states = torch.mean(hidden_states, dim=-1) + return hidden_states + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.fast_prefill_enabled: + hidden_states = self.fast_prefill_forward( + input_ids, + positions, + inputs_embeds, + per_layer_inputs, + **kwargs, + ) + else: + hidden_states = self.normal_forward( + input_ids, + positions, + inputs_embeds, + per_layer_inputs, + **kwargs, + ) + hidden_states = self.altup_unembed(hidden_states) return self.norm(hidden_states) def load_weights(self, weights: Iterable[tuple[str, @@ -716,6 +995,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: + # decoder layer weights, altup_unembed_projections and rmsnorm + # are initialized in text model, others are in self decoder + if (not name.startswith('layers') + and not name.startswith('altup_unembed_projections') + and not name.startswith('norm')): + name = f"self_decoder.{name}" + if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): # Loading kv cache scales for compressed-tensors quantization