diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index a58b32793dbe..e16c03c8d3b5 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -46,6 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsQuant from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_layers, maybe_prefix) @@ -68,6 +69,7 @@ class Gemma3nAltUp(nn.Module): altup_num_inputs: int, altup_coef_clip: float, altup_active_idx: int, + quant_config: QuantizationConfig, prefix: str, ): super().__init__() @@ -80,6 +82,7 @@ class Gemma3nAltUp(nn.Module): altup_num_inputs, altup_num_inputs, bias=False, + quant_config=quant_config, prefix=f"{prefix}.correction_coefs", return_bias=False, ) @@ -87,6 +90,7 @@ class Gemma3nAltUp(nn.Module): altup_num_inputs, altup_num_inputs**2, bias=False, + quant_config=quant_config, prefix=f"{prefix}.prediction_coefs", return_bias=False, ) @@ -94,6 +98,7 @@ class Gemma3nAltUp(nn.Module): hidden_size, altup_num_inputs, bias=False, + quant_config=quant_config, prefix=f"{prefix}.modality_router", return_bias=False, ) @@ -400,6 +405,7 @@ class Gemma3nDecoderLayer(nn.Module): altup_num_inputs=config.altup_num_inputs, altup_coef_clip=config.altup_coef_clip, altup_active_idx=config.altup_active_idx, + quant_config=quant_config, prefix=f"{prefix}.altup", ) self.self_attn = Gemma3nAttention( @@ -527,7 +533,7 @@ class Gemma3nDecoderLayer(nn.Module): @support_torch_compile -class Gemma3nTextModel(nn.Module): +class Gemma3nTextModel(nn.Module, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -540,6 +546,7 @@ class Gemma3nTextModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, prefix=f"{prefix}.embed_tokens", ) self.embed_scale = torch.tensor( @@ -549,6 +556,7 @@ class Gemma3nTextModel(nn.Module): self.embed_tokens_per_layer = VocabParallelEmbedding( config.vocab_size_per_layer_input, config.num_hidden_layers * config.hidden_size_per_layer_input, + quant_config=quant_config, prefix=f"{prefix}.per_layer_embed_tokens", ) self.embed_scale_per_layer = torch.tensor( @@ -582,7 +590,7 @@ class Gemma3nTextModel(nn.Module): gather_output=True, return_bias=False, quant_config=quant_config, - prefix=f"{prefix}.{idx-1}.altup_projections", + prefix=f"{prefix}.altup_projections.{idx-1}", ) for idx in range(1, self.config.altup_num_inputs) ]) self.altup_unembed_projections = nn.ModuleList([ @@ -593,7 +601,7 @@ class Gemma3nTextModel(nn.Module): gather_output=True, return_bias=False, quant_config=quant_config, - prefix=f"{prefix}.{idx-1}.altup_unembed_projections", + prefix=f"{prefix}.altup_unembed_projections.{idx-1}", ) for idx in range(1, self.config.altup_num_inputs) ]) @@ -774,7 +782,7 @@ class Gemma3nModel(nn.Module): **kwargs) -class Gemma3nForConditionalGeneration(nn.Module): +class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant): packed_modules_mapping = { "qkv_proj": [ "q_proj",