diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index cbc46810a26a6..d0c8b3d1a3093 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -39,15 +39,15 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): value_layout = StridedLayout scale_layout = StridedLayout elif current_platform.is_rocm(): - from triton_kernels.tensor_details.layout import ( - GFX950MXScaleLayout, - StridedLayout, - ) - from vllm.platforms.rocm import on_gfx950 value_layout = StridedLayout - scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout + if on_gfx950(): + from triton_kernels.tensor_details.layout import GFX950MXScaleLayout + + scale_layout = GFX950MXScaleLayout + else: + scale_layout = StridedLayout else: value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( mx_axis=1