diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index e29f66dca313f..7927bd0d200d8 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -30,16 +30,11 @@ def ref_dynamic_per_token_quant( if quant_dtype == torch.int8 else torch.finfo(quant_dtype) ) - qtype_traits_max = ( - ROCM_FP8FNUZ_MAX - if current_platform.is_rocm() and current_platform.is_fp8_fnuz() - else qtype_traits.max - ) - qtype_traits_min = ( - -ROCM_FP8FNUZ_MAX - if current_platform.is_rocm() and current_platform.is_fp8_fnuz() - else qtype_traits.min + use_fp8fnuz = ( + current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype() ) + qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max + qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) s_512 = as_float32_tensor(512.0)