[Bugfix] Fix Mamba model initialization and MLP Speculator weights loading (#10456)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-11-20 13:04:05 +08:00 committed by GitHub
parent 9e05252b46
commit ad44437ba3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 7 deletions

View File

@ -1,5 +1,5 @@
"""PyTorch MAMBA model.""" """PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -243,10 +243,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "A_log" in name: if "A_log" in name:
name = name.replace("A_log", "A") name = name.replace("A_log", "A")
@ -258,5 +256,3 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@ -193,7 +193,8 @@ class MLPSpeculator(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
param = params_dict.get(name.replace("speculator.", "")) name = name.replace("speculator.", "")
param = params_dict.get(name)
if param is not None: if param is not None:
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)