[ROCm] Small functional changes for gptoss (#25201)

Signed-off-by: jpvillam <jpvillam@amd.com>
Co-authored-by: jpvillam <jpvillam@amd.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Juan Villamizar 2025-09-23 18:39:50 -05:00 committed by yewentao256
parent d12433adfc
commit 81ee45298d
3 changed files with 26 additions and 6 deletions

View File

@ -212,12 +212,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
elif current_platform.is_rocm() or (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
hidden_size = round_up(hidden_size, 128)
elif current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64)

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from typing import Any, Callable, Optional
import torch
@ -21,6 +21,10 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout
value_layout_opts: dict[str, Any] = {}
scale_layout_opts: dict[str, Any] = {}
if (current_platform.is_cuda()
and current_platform.is_device_capability(90)
and not is_torch_equal_or_newer("2.8.1")):
@ -28,8 +32,15 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
"Mxfp4 on hopper is running on torch < 2.8.1, "
"this cause swizling to be disabled, which may "
"cause performance degradation. Please upgrade to torch nightly")
value_layout, value_layout_opts = StridedLayout, dict()
scale_layout, scale_layout_opts = StridedLayout, dict()
value_layout = StridedLayout
scale_layout = StridedLayout
elif current_platform.is_rocm():
from triton_kernels.tensor_details.layout import (GFX950MXScaleLayout,
StridedLayout)
from vllm.platforms.rocm import on_gfx950
value_layout = StridedLayout
scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout
else:
value_layout, value_layout_opts = \
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)

View File

@ -118,6 +118,12 @@ def on_gfx9() -> bool:
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
@cache
def on_gfx950() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx950"])
@cache
def use_rocm_custom_paged_attention(
qtype: torch.dtype,