From f065de4e88e7651f1f68fc4c0ca95b79d4577b89 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 12 May 2025 19:02:07 -0400 Subject: [PATCH] Fix FBGEMM integration (#18002) Signed-off-by: mgoin --- .../layers/quantization/fbgemm_fp8.py | 4 +++- .../quantization/utils/marlin_utils_fp8.py | 20 ++++++++----------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 1fa2b3a8eeeaa..163aabb45c648 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -63,7 +63,9 @@ class FBGEMMFp8Config(QuantizationConfig): def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): - if is_layer_skipped(prefix, self.ignore_list): + if is_layer_skipped(prefix=prefix, + ignored_layers=self.ignore_list, + fused_mapping=self.packed_modules_mapping): return UnquantizedLinearMethod() return FBGEMMFp8LinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 3080d2a0da876..08812debd321b 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -86,6 +86,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, part_size_n = layer.output_size_per_partition part_size_k = layer.input_size_per_partition + weight_block_size = getattr(layer, "weight_block_size", None) if size_k_first: assert layer.weight.shape == (part_size_k, part_size_n) @@ -119,14 +120,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, scales = layer.weight_scale_inv.to(layer.orig_dtype) del layer.weight_scale_inv - if layer.weight_block_size is None: - group_size = -1 - else: - group_size = layer.weight_block_size[1] + group_size = -1 if weight_block_size is None else weight_block_size[1] # marlin kernel only support channel-wise and group-wise quantization # we need to convert the scales - if layer.weight_block_size is None: + if weight_block_size is None: if scales.nelement() == 1: # tensor-wise quantization -> channel-wise quantization # (1, 1) =>(repeat)=> (1, size_n) @@ -149,7 +147,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, # =>(repeat)=> (size_k // block_size[1], size_n) if not size_k_first: scales = scales.T.contiguous() - block_n = layer.weight_block_size[0] + block_n = weight_block_size[0] scales = scales.repeat_interleave(block_n, 1) # size_n may not divisible by block_size[0] scales = scales[:, :part_size_n] @@ -173,6 +171,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, e = layer.num_experts k = layer.hidden_size n = layer.intermediate_size_per_partition + weight_block_size = getattr(layer, "weight_block_size", None) # WORKSPACE device = layer.w13_weight.device @@ -213,10 +212,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, # WEIGHT SCALES # Permute scales - if layer.weight_block_size is None: - group_size = -1 - else: - group_size = layer.weight_block_size[1] + group_size = -1 if weight_block_size is None else weight_block_size[1] for name in ["w13", "w2"]: if name + "_weight_scale" in dir(layer): @@ -236,7 +232,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, # marlin kernel only support channel-wise and group-wise quantization # we need to convert the scales - if layer.weight_block_size is None: + if weight_block_size is None: if scales.nelement() == e: # tensor-wise quantization -> channel-wise quantization # (e, 1, 1) =>(repeat)=> (e, 1, size_n) @@ -259,7 +255,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, # =>(repeat)=> (e, size_k // block_size[1], size_n) if not size_k_first: scales = scales.permute(0, 2, 1) - block_n = layer.weight_block_size[0] + block_n = weight_block_size[0] scales = scales.repeat_interleave(block_n, 2) # size_n may not divisible by block_size[0] scales = scales[..., :size_n].contiguous()