diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 6a1d97bd7b69c..c4ae4fc3c0062 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -272,12 +272,6 @@ class GPTBigCodeModel(nn.Module): class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"c_attn": ["c_attn"]} - # LoRA specific attributes - embedding_modules = { - "wte": "input_embeddings", - "lm_head": "output_embeddings", - } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -330,8 +324,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = None + if self.config.tie_word_embeddings: + skip_prefixes = ["lm_head."] loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."]), + skip_prefixes=skip_prefixes, ) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights)