mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-04 00:57:55 +08:00
[ROCm] Small functional changes for gptoss (#25201)
Signed-off-by: jpvillam <jpvillam@amd.com> Co-authored-by: jpvillam <jpvillam@amd.com>
This commit is contained in:
parent
5e25b12236
commit
bde2a1a8a4
@ -212,12 +212,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
intermediate_size_per_partition_after_pad = round_up(
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
intermediate_size_per_partition, 256)
|
intermediate_size_per_partition, 256)
|
||||||
hidden_size = round_up(hidden_size, 256)
|
hidden_size = round_up(hidden_size, 256)
|
||||||
elif current_platform.is_rocm() or (
|
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
|
||||||
intermediate_size_per_partition_after_pad = round_up(
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
intermediate_size_per_partition, 128)
|
intermediate_size_per_partition, 128)
|
||||||
hidden_size = round_up(hidden_size, 128)
|
hidden_size = round_up(hidden_size, 128)
|
||||||
|
elif current_platform.is_rocm():
|
||||||
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
|
intermediate_size_per_partition, 256)
|
||||||
|
hidden_size = round_up(hidden_size, 256)
|
||||||
else:
|
else:
|
||||||
intermediate_size_per_partition_after_pad = round_up(
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
intermediate_size_per_partition, 64)
|
intermediate_size_per_partition, 64)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -21,6 +21,10 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
|||||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||||
from triton_kernels.tensor_details import layout
|
from triton_kernels.tensor_details import layout
|
||||||
from triton_kernels.tensor_details.layout import StridedLayout
|
from triton_kernels.tensor_details.layout import StridedLayout
|
||||||
|
|
||||||
|
value_layout_opts: dict[str, Any] = {}
|
||||||
|
scale_layout_opts: dict[str, Any] = {}
|
||||||
|
|
||||||
if (current_platform.is_cuda()
|
if (current_platform.is_cuda()
|
||||||
and current_platform.is_device_capability(90)
|
and current_platform.is_device_capability(90)
|
||||||
and not is_torch_equal_or_newer("2.8.1")):
|
and not is_torch_equal_or_newer("2.8.1")):
|
||||||
@ -28,8 +32,15 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
|||||||
"Mxfp4 on hopper is running on torch < 2.8.1, "
|
"Mxfp4 on hopper is running on torch < 2.8.1, "
|
||||||
"this cause swizling to be disabled, which may "
|
"this cause swizling to be disabled, which may "
|
||||||
"cause performance degradation. Please upgrade to torch nightly")
|
"cause performance degradation. Please upgrade to torch nightly")
|
||||||
value_layout, value_layout_opts = StridedLayout, dict()
|
value_layout = StridedLayout
|
||||||
scale_layout, scale_layout_opts = StridedLayout, dict()
|
scale_layout = StridedLayout
|
||||||
|
elif current_platform.is_rocm():
|
||||||
|
from triton_kernels.tensor_details.layout import (GFX950MXScaleLayout,
|
||||||
|
StridedLayout)
|
||||||
|
|
||||||
|
from vllm.platforms.rocm import on_gfx950
|
||||||
|
value_layout = StridedLayout
|
||||||
|
scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout
|
||||||
else:
|
else:
|
||||||
value_layout, value_layout_opts = \
|
value_layout, value_layout_opts = \
|
||||||
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||||
|
|||||||
@ -118,6 +118,12 @@ def on_gfx9() -> bool:
|
|||||||
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def on_gfx950() -> bool:
|
||||||
|
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
|
return any(arch in GPU_ARCH for arch in ["gfx950"])
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def use_rocm_custom_paged_attention(
|
def use_rocm_custom_paged_attention(
|
||||||
qtype: torch.dtype,
|
qtype: torch.dtype,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user