[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility (#28500)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian 2025-11-12 05:01:14 -08:00 committed by GitHub
parent c5f10cc139
commit edb59a9470
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 37 additions and 38 deletions

View File

@ -35,10 +35,12 @@
CHECK_TH_CUDA(x); \ CHECK_TH_CUDA(x); \
CHECK_CONTIGUOUS(x) CHECK_CONTIGUOUS(x)
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#else
#define FINAL_MASK 0xffffffff #define FINAL_MASK 0xffffffff
#endif
// TODO: suport for AMD ROCM platform
#ifndef USE_ROCM
namespace tensorrt_llm::common { namespace tensorrt_llm::common {
template <typename T, int num> template <typename T, int num>
struct packed_as; struct packed_as;
@ -97,7 +99,7 @@ __global__ void fusedQKNormRopeKernel(
int64_t const* position_ids, // Position IDs for RoPE int64_t const* position_ids, // Position IDs for RoPE
int const num_tokens // Number of tokens int const num_tokens // Number of tokens
) { ) {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 #if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) || if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
std::is_same_v<scalar_t_cache, c10::BFloat16>) { std::is_same_v<scalar_t_cache, c10::BFloat16>) {
return; return;
@ -247,7 +249,7 @@ __global__ void fusedQKNormRopeKernel(
// values. // values.
#pragma unroll #pragma unroll
for (int i = 0; i < numElemsPerThread; i++) { for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16); elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
if (laneId < 16) { if (laneId < 16) {
elements2[i] = -elements2[i]; elements2[i] = -elements2[i];
} }
@ -280,7 +282,7 @@ __global__ void fusedQKNormRopeKernel(
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec; *reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
} }
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 #if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
} }
#endif #endif
} }
@ -414,5 +416,3 @@ void fused_qk_norm_rope(
}); });
}); });
} }
#endif // not USE_ROCM

View File

@ -175,7 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"); "float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
#ifndef USE_ROCM
// Function for fused QK Norm and RoPE // Function for fused QK Norm and RoPE
ops.def( ops.def(
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, " "fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
@ -183,7 +182,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, " "Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
"bool is_neox, Tensor position_ids) -> ()"); "bool is_neox, Tensor position_ids) -> ()");
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
#endif
// Apply repetition penalties to logits in-place // Apply repetition penalties to logits in-place
ops.def( ops.def(

View File

@ -67,9 +67,9 @@ struct _typeConvert<c10::Half> {
} }
}; };
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || defined(USE_ROCM)
// CUDA_ARCH < 800 does not have BF16 support // CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely // ROCm 7.0+ supports bfloat16
template <> template <>
struct _typeConvert<c10::BFloat16> { struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true; static constexpr bool exists = true;
@ -89,7 +89,8 @@ struct _typeConvert<c10::BFloat16> {
return __float22bfloat162_rn(x); return __float22bfloat162_rn(x);
} }
}; };
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) ||
// defined(USE_ROCM)
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000)) // 12000))

View File

@ -113,8 +113,8 @@ class QKNormRoPETestModel(torch.nn.Module):
@pytest.mark.parametrize("enable_rope_custom_op", [True]) @pytest.mark.parametrize("enable_rope_custom_op", [True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_cuda(), not current_platform.is_cuda_alike(),
reason="Only test on cuda platform", reason="Only test on cuda and rocm platform",
) )
def test_qk_norm_rope_fusion( def test_qk_norm_rope_fusion(
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype

View File

@ -44,8 +44,8 @@ def _apply_qk_norm_rope(
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_cuda(), not current_platform.is_cuda_alike(),
reason="fused_qk_norm_rope custom op requires cuda platform", reason="fused_qk_norm_rope custom op requires cuda and rocm platform",
) )
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)

View File

@ -184,10 +184,10 @@ class PassConfig:
"Fusion enabled but reshape elimination disabled. " "Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work" "Allreduce + rms norm + quant (fp8) fusion might not work"
) )
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda(): if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
logger.warning_once( logger.warning_once(
"QK Norm + RoPE fusion enabled but the current platform is not " "QK Norm + RoPE fusion enabled but the current platform is not "
"CUDA. The fusion will be disabled." "CUDA or ROCm. The fusion will be disabled."
) )
self.enable_qk_norm_rope_fusion = False self.enable_qk_norm_rope_fusion = False