From 35d801f13fa5bd79ae74707388b1fa4e1caf9ba5 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 10 Nov 2025 19:08:40 -0500 Subject: [PATCH] [Feature] Refactor batch invariant fp8 DeepGEMM (#27606) Signed-off-by: yewentao256 --- .../model_executor/layers/quantization/fp8.py | 98 +++---------------- 1 file changed, 11 insertions(+), 87 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f5fc750baaea7..c7d5b251cf4ef 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -43,7 +43,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -95,11 +94,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import ( - fp8_gemm_nt, get_col_major_tma_aligned_tensor, is_deep_gemm_e8m0_used, is_deep_gemm_supported, - should_use_deepgemm_for_fp8_linear, ) from vllm.utils.flashinfer import has_flashinfer_moe from vllm.utils.import_utils import has_deep_gemm @@ -554,83 +551,19 @@ class Fp8LinearMethod(LinearMethodBase): # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. if vllm_is_batch_invariant(): - # Call is_deep_gemm_supported() ahead of time for torch.compile - # dynamo has trouble tracing through - if self.block_quant and should_use_deepgemm_for_fp8_linear( - torch.bfloat16, layer.weight, self.use_deep_gemm - ): - # use group quant consistent with block size across K - assert self.act_q_group_shape is not None - q_input, input_scale = QuantFP8( - False, - self.act_q_group_shape, - column_major_scales=True, - )(x) - - output_2d = torch.empty( - (q_input.shape[0], layer.weight.shape[0]), - dtype=torch.bfloat16, - device=q_input.device, - ) - fp8_gemm_nt( - (q_input, input_scale), - (layer.weight, layer.weight_scale), - output_2d, - ) - if bias is not None: - output_2d = output_2d + bias - return output_2d - - # Dequantize FP8 weights to BF16 - weight_fp8 = layer.weight.to(torch.bfloat16) - weight_scale = layer.weight_scale.to(torch.bfloat16) - - # Handle different quantization granularities if self.block_quant: - # Block-wise quantization: - # - Weight is NOT transposed, shape is [N, K] (output_size, input_size) - # - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!) assert self.weight_block_size is not None - block_n, block_k = self.weight_block_size # Note: order is [N, K] - - N, K = weight_fp8.shape - - # determine expected number of blocks along N and K - num_blocks_n = (N + block_n - 1) // block_n - num_blocks_k = (K + block_k - 1) // block_k - - # scale layout may be [num_blocks_n, num_blocks_k] - # or [num_blocks_k, num_blocks_n] depending on backend - if weight_scale.dim() != 2: - raise RuntimeError( - f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}" - ) - - scale_rows, scale_cols = weight_scale.shape - if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n): - if num_blocks_n == num_blocks_k: - # ambiguous square case, warn and skip transpose - logger.warning( - "Batch-invariant FP8: square block-scale %dx%d; " - "skipping transpose to avoid misorientation.", - scale_rows, - scale_cols, - ) - else: - # clear KN -> transpose to NK - weight_scale = weight_scale.t() - - # Expand scale to match weight dimensions - # scale_expanded should have shape [N, K] - scale_expanded = weight_scale.repeat_interleave( - block_n, dim=0 - ).repeat_interleave(block_k, dim=1) - # Trim to exact weight size (in case of padding) - scale_expanded = scale_expanded[:N, :K] - weight_bf16 = weight_fp8 * scale_expanded + return self.w8a8_block_fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + ) else: - # Per-tensor quantization: weight IS transposed to [K, N] - # scale should be scalar or [1] or per-output-channel [N] + # per-tensor/channel: dequant to BF16 and run GEMM + weight_fp8 = layer.weight.to(torch.bfloat16) + weight_scale = layer.weight_scale.to(torch.bfloat16) if weight_scale.numel() == 1: # Per-tensor: simple scalar multiplication weight_bf16 = weight_fp8 * weight_scale @@ -649,16 +582,7 @@ class Fp8LinearMethod(LinearMethodBase): else: # Fallback weight_bf16 = weight_fp8 * weight_scale - - # For block quant, weight is [N, K], for per-tensor it's [K, N] - # F.linear expects weight to be [N, K], so: - if self.block_quant: - # Already in correct shape [N, K] - output = torch.nn.functional.linear(x, weight_bf16, bias) - else: - # Need to transpose back: [K, N] -> [N, K] - output = torch.nn.functional.linear(x, weight_bf16.t(), bias) - return output + return torch.nn.functional.linear(x, weight_bf16.t(), bias) if self.use_marlin: return apply_fp8_marlin_linear(