mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:26:12 +08:00
Add GPTQ support for Gemma (#3200)
This commit is contained in:
parent
4cb3b924cd
commit
d3c04b6a39
@ -325,11 +325,17 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
if shard_name not in name:
|
if shard_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(shard_name, param_name)
|
name = name.replace(shard_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
# GemmaRMSNorm is different from Llama's in that it multiplies
|
# GemmaRMSNorm is different from Llama's in that it multiplies
|
||||||
# (1 + weight) to the output, instead of just weight.
|
# (1 + weight) to the output, instead of just weight.
|
||||||
if "norm.weight" in name:
|
if "norm.weight" in name:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user