diff --git a/csrc/ops.h b/csrc/ops.h index 4bb7857b15032..d302f04913266 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -299,6 +299,14 @@ void per_token_group_quant_int8(const torch::Tensor& input, torch::Tensor& output_q, torch::Tensor& output_s, int64_t group_size, double eps, double int8_min, double int8_max); + +// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales. +void per_token_group_quant_8bit_packed(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s_packed, + int64_t group_size, double eps, + double min_8bit, double max_8bit); + #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu index e3ab0676b254e..f9ac874c43730 100644 --- a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu +++ b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu @@ -206,6 +206,191 @@ void per_token_group_quant_8bit(const torch::Tensor& input, #undef LAUNCH_KERNEL } +template +__global__ void per_token_group_quant_8bit_packed_kernel( + const T* __restrict__ input, void* __restrict__ output_q, + unsigned int* __restrict__ output_s_packed, const int group_size, + const int num_groups, const int groups_per_block, const int groups_per_row, + const int mn, const int tma_aligned_mn, const float eps, + const float min_8bit, const float max_8bit) { + const int threads_per_group = 16; + const int64_t local_group_id = threadIdx.x / threads_per_group; + const int lane_id = threadIdx.x % threads_per_group; + + const int64_t block_group_id = blockIdx.x * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + if (global_group_id >= num_groups) { + return; + } + + const int64_t block_group_offset = global_group_id * group_size; + + float local_absmax = eps; + + const T* group_input = input + block_group_offset; + DST_DTYPE* group_output = + static_cast(output_q) + block_group_offset; + + // shared memory to cache each group's data to avoid double DRAM reads. + extern __shared__ __align__(16) char smem_raw[]; + T* smem = reinterpret_cast(smem_raw); + T* smem_group = smem + local_group_id * group_size; + + constexpr int vec_size = 16 / sizeof(T); + using vec_t = vllm::vec_n_t; + + // copy global -> shared & compute absmax + auto scalar_op_cache = [&] __device__(T & dst, const T& src) { + float abs_v = fabsf(static_cast(src)); + local_absmax = fmaxf(local_absmax, abs_v); + dst = src; + }; + + vllm::vectorize_with_alignment( + group_input, // in + smem_group, // out (shared) + group_size, // elements per group + lane_id, // thread id + threads_per_group, // stride in group + scalar_op_cache); // scalar handler + + local_absmax = GroupReduceMax(local_absmax); + + float y_s = local_absmax / max_8bit; + y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); + + // pack 4 scales into a uint32 + if (lane_id == 0) { + // map flat group id to 2D indices (mn_idx, sf_k_idx) + const int sf_k_idx = static_cast(global_group_id % groups_per_row); + const int mn_idx = static_cast(global_group_id / groups_per_row); + + if (mn_idx < mn) { + // each uint32 in output_s_packed stores 4 packed scales + const int sf_k_pack_idx = sf_k_idx / 4; + const int pos = sf_k_idx % 4; + + // reinterpret the UE8M0 scale y_s as IEEE bits, extract the 8-bit + // exponent, and place it into the correct byte of the 32-bit word. + const unsigned int bits = __float_as_uint(y_s); + const unsigned int exponent = (bits >> 23u) & 0xffu; + const unsigned int contrib = exponent << (pos * 8u); + + const int out_idx = sf_k_pack_idx * tma_aligned_mn + mn_idx; + // atomically OR 8-bit exponent into the packed scales buffer + atomicOr(output_s_packed + out_idx, contrib); + } + } + + __syncthreads(); + + // quantize shared -> global 8-bit + auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) { + float q = fminf(fmaxf(static_cast(src) / y_s, min_8bit), max_8bit); + dst = DST_DTYPE(q); + }; + + vllm::vectorize_with_alignment( + smem_group, // in (shared) + group_output, // out (global quant tensor) + group_size, // elements + lane_id, // tid + threads_per_group, // stride + scalar_op_quant); // scalar handler +} + +void per_token_group_quant_8bit_packed(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s_packed, + int64_t group_size, double eps, + double min_8bit, double max_8bit) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(output_q.is_contiguous()); + + const int64_t k = input.size(-1); + TORCH_CHECK(k % group_size == 0, "Last dimension (", k, + ") must be divisible by group_size (", group_size, ")."); + + const int64_t mn = input.numel() / k; + const int64_t groups_per_row = k / group_size; + const int64_t num_groups = mn * groups_per_row; + + TORCH_CHECK(output_s_packed.dim() == 2, + "output_s_packed must be 2D, got dim=", output_s_packed.dim(), + "."); + + const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4; + const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4; + + TORCH_CHECK(output_s_packed.scalar_type() == at::ScalarType::Int, + "output_s_packed must have dtype int32 for UE8M0-packed scales."); + // DeepGEMM expects SFA scales in MN-major form with shape + // [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last + // dimension. + TORCH_CHECK(output_s_packed.size(0) == mn && + output_s_packed.size(1) == k_num_packed_sfk, + "output_s_packed shape must be [", mn, ", ", k_num_packed_sfk, + "], but got [", output_s_packed.size(0), ", ", + output_s_packed.size(1), "]."); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + constexpr int THREADS_PER_GROUP = 16; + + int groups_per_block = 1; + + if (num_groups % 16 == 0) { + groups_per_block = 16; + } else if (num_groups % 8 == 0) { + groups_per_block = 8; + } else if (num_groups % 4 == 0) { + groups_per_block = 4; + } else if (num_groups % 2 == 0) { + groups_per_block = 2; + } + + auto dst_type = output_q.scalar_type(); + const int num_blocks = num_groups / groups_per_block; + const int num_threads = groups_per_block * THREADS_PER_GROUP; + + // zero-initialize packed scales, since we use atomicOr to accumulate + // exponents from different groups. + output_s_packed.zero_(); + +#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \ + do { \ + dim3 grid(num_blocks); \ + dim3 block(num_threads); \ + size_t smem_bytes = \ + static_cast(groups_per_block) * group_size * sizeof(T); \ + per_token_group_quant_8bit_packed_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + reinterpret_cast(output_s_packed.data_ptr()), \ + static_cast(group_size), static_cast(num_groups), \ + groups_per_block, static_cast(groups_per_row), \ + static_cast(mn), static_cast(tma_aligned_mn), \ + static_cast(eps), static_cast(min_8bit), \ + static_cast(max_8bit)); \ + } while (0) + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] { + if (dst_type == at::ScalarType::Float8_e4m3fn) { + LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3); + } else if (dst_type == at::ScalarType::Char) { + LAUNCH_PACKED_KERNEL(scalar_t, int8_t); + } else { + TORCH_CHECK( + false, + "per_token_group_quant_8bit_packed only supports FP8/INT8 " + "outputs."); + } + })); + +#undef LAUNCH_PACKED_KERNEL +} + void per_token_group_quant_fp8(const torch::Tensor& input, torch::Tensor& output_q, torch::Tensor& output_s, int64_t group_size, double eps, double fp8_min, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 914227838558a..23ac1d9abeea9 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -617,6 +617,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("per_token_group_fp8_quant", torch::kCUDA, &per_token_group_quant_fp8); + // Compute per-token-group 8-bit quantized tensor and UE8M0-packed, + // TMA-aligned scales for DeepGEMM. + ops.def( + "per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, " + "Tensor! output_s_packed, int group_size, float eps, float fp8_min, " + "float fp8_max) -> ()"); + ops.impl("per_token_group_fp8_quant_packed", torch::kCUDA, + &per_token_group_quant_8bit_packed); + // Compute per-token-group INT8 quantized tensor and scaling factor. ops.def( "per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! " diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 9f47e692d5ae2..4a64736ed767b 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -23,9 +23,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, + per_token_group_quant_fp8_packed_for_deepgemm, silu_mul_per_token_group_quant_fp8_colmajor, ) from vllm.utils.deep_gemm import ( + DeepGemmQuantScaleFMT, get_mk_alignment_for_contiguous_layout, m_grouped_fp8_gemm_nt_contiguous, ) @@ -157,23 +159,40 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def _act_mul_quant( self, input: torch.Tensor, output: torch.Tensor, activation: str ) -> tuple[torch.Tensor, torch.Tensor]: - if activation == "silu": - return silu_mul_per_token_group_quant_fp8_colmajor( - input=input, output=output - ) - else: - # This is a fallback path. If we find ourselves using any activation other - # than silu, we should add that activation to - # silu_mul_per_token_group_quant_fp8_colmajor kernel as it is much faster. + assert self.block_shape is not None + block_k = self.block_shape[1] + scale_fmt = DeepGemmQuantScaleFMT.from_oracle() + + # 1. DeepGemm UE8M0: use packed per-token-group quant + if scale_fmt == DeepGemmQuantScaleFMT.UE8M0: M_sum, N = input.size() act_out = torch.empty( (M_sum, N // 2), dtype=input.dtype, device=input.device ) self.activation(activation, act_out, input) - assert self.block_shape is not None - return per_token_group_quant_fp8( - act_out, self.block_shape[1], column_major_scales=True, out_q=output + a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm( + act_out, + block_k, + out_q=output, ) + return a2q, a2q_scale + + # 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel + if activation == "silu": + use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 + return silu_mul_per_token_group_quant_fp8_colmajor( + input=input, + output=output, + use_ue8m0=use_ue8m0, + ) + + # 3. fallback path for non-SiLU activations in non‑UE8M0 cases. + M_sum, N = input.size() + act_out = torch.empty((M_sum, N // 2), dtype=input.dtype, device=input.device) + self.activation(activation, act_out, input) + return per_token_group_quant_fp8( + act_out, block_k, column_major_scales=True, out_q=output + ) def apply( self, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 6e73833d1ae1c..7e1bda8639ac7 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -269,7 +269,11 @@ class W8A8BlockFp8LinearOp: weight_scale: torch.Tensor, ) -> torch.Tensor: assert self.deepgemm_input_quant_op is not None - q_input, input_scale = self.deepgemm_input_quant_op(input_2d) + q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm( + input_2d, + group_size=self.act_quant_group_shape.col, + use_ue8m0=True, + ) output = torch.empty( (q_input.shape[0], weight.shape[0]), dtype=torch.bfloat16, @@ -791,6 +795,80 @@ def per_token_group_quant_fp8( return x_q, x_s +def per_token_group_quant_fp8_packed_for_deepgemm( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + use_ue8m0: bool | None = None, + out_q: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """FP8 per-token-group quantization for DeepGEMM. + + Returns: + (x_q, x_s_packed) + x_q: FP8 activations, same shape as `x`. + x_s_packed: Int32 tensor with logical shape + [mn, ceil(num_groups_per_row / 4)], laid out with + TMA-aligned stride along the packed-K dimension + """ + if use_ue8m0 is None: + use_ue8m0 = is_deep_gemm_e8m0_used() + # for DeepGEMM UE8M0-packed layout we *require* UE8M0 scales. + assert use_ue8m0, ( + "per_token_group_quant_fp8_packed_for_deepgemm requires UE8M0 scales." + ) + + dtype = current_platform.fp8_dtype() + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min, fp8_max = finfo.min, finfo.max + + # compute DeepGEMM-style packed scale tensor shape. + hidden_dim = x.shape[-1] + mn = x.numel() // hidden_dim + num_groups_per_row = hidden_dim // group_size + k_num_packed_sf_k = (num_groups_per_row + 3) // 4 + tma_aligned_mn = ((mn + 3) // 4) * 4 + + x_s_packed = torch.empty_strided( + (mn, k_num_packed_sf_k), + (1, tma_aligned_mn), + device=x.device, + dtype=torch.int32, + ) + + # CUDA kernel path only (DeepGEMM + E8M0 is CUDA-specific). + assert current_platform.is_cuda(), ( + "per_token_group_quant_fp8_packed_for_deepgemm is only valid on CUDA " + "platforms using DeepGEMM." + ) + + x_contiguous = x.contiguous() + if out_q is not None: + x_q_local = out_q + else: + x_q_local = torch.empty_like(x_contiguous, device=x.device, dtype=dtype) + + torch.ops._C.per_token_group_fp8_quant_packed( + x_contiguous, + x_q_local, + x_s_packed, + group_size, + eps, + fp8_min, + fp8_max, + ) + + # return a tensor with the original logical shape. + x_q = x_q_local.view_as(x) + return x_q, x_s_packed + + @triton.jit def _w8a8_triton_block_scaled_mm( # Pointers to inputs and output