[Bugfix] Fix Mamba model initialization and MLP Speculator weights loading (#10456)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-11-20 13:04:05 +08:00 committed by GitHub
parent 9e05252b46
commit ad44437ba3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 7 deletions

View File

@ -1,5 +1,5 @@
"""PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@ -243,10 +243,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
@ -258,5 +256,3 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@ -193,7 +193,8 @@ class MLPSpeculator(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
param = params_dict.get(name.replace("speculator.", ""))
name = name.replace("speculator.", "")
param = params_dict.get(name)
if param is not None:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)