[Model] [Quantization] Support quantization for Gemma3n (#21974)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers 2025-08-01 01:45:15 -04:00 committed by GitHub
parent e1a7fe4af5
commit 0f46a780d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -46,6 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant
from .utils import (AutoWeightsLoader, extract_layer_index, from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter, make_layers, maybe_prefix) is_pp_missing_parameter, make_layers, maybe_prefix)
@ -68,6 +69,7 @@ class Gemma3nAltUp(nn.Module):
altup_num_inputs: int, altup_num_inputs: int,
altup_coef_clip: float, altup_coef_clip: float,
altup_active_idx: int, altup_active_idx: int,
quant_config: QuantizationConfig,
prefix: str, prefix: str,
): ):
super().__init__() super().__init__()
@ -80,6 +82,7 @@ class Gemma3nAltUp(nn.Module):
altup_num_inputs, altup_num_inputs,
altup_num_inputs, altup_num_inputs,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.correction_coefs", prefix=f"{prefix}.correction_coefs",
return_bias=False, return_bias=False,
) )
@ -87,6 +90,7 @@ class Gemma3nAltUp(nn.Module):
altup_num_inputs, altup_num_inputs,
altup_num_inputs**2, altup_num_inputs**2,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.prediction_coefs", prefix=f"{prefix}.prediction_coefs",
return_bias=False, return_bias=False,
) )
@ -94,6 +98,7 @@ class Gemma3nAltUp(nn.Module):
hidden_size, hidden_size,
altup_num_inputs, altup_num_inputs,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.modality_router", prefix=f"{prefix}.modality_router",
return_bias=False, return_bias=False,
) )
@ -400,6 +405,7 @@ class Gemma3nDecoderLayer(nn.Module):
altup_num_inputs=config.altup_num_inputs, altup_num_inputs=config.altup_num_inputs,
altup_coef_clip=config.altup_coef_clip, altup_coef_clip=config.altup_coef_clip,
altup_active_idx=config.altup_active_idx, altup_active_idx=config.altup_active_idx,
quant_config=quant_config,
prefix=f"{prefix}.altup", prefix=f"{prefix}.altup",
) )
self.self_attn = Gemma3nAttention( self.self_attn = Gemma3nAttention(
@ -527,7 +533,7 @@ class Gemma3nDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class Gemma3nTextModel(nn.Module): class Gemma3nTextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
@ -540,6 +546,7 @@ class Gemma3nTextModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens", prefix=f"{prefix}.embed_tokens",
) )
self.embed_scale = torch.tensor( self.embed_scale = torch.tensor(
@ -549,6 +556,7 @@ class Gemma3nTextModel(nn.Module):
self.embed_tokens_per_layer = VocabParallelEmbedding( self.embed_tokens_per_layer = VocabParallelEmbedding(
config.vocab_size_per_layer_input, config.vocab_size_per_layer_input,
config.num_hidden_layers * config.hidden_size_per_layer_input, config.num_hidden_layers * config.hidden_size_per_layer_input,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_embed_tokens", prefix=f"{prefix}.per_layer_embed_tokens",
) )
self.embed_scale_per_layer = torch.tensor( self.embed_scale_per_layer = torch.tensor(
@ -582,7 +590,7 @@ class Gemma3nTextModel(nn.Module):
gather_output=True, gather_output=True,
return_bias=False, return_bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_projections", prefix=f"{prefix}.altup_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs) ) for idx in range(1, self.config.altup_num_inputs)
]) ])
self.altup_unembed_projections = nn.ModuleList([ self.altup_unembed_projections = nn.ModuleList([
@ -593,7 +601,7 @@ class Gemma3nTextModel(nn.Module):
gather_output=True, gather_output=True,
return_bias=False, return_bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_unembed_projections", prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs) ) for idx in range(1, self.config.altup_num_inputs)
]) ])
@ -774,7 +782,7 @@ class Gemma3nModel(nn.Module):
**kwargs) **kwargs)
class Gemma3nForConditionalGeneration(nn.Module): class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",