Add GPTQ support for Gemma (#3200)

This commit is contained in:
TechxGenus 2024-03-07 08:19:14 +08:00 committed by GitHub
parent 4cb3b924cd
commit d3c04b6a39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -325,11 +325,17 @@ class GemmaForCausalLM(nn.Module):
if shard_name not in name: if shard_name not in name:
continue continue
name = name.replace(shard_name, param_name) name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies # GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight. # (1 + weight) to the output, instead of just weight.
if "norm.weight" in name: if "norm.weight" in name: