[Bugfix] Fix the lm_head in gpt_bigcode in lora mode (#6357)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
Maximilien de Bayser 2025-05-26 03:52:25 -03:00 committed by GitHub
parent abd4030d94
commit 561b77a0d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -272,12 +272,6 @@ class GPTBigCodeModel(nn.Module):
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {"c_attn": ["c_attn"]} packed_modules_mapping = {"c_attn": ["c_attn"]}
# LoRA specific attributes
embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -330,8 +324,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
skip_prefixes = None
if self.config.tie_word_embeddings:
skip_prefixes = ["lm_head."]
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."]), skip_prefixes=skip_prefixes,
) )
return loader.load_weights(weights) return loader.load_weights(weights)