mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 00:37:08 +08:00
[BugFix] Fix weight loading for Mixtral with TP (#2208)
This commit is contained in:
parent
de60a3fb93
commit
ba4f826738
@ -49,7 +49,6 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -94,30 +93,6 @@ class MixtralMLP(nn.Module):
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class DummyModule(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(0, 0, bias=False)
|
||||
self.w2 = nn.Linear(0, 0, bias=False)
|
||||
self.w3 = nn.Linear(0, 0, bias=False)
|
||||
|
||||
set_weight_attrs(self.w1.weight,
|
||||
{"weight_loader": self.dummy_weight_loader})
|
||||
set_weight_attrs(self.w2.weight,
|
||||
{"weight_loader": self.dummy_weight_loader})
|
||||
set_weight_attrs(self.w3.weight,
|
||||
{"weight_loader": self.dummy_weight_loader})
|
||||
|
||||
def forward(self, *args, **kwargs) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def dummy_weight_loader(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
|
||||
# Noop
|
||||
return
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -147,7 +122,7 @@ class MixtralMoE(nn.Module):
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
linear_method=linear_method)
|
||||
if idx in self.expert_indicies else DummyModule()
|
||||
if idx in self.expert_indicies else None
|
||||
for idx in range(self.num_total_experts)
|
||||
])
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
@ -427,6 +402,10 @@ class MixtralForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if ("block_sparse_moe.experts." in name
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user