diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2818674775cc..3d94626e5d8c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -450,10 +450,10 @@ class Fp8LinearMethod(LinearMethodBase): # Activations not quantized for marlin. del layer.input_scale - # On B200, if E8M0 for DeepGemm is used, we need to + # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to # requantize the weight and input to the specific scale # at the same time. - if is_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used() and self.block_quant: assert layer.weight_block_size is not None block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace( @@ -905,7 +905,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): del layer.w13_input_scale del layer.w2_input_scale - if is_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used() and self.block_quant: assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. block_sz = tuple(layer.weight_block_size)