From 561b77a0d608a9059318d6cff9f0975439880d77 Mon Sep 17 00:00:00 2001 From: Maximilien de Bayser Date: Mon, 26 May 2025 03:52:25 -0300 Subject: [PATCH] [Bugfix] Fix the lm_head in gpt_bigcode in lora mode (#6357) Signed-off-by: Max de Bayser Signed-off-by: Max de Bayser --- vllm/model_executor/models/gpt_bigcode.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 6a1d97bd7b69c..c4ae4fc3c0062 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -272,12 +272,6 @@ class GPTBigCodeModel(nn.Module): class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"c_attn": ["c_attn"]} - # LoRA specific attributes - embedding_modules = { - "wte": "input_embeddings", - "lm_head": "output_embeddings", - } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -330,8 +324,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = None + if self.config.tie_word_embeddings: + skip_prefixes = ["lm_head."] loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."]), + skip_prefixes=skip_prefixes, ) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights)