diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 168665cc29655..d0880103d4e86 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -167,22 +167,33 @@ class Gemma3nAltUp(nn.Module): class Gemma3nLaurelBlock(nn.Module): """Learned Augmented Residual Layer""" - def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float, - prefix: str): + def __init__( + self, + hidden_size: int, + laurel_rank: int, + rms_norm_eps: float, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str, + ) -> None: super().__init__() self.linear_left = ColumnParallelLinear( hidden_size, laurel_rank, bias=False, + quant_config=quant_config, prefix=f"{prefix}.linear_left", return_bias=False, ) - self.linear_right = RowParallelLinear(laurel_rank, - hidden_size, - bias=False, - prefix=f"{prefix}.linear_right", - return_bias=False) + self.linear_right = RowParallelLinear( + laurel_rank, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.linear_right", + return_bias=False, + ) self.post_laurel_norm = RMSNorm( hidden_size=hidden_size, eps=rms_norm_eps, @@ -417,6 +428,7 @@ class Gemma3nDecoderLayer(nn.Module): hidden_size=config.hidden_size, laurel_rank=config.laurel_rank, rms_norm_eps=config.rms_norm_eps, + quant_config=quant_config, prefix=f"{prefix}.laurel", ) @@ -427,6 +439,7 @@ class Gemma3nDecoderLayer(nn.Module): config.hidden_size, config.hidden_size_per_layer_input, bias=False, + quant_config=quant_config, prefix=f"{prefix}.per_layer_input_gate", return_bias=False, ) @@ -434,6 +447,7 @@ class Gemma3nDecoderLayer(nn.Module): config.hidden_size_per_layer_input, config.hidden_size, bias=False, + quant_config=quant_config, prefix=f"{prefix}.per_layer_projection", return_bias=False, ) @@ -547,6 +561,7 @@ class Gemma3nTextModel(nn.Module): bias=False, gather_output=True, return_bias=False, + quant_config=quant_config, prefix=f"{prefix}.per_layer_model_projection", ) self.per_layer_projection_norm = RMSNorm( @@ -566,6 +581,7 @@ class Gemma3nTextModel(nn.Module): bias=False, gather_output=True, return_bias=False, + quant_config=quant_config, prefix=f"{prefix}.{idx-1}.altup_projections", ) for idx in range(1, self.config.altup_num_inputs) ]) @@ -576,6 +592,7 @@ class Gemma3nTextModel(nn.Module): bias=False, gather_output=True, return_bias=False, + quant_config=quant_config, prefix=f"{prefix}.{idx-1}.altup_unembed_projections", ) for idx in range(1, self.config.altup_num_inputs) ])