mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 21:05:01 +08:00
[Bugfix] Fix the lm_head in gpt_bigcode in lora mode (#6357)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
parent
abd4030d94
commit
561b77a0d6
@ -272,12 +272,6 @@ class GPTBigCodeModel(nn.Module):
|
||||
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
||||
|
||||
# LoRA specific attributes
|
||||
embedding_modules = {
|
||||
"wte": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -330,8 +324,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
skip_prefixes = None
|
||||
if self.config.tie_word_embeddings:
|
||||
skip_prefixes = ["lm_head."]
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]),
|
||||
skip_prefixes=skip_prefixes,
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
Loading…
x
Reference in New Issue
Block a user