[Perf][Kernels] Vectorize csrc/activations_kernels.cu (#29512)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-12-16 17:56:02 -05:00 committed by GitHub
parent b6ec077e05
commit 0a1ab1e565
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 175 additions and 37 deletions

View File

@ -13,8 +13,8 @@ from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
batch_size_range = [1, 16, 128]
seq_len_range = [1, 16, 64, 1024, 4096]
intermediate_size = [3072, 9728, 12288]
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))

View File

@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}
// Activation and gating kernel template.
// Check if all pointers are 16-byte aligned for int4 vectorized access
__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
}
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t* y_ptr = x_ptr + d;
scalar_t* out_ptr = out + token_idx * d;
// Check alignment for 128-bit vectorized access.
// All three pointers must be 16-byte aligned for safe int4 operations.
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
auto* xp = reinterpret_cast<scalar_t*>(&x);
auto* yp = reinterpret_cast<scalar_t*>(&y);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = compute<scalar_t, ACT_FN, act_first>(xp[j], yp[j]);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = compute<scalar_t, ACT_FN, act_first>(VLLM_LDG(&x_ptr[i]),
VLLM_LDG(&y_ptr[i]));
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
}
}
}
@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x, param) * y;
const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t* y_ptr = x_ptr + d;
scalar_t* out_ptr = out + token_idx * d;
// Check alignment for 128-bit vectorized access
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
auto* xp = reinterpret_cast<scalar_t*>(&x);
auto* yp = reinterpret_cast<scalar_t*>(&y);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = ACT_FN(xp[j], param) * yp[j];
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]);
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = ACT_FN(x, param) * y;
}
}
}
template <typename T>
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
float alpha, float limit) {
// clamp gate: min=None, max=limit
const float gate_f = (float)gate;
const float clamped_gate = gate_f > limit ? limit : gate_f;
// clamp up: min=-limit, max=limit
const float up_f = (float)up;
const float clamped_up =
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
// glu = gate * sigmoid(gate * alpha)
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
const float glu = clamped_gate * sigmoid_val;
// (up + 1) * glu
return (T)((clamped_up + 1.0f) * glu);
// Clamp gate to (-inf, limit] and up to [-limit, limit]
const float g = fminf((float)gate, limit);
const float u = fmaxf(fminf((float)up, limit), -limit);
// glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu
return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha)));
}
// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...].
template <typename scalar_t,
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
const float)>
__global__ void swigluoai_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const scalar_t* __restrict__ input, // [..., 2 * d] (interleaved)
const int d, const float alpha, const float limit) {
// For interleaved data: input has 2*d elements per token (gate/up pairs)
// output has d elements per token
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
constexpr int PAIRS = VEC_SIZE / 2; // Number of gate/up pairs per int4 load
const int64_t token_idx = blockIdx.x;
// TODO: Vectorize loads and stores.
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// gate = x[..., ::2] (even indices)
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
// up = x[..., 1::2] (odd indices)
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
const scalar_t* in_ptr = input + token_idx * 2 * d;
scalar_t* out_ptr = out + token_idx * d;
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
// Check alignment for 128-bit vectorized access on input.
// For output we use int2 (64-bit) which has 8-byte alignment requirement.
const bool in_aligned = is_16byte_aligned(in_ptr);
const bool out_aligned =
(reinterpret_cast<uintptr_t>(out_ptr) & 7) == 0; // 8-byte for int2
if (in_aligned && out_aligned && d >= PAIRS) {
// Fast path: vectorized loop
// Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs
// Each int2 store writes PAIRS output elements
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
int2* out_vec = reinterpret_cast<int2*>(out_ptr);
const int num_vecs = d / PAIRS;
const int vec_end = num_vecs * PAIRS;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 v = VLLM_LDG(&in_vec[i]);
int2 r;
auto* vp = reinterpret_cast<scalar_t*>(&v);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < PAIRS; j++) {
rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]),
VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit);
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// gate = x[..., ::2] (even indices)
const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]);
// up = x[..., 1::2] (odd indices)
const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]);
out_ptr[idx] = ACT_FN(gate, up, alpha, limit);
}
}
}
@ -217,10 +324,41 @@ __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
const scalar_t* in_ptr = input + token_idx * d;
scalar_t* out_ptr = out + token_idx * d;
// Check alignment for 128-bit vectorized access
const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 v = VLLM_LDG(&in_vec[i]), r;
auto* vp = reinterpret_cast<scalar_t*>(&v);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = ACT_FN(vp[j]);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i]));
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&in_ptr[idx]);
out_ptr[idx] = ACT_FN(x);
}
}
}