diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index fd91c78c7cc4..28dba091f430 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -43,6 +43,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( _can_support_mxfp4, _swizzle_mxfp4, + get_padding_alignment, ) from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.utils import set_weight_attrs @@ -282,10 +283,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ) hidden_size = round_up(hidden_size, 128) elif current_platform.is_rocm(): + pad_align = get_padding_alignment() intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 256 + intermediate_size_per_partition, pad_align ) - hidden_size = round_up(hidden_size, 256) + hidden_size = round_up(hidden_size, pad_align) else: intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 64 diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 5e87cadfb107..34a31bcf6a74 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -7,6 +7,7 @@ import torch from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.triton_utils import triton from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer logger = init_logger(__name__) @@ -99,6 +100,14 @@ def _can_support_mxfp4( ) +def get_padding_alignment(): + return ( + 256 + if triton.runtime.driver.active.get_current_target().arch in ("gfx950",) + else 128 + ) + + def _dequant_mxfp4( x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype ) -> torch.Tensor: