From ed3a1d2106a42b1522013b304e5dc9ca172385ec Mon Sep 17 00:00:00 2001 From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Date: Tue, 6 May 2025 20:39:48 -0400 Subject: [PATCH] [ROCm] fix num_stages for default moe config to avoid triton OutOfResource error (#17744) Signed-off-by: Hongxia Yang --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 075b98d148607..f6305822c2dc1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -747,13 +747,15 @@ def get_default_config( if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # BLOCK_SIZE_K must be divisible by block_shape[1] + # num_stages=3 can cause triton.runtime.errors.OutOfResources + # on ROCm, set it to 2 instead. config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": block_shape[0], "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3, + "num_stages": 3 if not current_platform.is_rocm() else 2, } elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: # moe wna16 kernels