mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-10 18:04:38 +08:00
[Bugfix] Fix shape mismatch assertion error when loading Gemma3n model with BitsAndBytes quantization (#21808)
Signed-off-by: sydarb <areebsyed237@gmail.com>
This commit is contained in:
parent
b917da442b
commit
fdde18229e
@ -167,22 +167,33 @@ class Gemma3nAltUp(nn.Module):
|
||||
class Gemma3nLaurelBlock(nn.Module):
|
||||
"""Learned Augmented Residual Layer"""
|
||||
|
||||
def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float,
|
||||
prefix: str):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
laurel_rank: int,
|
||||
rms_norm_eps: float,
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear_left = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
laurel_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_left",
|
||||
return_bias=False,
|
||||
)
|
||||
self.linear_right = RowParallelLinear(laurel_rank,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.linear_right",
|
||||
return_bias=False)
|
||||
self.linear_right = RowParallelLinear(
|
||||
laurel_rank,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_right",
|
||||
return_bias=False,
|
||||
)
|
||||
self.post_laurel_norm = RMSNorm(
|
||||
hidden_size=hidden_size,
|
||||
eps=rms_norm_eps,
|
||||
@ -417,6 +428,7 @@ class Gemma3nDecoderLayer(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
laurel_rank=config.laurel_rank,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.laurel",
|
||||
)
|
||||
|
||||
@ -427,6 +439,7 @@ class Gemma3nDecoderLayer(nn.Module):
|
||||
config.hidden_size,
|
||||
config.hidden_size_per_layer_input,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.per_layer_input_gate",
|
||||
return_bias=False,
|
||||
)
|
||||
@ -434,6 +447,7 @@ class Gemma3nDecoderLayer(nn.Module):
|
||||
config.hidden_size_per_layer_input,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.per_layer_projection",
|
||||
return_bias=False,
|
||||
)
|
||||
@ -547,6 +561,7 @@ class Gemma3nTextModel(nn.Module):
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.per_layer_model_projection",
|
||||
)
|
||||
self.per_layer_projection_norm = RMSNorm(
|
||||
@ -566,6 +581,7 @@ class Gemma3nTextModel(nn.Module):
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{idx-1}.altup_projections",
|
||||
) for idx in range(1, self.config.altup_num_inputs)
|
||||
])
|
||||
@ -576,6 +592,7 @@ class Gemma3nTextModel(nn.Module):
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
|
||||
) for idx in range(1, self.config.altup_num_inputs)
|
||||
])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user