mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 00:49:10 +08:00
[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility (#28500)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
c5f10cc139
commit
edb59a9470
@ -35,10 +35,12 @@
|
|||||||
CHECK_TH_CUDA(x); \
|
CHECK_TH_CUDA(x); \
|
||||||
CHECK_CONTIGUOUS(x)
|
CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
#define FINAL_MASK 0xffffffffffffffffULL
|
||||||
|
#else
|
||||||
#define FINAL_MASK 0xffffffff
|
#define FINAL_MASK 0xffffffff
|
||||||
|
#endif
|
||||||
|
|
||||||
// TODO: suport for AMD ROCM platform
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
namespace tensorrt_llm::common {
|
namespace tensorrt_llm::common {
|
||||||
template <typename T, int num>
|
template <typename T, int num>
|
||||||
struct packed_as;
|
struct packed_as;
|
||||||
@ -97,7 +99,7 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
int64_t const* position_ids, // Position IDs for RoPE
|
int64_t const* position_ids, // Position IDs for RoPE
|
||||||
int const num_tokens // Number of tokens
|
int const num_tokens // Number of tokens
|
||||||
) {
|
) {
|
||||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||||
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
|
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
|
||||||
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
|
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
|
||||||
return;
|
return;
|
||||||
@ -247,7 +249,7 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
// values.
|
// values.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < numElemsPerThread; i++) {
|
for (int i = 0; i < numElemsPerThread; i++) {
|
||||||
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
|
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
|
||||||
if (laneId < 16) {
|
if (laneId < 16) {
|
||||||
elements2[i] = -elements2[i];
|
elements2[i] = -elements2[i];
|
||||||
}
|
}
|
||||||
@ -280,7 +282,7 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
|
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@ -414,5 +416,3 @@ void fused_qk_norm_rope(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // not USE_ROCM
|
|
||||||
@ -175,7 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"float epsilon) -> ()");
|
"float epsilon) -> ()");
|
||||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
// Function for fused QK Norm and RoPE
|
// Function for fused QK Norm and RoPE
|
||||||
ops.def(
|
ops.def(
|
||||||
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
|
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
|
||||||
@ -183,7 +182,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
|
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
|
||||||
"bool is_neox, Tensor position_ids) -> ()");
|
"bool is_neox, Tensor position_ids) -> ()");
|
||||||
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
|
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
|
||||||
#endif
|
|
||||||
|
|
||||||
// Apply repetition penalties to logits in-place
|
// Apply repetition penalties to logits in-place
|
||||||
ops.def(
|
ops.def(
|
||||||
|
|||||||
@ -67,9 +67,9 @@ struct _typeConvert<c10::Half> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || defined(USE_ROCM)
|
||||||
// CUDA_ARCH < 800 does not have BF16 support
|
// CUDA_ARCH < 800 does not have BF16 support
|
||||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
// ROCm 7.0+ supports bfloat16
|
||||||
template <>
|
template <>
|
||||||
struct _typeConvert<c10::BFloat16> {
|
struct _typeConvert<c10::BFloat16> {
|
||||||
static constexpr bool exists = true;
|
static constexpr bool exists = true;
|
||||||
@ -89,7 +89,8 @@ struct _typeConvert<c10::BFloat16> {
|
|||||||
return __float22bfloat162_rn(x);
|
return __float22bfloat162_rn(x);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) ||
|
||||||
|
// defined(USE_ROCM)
|
||||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||||
// 12000))
|
// 12000))
|
||||||
|
|
||||||
|
|||||||
@ -113,8 +113,8 @@ class QKNormRoPETestModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("enable_rope_custom_op", [True])
|
@pytest.mark.parametrize("enable_rope_custom_op", [True])
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not current_platform.is_cuda(),
|
not current_platform.is_cuda_alike(),
|
||||||
reason="Only test on cuda platform",
|
reason="Only test on cuda and rocm platform",
|
||||||
)
|
)
|
||||||
def test_qk_norm_rope_fusion(
|
def test_qk_norm_rope_fusion(
|
||||||
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
||||||
|
|||||||
@ -44,8 +44,8 @@ def _apply_qk_norm_rope(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not current_platform.is_cuda(),
|
not current_platform.is_cuda_alike(),
|
||||||
reason="fused_qk_norm_rope custom op requires cuda platform",
|
reason="fused_qk_norm_rope custom op requires cuda and rocm platform",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
|||||||
@ -184,10 +184,10 @@ class PassConfig:
|
|||||||
"Fusion enabled but reshape elimination disabled. "
|
"Fusion enabled but reshape elimination disabled. "
|
||||||
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
||||||
)
|
)
|
||||||
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda():
|
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"QK Norm + RoPE fusion enabled but the current platform is not "
|
"QK Norm + RoPE fusion enabled but the current platform is not "
|
||||||
"CUDA. The fusion will be disabled."
|
"CUDA or ROCm. The fusion will be disabled."
|
||||||
)
|
)
|
||||||
self.enable_qk_norm_rope_fusion = False
|
self.enable_qk_norm_rope_fusion = False
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user