[kernel][perf] support uncontiguous input for rms_norm kernel (#28103)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
zhrrr 2025-11-21 11:39:09 +08:00 committed by GitHub
parent 0e741c12e3
commit a982f5b5ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 33 deletions

View File

@ -117,3 +117,24 @@
break; \ break; \
} \ } \
} }
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
switch (NUM_DIMS) { \
case 2: { \
constexpr int tensor_rank = 2; \
__VA_ARGS__(); \
break; \
} \
case 3: { \
constexpr int tensor_rank = 3; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int tensor_rank = 4; \
__VA_ARGS__(); \
break; \
} \
default: \
TORCH_CHECK(false, "Expects rank 2, 3 or 4 tensors but got ", NUM_DIMS); \
}

View File

@ -10,16 +10,38 @@
namespace vllm { namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, int VEC_SIZE> template <typename scalar_t, int VEC_SIZE, int NUM_DIMS>
__global__ void rms_norm_kernel( __global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride, const int64_t input_stride_d2, // input.stride(-2)
const int64_t input_stride_d3, // input.stride(-3)
const int64_t input_stride_d4, // input.stride(-4)
const int64_t input_shape_d2, // input.size(-2)
const int64_t input_shape_d3, // input.size(-3)
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) { const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
const scalar_t* input_row = input + blockIdx.x * input_stride; const scalar_t* input_row;
if constexpr (NUM_DIMS == 2) {
// 2D for layernorm normal case [batch_size, hidden]
input_row = input + blockIdx.x * input_stride_d2;
} else if constexpr (NUM_DIMS == 3) {
// 3D for q/k norm [batch_size, num_heads, head_size]
int batch_idx = blockIdx.x / input_shape_d2;
int head_idx = blockIdx.x % input_shape_d2;
input_row =
input + batch_idx * input_stride_d3 + head_idx * input_stride_d2;
} else if constexpr (NUM_DIMS == 4) {
// 4D for transformers model_impl qk norm [batch, seq, head, head_dim]
int batch_idx = blockIdx.x / (input_shape_d3 * input_shape_d2);
int remaining = blockIdx.x % (input_shape_d3 * input_shape_d2);
int seq_idx = remaining / input_shape_d2;
int head_idx = remaining % input_shape_d2;
input_row = input + batch_idx * input_stride_d4 +
seq_idx * input_stride_d3 + head_idx * input_stride_d2;
}
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) { auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll #pragma unroll
@ -164,36 +186,42 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
double epsilon) { double epsilon) {
TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(out.is_contiguous());
if (input.stride(-1) != 1) {
input = input.contiguous();
}
TORCH_CHECK(input.stride(-1) == 1); TORCH_CHECK(input.stride(-1) == 1);
TORCH_CHECK(weight.is_contiguous()); TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
// We cannot just use `input.stride(-2)` if the tensor is not row-major. int num_tokens = input.numel() / hidden_size;
// Instead, we use a 2d view to get the second-innermost stride. int num_dims = input.dim();
// That way the dimensions (except the last one) can be arbitrarily permuted. int64_t input_stride_d2 = input.stride(-2);
torch::Tensor input_view = input.view({-1, hidden_size}); int64_t input_stride_d3 = (num_dims >= 3) ? input.stride(-3) : 0;
int64_t input_stride_d4 = (num_dims >= 4) ? input.stride(-4) : 0;
int num_tokens = input_view.numel() / hidden_size; int64_t input_shape_d2 = (num_dims >= 3) ? input.size(-2) : 0;
int64_t input_stride = input_view.stride(-2); int64_t input_shape_d3 = (num_dims >= 4) ? input.size(-3) : 0;
// For large num_tokens, use smaller blocks to increase SM concurrency. // For large num_tokens, use smaller blocks to increase SM concurrency.
const int max_block_size = (num_tokens < 256) ? 1024 : 256; const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 grid(num_tokens); dim3 grid(num_tokens);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_RANK234(num_dims, [&] {
input_view.scalar_type(), "rms_norm_kernel", [&] { VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
const int calculated_vec_size = const int calculated_vec_size =
std::gcd(16 / sizeof(scalar_t), hidden_size); std::gcd(16 / sizeof(scalar_t), hidden_size);
const int block_size = const int block_size =
std::min(hidden_size / calculated_vec_size, max_block_size); std::min(hidden_size / calculated_vec_size, max_block_size);
dim3 block(block_size); dim3 block(block_size);
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
vllm::rms_norm_kernel<scalar_t, vec_size><<<grid, block, 0, stream>>>( vllm::rms_norm_kernel<scalar_t, vec_size, tensor_rank>
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(), <<<grid, block, 0, stream>>>(
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens, out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size); input_stride_d2, input_stride_d3, input_stride_d4,
input_shape_d2, input_shape_d3, weight.data_ptr<scalar_t>(),
epsilon, num_tokens, hidden_size);
});
}); });
}); });
} }

View File

@ -328,10 +328,7 @@ def rotary_embedding(
def rms_norm( def rms_norm(
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float
) -> None: ) -> None:
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input torch.ops._C.rms_norm(out, input, weight, epsilon)
# If removed, also need to remove contiguous in MatcherRMSNorm
input_contiguous = input.contiguous()
torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
def fused_add_rms_norm( def fused_add_rms_norm(

View File

@ -162,12 +162,10 @@ class MatcherRMSNorm(MatcherCustomOp):
weight: torch.Tensor, weight: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
result = torch.empty_like(input) result = torch.empty_like(input)
# TODO: support non-contiguous input for RMSNorm and remove this
input_contiguous = input.contiguous()
_, result = auto_functionalized( _, result = auto_functionalized(
RMS_OP, RMS_OP,
result=result, result=result,
input=input_contiguous, input=input,
weight=weight, weight=weight,
epsilon=self.epsilon, epsilon=self.epsilon,
) )