From c488b928a736109cc0a0c824340b928a4f118b2f Mon Sep 17 00:00:00 2001 From: TJian Date: Mon, 14 Jul 2025 00:23:28 -0700 Subject: [PATCH] [ROCm] [Bugfix] [Critical]: Fix mamba compilation bug (#20883) Signed-off-by: tjtanaa Co-authored-by: vllmellm --- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 5f9209979341e..5766fbab4e871 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -7,7 +7,11 @@ #include #include -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#ifdef USE_ROCM + #include // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK +#else + #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#endif #ifndef USE_ROCM #include @@ -320,8 +324,13 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { dim3 grid(params.batch, params.dim / kNRows); auto kernel = &selective_scan_fwd_kernel; if (kSmemSize >= 48 * 1024) { +#ifdef USE_ROCM + C10_HIP_CHECK(hipFuncSetAttribute( + reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); +#else C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); +#endif } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK();