mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:17:16 +08:00
[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility (#28500)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
c5f10cc139
commit
edb59a9470
@ -35,10 +35,12 @@
|
|||||||
CHECK_TH_CUDA(x); \
|
CHECK_TH_CUDA(x); \
|
||||||
CHECK_CONTIGUOUS(x)
|
CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
#define FINAL_MASK 0xffffffff
|
#ifdef USE_ROCM
|
||||||
|
#define FINAL_MASK 0xffffffffffffffffULL
|
||||||
|
#else
|
||||||
|
#define FINAL_MASK 0xffffffff
|
||||||
|
#endif
|
||||||
|
|
||||||
// TODO: suport for AMD ROCM platform
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
namespace tensorrt_llm::common {
|
namespace tensorrt_llm::common {
|
||||||
template <typename T, int num>
|
template <typename T, int num>
|
||||||
struct packed_as;
|
struct packed_as;
|
||||||
@ -60,7 +62,7 @@ struct packed_as<uint, 4> {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__inline__ __device__ T warpReduceSum(T val) {
|
__inline__ __device__ T warpReduceSum(T val) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1)
|
for (int mask = 16; mask > 0; mask >>= 1)
|
||||||
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
|
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
|
||||||
return val;
|
return val;
|
||||||
@ -97,12 +99,12 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
int64_t const* position_ids, // Position IDs for RoPE
|
int64_t const* position_ids, // Position IDs for RoPE
|
||||||
int const num_tokens // Number of tokens
|
int const num_tokens // Number of tokens
|
||||||
) {
|
) {
|
||||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||||
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
|
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
|
||||||
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
|
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
using Converter = vllm::_typeConvert<scalar_t_in>;
|
using Converter = vllm::_typeConvert<scalar_t_in>;
|
||||||
static_assert(Converter::exists,
|
static_assert(Converter::exists,
|
||||||
@ -179,7 +181,7 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
{
|
{
|
||||||
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
|
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
|
||||||
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
|
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < num_packed_elems; i++) {
|
for (int i = 0; i < num_packed_elems; i++) {
|
||||||
// Interpret the generic vector chunk as the specific packed type
|
// Interpret the generic vector chunk as the specific packed type
|
||||||
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
|
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
|
||||||
@ -200,7 +202,7 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
|
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
|
||||||
|
|
||||||
// Normalize elements
|
// Normalize elements
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < numElemsPerThread; i++) {
|
for (int i = 0; i < numElemsPerThread; i++) {
|
||||||
int dim = laneId * numElemsPerThread + i;
|
int dim = laneId * numElemsPerThread + i;
|
||||||
float weight = isQ ? Converter::convert(q_weight[dim])
|
float weight = isQ ? Converter::convert(q_weight[dim])
|
||||||
@ -222,7 +224,7 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
|
|
||||||
if constexpr (interleave) {
|
if constexpr (interleave) {
|
||||||
// Perform interleaving. Use pre-computed cos/sin values.
|
// Perform interleaving. Use pre-computed cos/sin values.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < numElemsPerThread / 2; ++i) {
|
for (int i = 0; i < numElemsPerThread / 2; ++i) {
|
||||||
int const idx0 = 2 * i;
|
int const idx0 = 2 * i;
|
||||||
int const idx1 = 2 * i + 1;
|
int const idx1 = 2 * i + 1;
|
||||||
@ -245,9 +247,9 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
__syncwarp();
|
__syncwarp();
|
||||||
// Get the data from the other half of the warp. Use pre-computed cos/sin
|
// Get the data from the other half of the warp. Use pre-computed cos/sin
|
||||||
// values.
|
// values.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < numElemsPerThread; i++) {
|
for (int i = 0; i < numElemsPerThread; i++) {
|
||||||
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
|
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
|
||||||
if (laneId < 16) {
|
if (laneId < 16) {
|
||||||
elements2[i] = -elements2[i];
|
elements2[i] = -elements2[i];
|
||||||
}
|
}
|
||||||
@ -269,7 +271,7 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
{
|
{
|
||||||
vec_T vec;
|
vec_T vec;
|
||||||
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
|
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < num_packed_elems; i++) {
|
for (int i = 0; i < num_packed_elems; i++) {
|
||||||
// Convert from float2 back to the specific packed type
|
// Convert from float2 back to the specific packed type
|
||||||
T2_in packed_val = Converter::convert(
|
T2_in packed_val = Converter::convert(
|
||||||
@ -280,21 +282,21 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
|
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Borrowed from
|
// Borrowed from
|
||||||
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
|
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
|
||||||
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
|
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
|
||||||
if (interleave) { \
|
if (interleave) { \
|
||||||
const bool INTERLEAVE = true; \
|
const bool INTERLEAVE = true; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
} else { \
|
} else { \
|
||||||
const bool INTERLEAVE = false; \
|
const bool INTERLEAVE = false; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t_in, typename scalar_t_cache>
|
template <typename scalar_t_in, typename scalar_t_cache>
|
||||||
void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||||
@ -414,5 +416,3 @@ void fused_qk_norm_rope(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // not USE_ROCM
|
|
||||||
@ -175,7 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"float epsilon) -> ()");
|
"float epsilon) -> ()");
|
||||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
// Function for fused QK Norm and RoPE
|
// Function for fused QK Norm and RoPE
|
||||||
ops.def(
|
ops.def(
|
||||||
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
|
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
|
||||||
@ -183,7 +182,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
|
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
|
||||||
"bool is_neox, Tensor position_ids) -> ()");
|
"bool is_neox, Tensor position_ids) -> ()");
|
||||||
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
|
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
|
||||||
#endif
|
|
||||||
|
|
||||||
// Apply repetition penalties to logits in-place
|
// Apply repetition penalties to logits in-place
|
||||||
ops.def(
|
ops.def(
|
||||||
|
|||||||
@ -67,9 +67,9 @@ struct _typeConvert<c10::Half> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || defined(USE_ROCM)
|
||||||
// CUDA_ARCH < 800 does not have BF16 support
|
// CUDA_ARCH < 800 does not have BF16 support
|
||||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
// ROCm 7.0+ supports bfloat16
|
||||||
template <>
|
template <>
|
||||||
struct _typeConvert<c10::BFloat16> {
|
struct _typeConvert<c10::BFloat16> {
|
||||||
static constexpr bool exists = true;
|
static constexpr bool exists = true;
|
||||||
@ -89,7 +89,8 @@ struct _typeConvert<c10::BFloat16> {
|
|||||||
return __float22bfloat162_rn(x);
|
return __float22bfloat162_rn(x);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) ||
|
||||||
|
// defined(USE_ROCM)
|
||||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||||
// 12000))
|
// 12000))
|
||||||
|
|
||||||
|
|||||||
@ -113,8 +113,8 @@ class QKNormRoPETestModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("enable_rope_custom_op", [True])
|
@pytest.mark.parametrize("enable_rope_custom_op", [True])
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not current_platform.is_cuda(),
|
not current_platform.is_cuda_alike(),
|
||||||
reason="Only test on cuda platform",
|
reason="Only test on cuda and rocm platform",
|
||||||
)
|
)
|
||||||
def test_qk_norm_rope_fusion(
|
def test_qk_norm_rope_fusion(
|
||||||
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
||||||
|
|||||||
@ -44,8 +44,8 @@ def _apply_qk_norm_rope(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not current_platform.is_cuda(),
|
not current_platform.is_cuda_alike(),
|
||||||
reason="fused_qk_norm_rope custom op requires cuda platform",
|
reason="fused_qk_norm_rope custom op requires cuda and rocm platform",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
|||||||
@ -184,10 +184,10 @@ class PassConfig:
|
|||||||
"Fusion enabled but reshape elimination disabled. "
|
"Fusion enabled but reshape elimination disabled. "
|
||||||
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
||||||
)
|
)
|
||||||
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda():
|
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"QK Norm + RoPE fusion enabled but the current platform is not "
|
"QK Norm + RoPE fusion enabled but the current platform is not "
|
||||||
"CUDA. The fusion will be disabled."
|
"CUDA or ROCm. The fusion will be disabled."
|
||||||
)
|
)
|
||||||
self.enable_qk_norm_rope_fusion = False
|
self.enable_qk_norm_rope_fusion = False
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user