From f533b5837fa67f53957a12387a01067f9edef0d8 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Mon, 24 Mar 2025 19:45:30 -0400 Subject: [PATCH] [ROCm][Kernel] MoE weights padding (#14454) Signed-off-by: Gregory Shtrasberg Signed-off-by: charlifu Co-authored-by: charlifu --- tests/kernels/test_moe.py | 45 ++++++++++++++----- vllm/envs.py | 5 +++ .../layers/fused_moe/fused_moe.py | 6 +-- vllm/model_executor/layers/fused_moe/layer.py | 19 ++++++++ .../model_executor/layers/quantization/fp8.py | 6 +-- 5 files changed, 65 insertions(+), 16 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 52893f4329ec6..653d2734afe89 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -3,8 +3,11 @@ Run `pytest tests/kernels/test_moe.py`. """ + import pytest import torch +from torch.nn import Parameter +from torch.nn import functional as F from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock @@ -37,6 +40,7 @@ TOP_KS = [2, 6] @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("padding", [True, False]) def test_fused_moe( m: int, n: int, @@ -45,6 +49,7 @@ def test_fused_moe( topk: int, ep_size: int, dtype: torch.dtype, + padding: bool, ): a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -65,16 +70,7 @@ def test_fused_moe( else: e_map = None - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk, e_map) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) iterative_output = iterative_moe(a, w1, w2, @@ -83,6 +79,23 @@ def test_fused_moe( global_num_experts=e, expert_map=e_map, renormalize=False) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(iterative_output, torch_output, atol=2e-2, @@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("padding", [True, False]) @torch.inference_mode() -def test_mixtral_moe(dtype: torch.dtype): +def test_mixtral_moe(dtype: torch.dtype, padding: bool): """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" @@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype): # vLLM uses 1D query [num_tokens, hidden_dim] vllm_inputs = hf_inputs.flatten(0, 1) + # Pad the weight if moe padding is enabled + if padding: + vllm_moe.experts.w13_weight = Parameter(F.pad( + vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128], + requires_grad=False) + torch.cuda.empty_cache() + vllm_moe.experts.w2_weight = Parameter(F.pad( + vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], + requires_grad=False) + torch.cuda.empty_cache() + # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) vllm_states = vllm_moe.forward(vllm_inputs) diff --git a/vllm/envs.py b/vllm/envs.py index e97d37017b441..f0fd20c70e3b2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -75,6 +75,7 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_FP8_PADDING: bool = True + VLLM_ROCM_MOE_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -520,6 +521,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), + # Pad the weights for the moe kernel + "VLLM_ROCM_MOE_PADDING": + lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), + # Divisor for dynamic query scale factor calculation for FP8 KV Cache "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4143ccce52557..4de020ff81c0e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -800,7 +800,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, expert_ids, num_tokens_post_padded, B.shape[1], - A.shape[1], + B.shape[2], EM, topk_ids.numel(), A.stride(0), @@ -1322,8 +1322,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 ] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 739d216e6e80c..bc134f676159e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -5,6 +5,7 @@ from enum import Enum from typing import Callable, List, Optional, Tuple import torch +import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter from vllm import envs @@ -96,9 +97,27 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) + layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( + layer.w13_weight.data), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( + layer.w2_weight.data), + requires_grad=False) + if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: import intel_extension_for_pytorch as ipex diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2d5d8e6adc9c1..d92b0931a6ee0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -255,7 +255,7 @@ class Fp8LinearMethod(LinearMethodBase): else: layer.register_parameter("input_scale", None) - def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor: + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which # can benefit from tensors located far enough from one another in memory if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() @@ -279,7 +279,7 @@ class Fp8LinearMethod(LinearMethodBase): weight = layer.weight.data weight_scale_inv = layer.weight_scale_inv.data - weight = self.add_padding_to_weight(weight) + weight = self._maybe_pad_weight(weight) # Torch.compile cannot use Parameter subclasses. layer.weight = Parameter(weight, requires_grad=False) @@ -343,7 +343,7 @@ class Fp8LinearMethod(LinearMethodBase): logical_widths=layer.logical_widths, ) - weight = self.add_padding_to_weight(weight) + weight = self._maybe_pad_weight(weight) # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)