diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 13473857b3309..e61b401a78a2b 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -49,7 +49,6 @@ from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) -from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -94,30 +93,6 @@ class MixtralMLP(nn.Module): return current_hidden_states -class DummyModule(nn.Module): - - def __init__(self) -> None: - super().__init__() - - self.w1 = nn.Linear(0, 0, bias=False) - self.w2 = nn.Linear(0, 0, bias=False) - self.w3 = nn.Linear(0, 0, bias=False) - - set_weight_attrs(self.w1.weight, - {"weight_loader": self.dummy_weight_loader}) - set_weight_attrs(self.w2.weight, - {"weight_loader": self.dummy_weight_loader}) - set_weight_attrs(self.w3.weight, - {"weight_loader": self.dummy_weight_loader}) - - def forward(self, *args, **kwargs) -> None: - raise NotImplementedError() - - def dummy_weight_loader(self, *args, **kwargs) -> None: # pylint: disable=unused-argument - # Noop - return - - class MixtralMoE(nn.Module): def __init__( @@ -147,7 +122,7 @@ class MixtralMoE(nn.Module): config.hidden_size, config.intermediate_size, linear_method=linear_method) - if idx in self.expert_indicies else DummyModule() + if idx in self.expert_indicies else None for idx in range(self.num_total_experts) ]) self.gate = ReplicatedLinear(config.hidden_size, @@ -427,6 +402,10 @@ class MixtralForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip experts that are not assigned to this worker. + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)