diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index b2548e66827d..828111dc299e 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -911,15 +911,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module, # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to # requantize the weight and input to the specific scale # at the same time. - if is_deep_gemm_e8m0_used(): + should_use_deepgemm = should_use_deepgemm_for_fp8_linear( + layer.orig_dtype, layer.weight) + if is_deep_gemm_e8m0_used() and should_use_deepgemm: block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace(layer.weight.data, layer.weight_scale.data, block_sz) # SM90 Block FP8 CUTLASS requires row-major weight scales elif (current_platform.is_device_capability(90) - and cutlass_block_fp8_supported - and not should_use_deepgemm_for_fp8_linear(torch.bfloat16, - layer.weight)): + and cutlass_block_fp8_supported and not should_use_deepgemm): layer.weight_scale = torch.nn.Parameter( layer.weight_scale.data.T.contiguous(), requires_grad=False)