From 81ee45298d741f97050d0d7acdbaf164ac62e0da Mon Sep 17 00:00:00 2001 From: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Date: Tue, 23 Sep 2025 18:39:50 -0500 Subject: [PATCH] [ROCm] Small functional changes for gptoss (#25201) Signed-off-by: jpvillam Co-authored-by: jpvillam Signed-off-by: yewentao256 --- .../model_executor/layers/quantization/mxfp4.py | 9 ++++++--- .../layers/quantization/utils/mxfp4_utils.py | 17 ++++++++++++++--- vllm/platforms/rocm.py | 6 ++++++ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index a71c8d32a22c7..b710f6ee249b1 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -212,12 +212,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256) hidden_size = round_up(hidden_size, 256) - elif current_platform.is_rocm() or ( - self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 128) hidden_size = round_up(hidden_size, 128) + elif current_platform.is_rocm(): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 256) + hidden_size = round_up(hidden_size, 256) else: intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 64) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index d61ca7ad5dc4e..fb1d041f34499 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from typing import Any, Callable, Optional import torch @@ -21,6 +21,10 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout from triton_kernels.tensor_details.layout import StridedLayout + + value_layout_opts: dict[str, Any] = {} + scale_layout_opts: dict[str, Any] = {} + if (current_platform.is_cuda() and current_platform.is_device_capability(90) and not is_torch_equal_or_newer("2.8.1")): @@ -28,8 +32,15 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): "Mxfp4 on hopper is running on torch < 2.8.1, " "this cause swizling to be disabled, which may " "cause performance degradation. Please upgrade to torch nightly") - value_layout, value_layout_opts = StridedLayout, dict() - scale_layout, scale_layout_opts = StridedLayout, dict() + value_layout = StridedLayout + scale_layout = StridedLayout + elif current_platform.is_rocm(): + from triton_kernels.tensor_details.layout import (GFX950MXScaleLayout, + StridedLayout) + + from vllm.platforms.rocm import on_gfx950 + value_layout = StridedLayout + scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout else: value_layout, value_layout_opts = \ layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 878718489fa88..942fd1973f4f3 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -118,6 +118,12 @@ def on_gfx9() -> bool: return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) +@cache +def on_gfx950() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx950"]) + + @cache def use_rocm_custom_paged_attention( qtype: torch.dtype,