[Bugfix] Fix shape mismatch assertion error when loading Gemma3n model with BitsAndBytes quantization (#21808)

Signed-off-by: sydarb <areebsyed237@gmail.com>
This commit is contained in:
Areeb Syed 2025-07-30 09:05:21 +05:30 committed by GitHub
parent b917da442b
commit fdde18229e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -167,22 +167,33 @@ class Gemma3nAltUp(nn.Module):
class Gemma3nLaurelBlock(nn.Module):
"""Learned Augmented Residual Layer"""
def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float,
prefix: str):
def __init__(
self,
hidden_size: int,
laurel_rank: int,
rms_norm_eps: float,
*,
quant_config: Optional[QuantizationConfig] = None,
prefix: str,
) -> None:
super().__init__()
self.linear_left = ColumnParallelLinear(
hidden_size,
laurel_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.linear_left",
return_bias=False,
)
self.linear_right = RowParallelLinear(laurel_rank,
hidden_size,
bias=False,
prefix=f"{prefix}.linear_right",
return_bias=False)
self.linear_right = RowParallelLinear(
laurel_rank,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.linear_right",
return_bias=False,
)
self.post_laurel_norm = RMSNorm(
hidden_size=hidden_size,
eps=rms_norm_eps,
@ -417,6 +428,7 @@ class Gemma3nDecoderLayer(nn.Module):
hidden_size=config.hidden_size,
laurel_rank=config.laurel_rank,
rms_norm_eps=config.rms_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.laurel",
)
@ -427,6 +439,7 @@ class Gemma3nDecoderLayer(nn.Module):
config.hidden_size,
config.hidden_size_per_layer_input,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_input_gate",
return_bias=False,
)
@ -434,6 +447,7 @@ class Gemma3nDecoderLayer(nn.Module):
config.hidden_size_per_layer_input,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_projection",
return_bias=False,
)
@ -547,6 +561,7 @@ class Gemma3nTextModel(nn.Module):
bias=False,
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_model_projection",
)
self.per_layer_projection_norm = RMSNorm(
@ -566,6 +581,7 @@ class Gemma3nTextModel(nn.Module):
bias=False,
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_projections",
) for idx in range(1, self.config.altup_num_inputs)
])
@ -576,6 +592,7 @@ class Gemma3nTextModel(nn.Module):
bias=False,
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
) for idx in range(1, self.config.altup_num_inputs)
])