diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b72f51aa52bfa..711bdfd688501 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -699,8 +699,9 @@ class FusedMoE(torch.nn.Module): tp_rank=self.tp_rank) return - # Case weight scales and zero_points - if ("scale" in weight_name or "zero" in weight_name): + # Case weight scales, zero_points and offset + if ("scale" in weight_name or "zero" in weight_name + or "offset" in weight_name): # load the weight scales and zp based on the quantization scheme # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported