mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:55:55 +08:00
[Bugfix] Fix the lm_head in gpt_bigcode in lora mode (#6357)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
parent
abd4030d94
commit
561b77a0d6
@ -272,12 +272,6 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
||||||
|
|
||||||
# LoRA specific attributes
|
|
||||||
embedding_modules = {
|
|
||||||
"wte": "input_embeddings",
|
|
||||||
"lm_head": "output_embeddings",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
@ -330,8 +324,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
|
skip_prefixes = None
|
||||||
|
if self.config.tie_word_embeddings:
|
||||||
|
skip_prefixes = ["lm_head."]
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
self,
|
self,
|
||||||
skip_prefixes=(["lm_head."]),
|
skip_prefixes=skip_prefixes,
|
||||||
)
|
)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
Loading…
x
Reference in New Issue
Block a user