diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py index 66268b71b3de6..d31e67057d8f6 100644 --- a/benchmarks/kernels/benchmark_activation.py +++ b/benchmarks/kernels/benchmark_activation.py @@ -13,8 +13,8 @@ from vllm.triton_utils import triton from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE -batch_size_range = [1, 16, 32, 64, 128] -seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] +batch_size_range = [1, 16, 128] +seq_len_range = [1, 16, 64, 1024, 4096] intermediate_size = [3072, 9728, 12288] configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size)) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index a4a880f13cf7e..8268065ef02c8 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x, const scalar_t& y) { return act_first ? ACT_FN(x) * y : x * ACT_FN(y); } -// Activation and gating kernel template. +// Check if all pointers are 16-byte aligned for int4 vectorized access +__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { + return (reinterpret_cast(ptr) & 15) == 0; +} + +// Activation and gating kernel template. template __global__ void act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const int d) { + constexpr int VEC_SIZE = 16 / sizeof(scalar_t); const int64_t token_idx = blockIdx.x; - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); - const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); - out[token_idx * d + idx] = compute(x, y); + const scalar_t* x_ptr = input + token_idx * 2 * d; + const scalar_t* y_ptr = x_ptr + d; + scalar_t* out_ptr = out + token_idx * d; + + // Check alignment for 128-bit vectorized access. + // All three pointers must be 16-byte aligned for safe int4 operations. + const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) && + is_16byte_aligned(out_ptr); + + if (aligned && d >= VEC_SIZE) { + // Fast path: 128-bit vectorized loop + const int4* x_vec = reinterpret_cast(x_ptr); + const int4* y_vec = reinterpret_cast(y_ptr); + int4* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / VEC_SIZE; + const int vec_end = num_vecs * VEC_SIZE; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r; + auto* xp = reinterpret_cast(&x); + auto* yp = reinterpret_cast(&y); + auto* rp = reinterpret_cast(&r); +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + rp[j] = compute(xp[j], yp[j]); + } + out_vec[i] = r; + } + // Scalar cleanup for remaining elements + for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { + out_ptr[i] = compute(VLLM_LDG(&x_ptr[i]), + VLLM_LDG(&y_ptr[i])); + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + out_ptr[idx] = compute(x, y); + } } } @@ -120,50 +162,115 @@ template __global__ void act_and_mul_kernel_with_param( scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, const float param) { + constexpr int VEC_SIZE = 16 / sizeof(scalar_t); const int64_t token_idx = blockIdx.x; - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); - const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); - out[token_idx * d + idx] = ACT_FN(x, param) * y; + const scalar_t* x_ptr = input + token_idx * 2 * d; + const scalar_t* y_ptr = x_ptr + d; + scalar_t* out_ptr = out + token_idx * d; + + // Check alignment for 128-bit vectorized access + const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) && + is_16byte_aligned(out_ptr); + + if (aligned && d >= VEC_SIZE) { + // Fast path: 128-bit vectorized loop + const int4* x_vec = reinterpret_cast(x_ptr); + const int4* y_vec = reinterpret_cast(y_ptr); + int4* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / VEC_SIZE; + const int vec_end = num_vecs * VEC_SIZE; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r; + auto* xp = reinterpret_cast(&x); + auto* yp = reinterpret_cast(&y); + auto* rp = reinterpret_cast(&r); +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + rp[j] = ACT_FN(xp[j], param) * yp[j]; + } + out_vec[i] = r; + } + // Scalar cleanup for remaining elements + for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { + out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]); + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + out_ptr[idx] = ACT_FN(x, param) * y; + } } } template __device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up, float alpha, float limit) { - // clamp gate: min=None, max=limit - const float gate_f = (float)gate; - const float clamped_gate = gate_f > limit ? limit : gate_f; - - // clamp up: min=-limit, max=limit - const float up_f = (float)up; - const float clamped_up = - up_f > limit ? limit : (up_f < -limit ? -limit : up_f); - - // glu = gate * sigmoid(gate * alpha) - const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha)); - const float glu = clamped_gate * sigmoid_val; - - // (up + 1) * glu - return (T)((clamped_up + 1.0f) * glu); + // Clamp gate to (-inf, limit] and up to [-limit, limit] + const float g = fminf((float)gate, limit); + const float u = fmaxf(fminf((float)up, limit), -limit); + // glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu + return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha))); } +// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...]. template __global__ void swigluoai_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., 2, d] + const scalar_t* __restrict__ input, // [..., 2 * d] (interleaved) const int d, const float alpha, const float limit) { + // For interleaved data: input has 2*d elements per token (gate/up pairs) + // output has d elements per token + constexpr int VEC_SIZE = 16 / sizeof(scalar_t); + constexpr int PAIRS = VEC_SIZE / 2; // Number of gate/up pairs per int4 load const int64_t token_idx = blockIdx.x; - // TODO: Vectorize loads and stores. - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - // gate = x[..., ::2] (even indices) - const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]); - // up = x[..., 1::2] (odd indices) - const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]); + const scalar_t* in_ptr = input + token_idx * 2 * d; + scalar_t* out_ptr = out + token_idx * d; - out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit); + // Check alignment for 128-bit vectorized access on input. + // For output we use int2 (64-bit) which has 8-byte alignment requirement. + const bool in_aligned = is_16byte_aligned(in_ptr); + const bool out_aligned = + (reinterpret_cast(out_ptr) & 7) == 0; // 8-byte for int2 + + if (in_aligned && out_aligned && d >= PAIRS) { + // Fast path: vectorized loop + // Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs + // Each int2 store writes PAIRS output elements + const int4* in_vec = reinterpret_cast(in_ptr); + int2* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / PAIRS; + const int vec_end = num_vecs * PAIRS; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + int4 v = VLLM_LDG(&in_vec[i]); + int2 r; + auto* vp = reinterpret_cast(&v); + auto* rp = reinterpret_cast(&r); +#pragma unroll + for (int j = 0; j < PAIRS; j++) { + rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit); + } + out_vec[i] = r; + } + // Scalar cleanup for remaining elements + for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { + out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]), + VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit); + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + // gate = x[..., ::2] (even indices) + const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]); + // up = x[..., 1::2] (odd indices) + const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]); + out_ptr[idx] = ACT_FN(gate, up, alpha, limit); + } } } @@ -217,10 +324,41 @@ __global__ void activation_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., d] const int d) { + constexpr int VEC_SIZE = 16 / sizeof(scalar_t); const int64_t token_idx = blockIdx.x; - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); - out[token_idx * d + idx] = ACT_FN(x); + const scalar_t* in_ptr = input + token_idx * d; + scalar_t* out_ptr = out + token_idx * d; + + // Check alignment for 128-bit vectorized access + const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr); + + if (aligned && d >= VEC_SIZE) { + // Fast path: 128-bit vectorized loop + const int4* in_vec = reinterpret_cast(in_ptr); + int4* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / VEC_SIZE; + const int vec_end = num_vecs * VEC_SIZE; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + int4 v = VLLM_LDG(&in_vec[i]), r; + auto* vp = reinterpret_cast(&v); + auto* rp = reinterpret_cast(&r); +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + rp[j] = ACT_FN(vp[j]); + } + out_vec[i] = r; + } + // Scalar cleanup for remaining elements + for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { + out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i])); + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&in_ptr[idx]); + out_ptr[idx] = ACT_FN(x); + } } }