From 9471879bd42f38c4b1151539e05d7a3a3ed5c3c5 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 29 Sep 2025 21:15:19 -0400 Subject: [PATCH] [Bug] Fix Weight Loading for Block FP8 Cutlass SM90 (#25909) Signed-off-by: yewentao256 Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: simon-mo --- .../model_executor/layers/quantization/utils/fp8_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index b2548e66827d3..828111dc299ec 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -911,15 +911,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module, # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to # requantize the weight and input to the specific scale # at the same time. - if is_deep_gemm_e8m0_used(): + should_use_deepgemm = should_use_deepgemm_for_fp8_linear( + layer.orig_dtype, layer.weight) + if is_deep_gemm_e8m0_used() and should_use_deepgemm: block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace(layer.weight.data, layer.weight_scale.data, block_sz) # SM90 Block FP8 CUTLASS requires row-major weight scales elif (current_platform.is_device_capability(90) - and cutlass_block_fp8_supported - and not should_use_deepgemm_for_fp8_linear(torch.bfloat16, - layer.weight)): + and cutlass_block_fp8_supported and not should_use_deepgemm): layer.weight_scale = torch.nn.Parameter( layer.weight_scale.data.T.contiguous(), requires_grad=False)