mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 10:14:27 +08:00
[Perf] Deepgemm fused layout kernel for activations, 4.3% throughput improvement, 10.7% TTFT improvement. (#29546)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
b0f4866a77
commit
541a2ef892
@ -299,6 +299,14 @@ void per_token_group_quant_int8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s, int64_t group_size,
|
||||
double eps, double int8_min, double int8_max);
|
||||
|
||||
// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales.
|
||||
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s_packed,
|
||||
int64_t group_size, double eps,
|
||||
double min_8bit, double max_8bit);
|
||||
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
|
||||
@ -206,6 +206,191 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
#undef LAUNCH_KERNEL
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE>
|
||||
__global__ void per_token_group_quant_8bit_packed_kernel(
|
||||
const T* __restrict__ input, void* __restrict__ output_q,
|
||||
unsigned int* __restrict__ output_s_packed, const int group_size,
|
||||
const int num_groups, const int groups_per_block, const int groups_per_row,
|
||||
const int mn, const int tma_aligned_mn, const float eps,
|
||||
const float min_8bit, const float max_8bit) {
|
||||
const int threads_per_group = 16;
|
||||
const int64_t local_group_id = threadIdx.x / threads_per_group;
|
||||
const int lane_id = threadIdx.x % threads_per_group;
|
||||
|
||||
const int64_t block_group_id = blockIdx.x * groups_per_block;
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
if (global_group_id >= num_groups) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
const T* group_input = input + block_group_offset;
|
||||
DST_DTYPE* group_output =
|
||||
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||
|
||||
// shared memory to cache each group's data to avoid double DRAM reads.
|
||||
extern __shared__ __align__(16) char smem_raw[];
|
||||
T* smem = reinterpret_cast<T*>(smem_raw);
|
||||
T* smem_group = smem + local_group_id * group_size;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
using vec_t = vllm::vec_n_t<T, vec_size>;
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
|
||||
// pack 4 scales into a uint32
|
||||
if (lane_id == 0) {
|
||||
// map flat group id to 2D indices (mn_idx, sf_k_idx)
|
||||
const int sf_k_idx = static_cast<int>(global_group_id % groups_per_row);
|
||||
const int mn_idx = static_cast<int>(global_group_id / groups_per_row);
|
||||
|
||||
if (mn_idx < mn) {
|
||||
// each uint32 in output_s_packed stores 4 packed scales
|
||||
const int sf_k_pack_idx = sf_k_idx / 4;
|
||||
const int pos = sf_k_idx % 4;
|
||||
|
||||
// reinterpret the UE8M0 scale y_s as IEEE bits, extract the 8-bit
|
||||
// exponent, and place it into the correct byte of the 32-bit word.
|
||||
const unsigned int bits = __float_as_uint(y_s);
|
||||
const unsigned int exponent = (bits >> 23u) & 0xffu;
|
||||
const unsigned int contrib = exponent << (pos * 8u);
|
||||
|
||||
const int out_idx = sf_k_pack_idx * tma_aligned_mn + mn_idx;
|
||||
// atomically OR 8-bit exponent into the packed scales buffer
|
||||
atomicOr(output_s_packed + out_idx, contrib);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s_packed,
|
||||
int64_t group_size, double eps,
|
||||
double min_8bit, double max_8bit) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(output_q.is_contiguous());
|
||||
|
||||
const int64_t k = input.size(-1);
|
||||
TORCH_CHECK(k % group_size == 0, "Last dimension (", k,
|
||||
") must be divisible by group_size (", group_size, ").");
|
||||
|
||||
const int64_t mn = input.numel() / k;
|
||||
const int64_t groups_per_row = k / group_size;
|
||||
const int64_t num_groups = mn * groups_per_row;
|
||||
|
||||
TORCH_CHECK(output_s_packed.dim() == 2,
|
||||
"output_s_packed must be 2D, got dim=", output_s_packed.dim(),
|
||||
".");
|
||||
|
||||
const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4;
|
||||
const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4;
|
||||
|
||||
TORCH_CHECK(output_s_packed.scalar_type() == at::ScalarType::Int,
|
||||
"output_s_packed must have dtype int32 for UE8M0-packed scales.");
|
||||
// DeepGEMM expects SFA scales in MN-major form with shape
|
||||
// [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last
|
||||
// dimension.
|
||||
TORCH_CHECK(output_s_packed.size(0) == mn &&
|
||||
output_s_packed.size(1) == k_num_packed_sfk,
|
||||
"output_s_packed shape must be [", mn, ", ", k_num_packed_sfk,
|
||||
"], but got [", output_s_packed.size(0), ", ",
|
||||
output_s_packed.size(1), "].");
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
int groups_per_block = 1;
|
||||
|
||||
if (num_groups % 16 == 0) {
|
||||
groups_per_block = 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
groups_per_block = 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
groups_per_block = 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
groups_per_block = 2;
|
||||
}
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
const int num_threads = groups_per_block * THREADS_PER_GROUP;
|
||||
|
||||
// zero-initialize packed scales, since we use atomicOr to accumulate
|
||||
// exponents from different groups.
|
||||
output_s_packed.zero_();
|
||||
|
||||
#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \
|
||||
do { \
|
||||
dim3 grid(num_blocks); \
|
||||
dim3 block(num_threads); \
|
||||
size_t smem_bytes = \
|
||||
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
|
||||
per_token_group_quant_8bit_packed_kernel<T, DST_DTYPE> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
|
||||
static_cast<int>(group_size), static_cast<int>(num_groups), \
|
||||
groups_per_block, static_cast<int>(groups_per_row), \
|
||||
static_cast<int>(mn), static_cast<int>(tma_aligned_mn), \
|
||||
static_cast<float>(eps), static_cast<float>(min_8bit), \
|
||||
static_cast<float>(max_8bit)); \
|
||||
} while (0)
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] {
|
||||
if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
||||
LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3);
|
||||
} else if (dst_type == at::ScalarType::Char) {
|
||||
LAUNCH_PACKED_KERNEL(scalar_t, int8_t);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"per_token_group_quant_8bit_packed only supports FP8/INT8 "
|
||||
"outputs.");
|
||||
}
|
||||
}));
|
||||
|
||||
#undef LAUNCH_PACKED_KERNEL
|
||||
}
|
||||
|
||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||
int64_t group_size, double eps, double fp8_min,
|
||||
|
||||
@ -617,6 +617,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
|
||||
&per_token_group_quant_fp8);
|
||||
|
||||
// Compute per-token-group 8-bit quantized tensor and UE8M0-packed,
|
||||
// TMA-aligned scales for DeepGEMM.
|
||||
ops.def(
|
||||
"per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, "
|
||||
"Tensor! output_s_packed, int group_size, float eps, float fp8_min, "
|
||||
"float fp8_max) -> ()");
|
||||
ops.impl("per_token_group_fp8_quant_packed", torch::kCUDA,
|
||||
&per_token_group_quant_8bit_packed);
|
||||
|
||||
// Compute per-token-group INT8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
|
||||
|
||||
@ -23,9 +23,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
per_token_group_quant_fp8_packed_for_deepgemm,
|
||||
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||
)
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
)
|
||||
@ -157,23 +159,40 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def _act_mul_quant(
|
||||
self, input: torch.Tensor, output: torch.Tensor, activation: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if activation == "silu":
|
||||
return silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input=input, output=output
|
||||
)
|
||||
else:
|
||||
# This is a fallback path. If we find ourselves using any activation other
|
||||
# than silu, we should add that activation to
|
||||
# silu_mul_per_token_group_quant_fp8_colmajor kernel as it is much faster.
|
||||
assert self.block_shape is not None
|
||||
block_k = self.block_shape[1]
|
||||
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
|
||||
|
||||
# 1. DeepGemm UE8M0: use packed per-token-group quant
|
||||
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
|
||||
M_sum, N = input.size()
|
||||
act_out = torch.empty(
|
||||
(M_sum, N // 2), dtype=input.dtype, device=input.device
|
||||
)
|
||||
self.activation(activation, act_out, input)
|
||||
assert self.block_shape is not None
|
||||
return per_token_group_quant_fp8(
|
||||
act_out, self.block_shape[1], column_major_scales=True, out_q=output
|
||||
a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
act_out,
|
||||
block_k,
|
||||
out_q=output,
|
||||
)
|
||||
return a2q, a2q_scale
|
||||
|
||||
# 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel
|
||||
if activation == "silu":
|
||||
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
|
||||
return silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input=input,
|
||||
output=output,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
|
||||
# 3. fallback path for non-SiLU activations in non‑UE8M0 cases.
|
||||
M_sum, N = input.size()
|
||||
act_out = torch.empty((M_sum, N // 2), dtype=input.dtype, device=input.device)
|
||||
self.activation(activation, act_out, input)
|
||||
return per_token_group_quant_fp8(
|
||||
act_out, block_k, column_major_scales=True, out_q=output
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -269,7 +269,11 @@ class W8A8BlockFp8LinearOp:
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.deepgemm_input_quant_op is not None
|
||||
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
|
||||
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
input_2d,
|
||||
group_size=self.act_quant_group_shape.col,
|
||||
use_ue8m0=True,
|
||||
)
|
||||
output = torch.empty(
|
||||
(q_input.shape[0], weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
@ -791,6 +795,80 @@ def per_token_group_quant_fp8(
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
use_ue8m0: bool | None = None,
|
||||
out_q: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""FP8 per-token-group quantization for DeepGEMM.
|
||||
|
||||
Returns:
|
||||
(x_q, x_s_packed)
|
||||
x_q: FP8 activations, same shape as `x`.
|
||||
x_s_packed: Int32 tensor with logical shape
|
||||
[mn, ceil(num_groups_per_row / 4)], laid out with
|
||||
TMA-aligned stride along the packed-K dimension
|
||||
"""
|
||||
if use_ue8m0 is None:
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
# for DeepGEMM UE8M0-packed layout we *require* UE8M0 scales.
|
||||
assert use_ue8m0, (
|
||||
"per_token_group_quant_fp8_packed_for_deepgemm requires UE8M0 scales."
|
||||
)
|
||||
|
||||
dtype = current_platform.fp8_dtype()
|
||||
assert x.shape[-1] % group_size == 0, (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||
f"by `group_size` {group_size}"
|
||||
)
|
||||
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min, fp8_max = finfo.min, finfo.max
|
||||
|
||||
# compute DeepGEMM-style packed scale tensor shape.
|
||||
hidden_dim = x.shape[-1]
|
||||
mn = x.numel() // hidden_dim
|
||||
num_groups_per_row = hidden_dim // group_size
|
||||
k_num_packed_sf_k = (num_groups_per_row + 3) // 4
|
||||
tma_aligned_mn = ((mn + 3) // 4) * 4
|
||||
|
||||
x_s_packed = torch.empty_strided(
|
||||
(mn, k_num_packed_sf_k),
|
||||
(1, tma_aligned_mn),
|
||||
device=x.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# CUDA kernel path only (DeepGEMM + E8M0 is CUDA-specific).
|
||||
assert current_platform.is_cuda(), (
|
||||
"per_token_group_quant_fp8_packed_for_deepgemm is only valid on CUDA "
|
||||
"platforms using DeepGEMM."
|
||||
)
|
||||
|
||||
x_contiguous = x.contiguous()
|
||||
if out_q is not None:
|
||||
x_q_local = out_q
|
||||
else:
|
||||
x_q_local = torch.empty_like(x_contiguous, device=x.device, dtype=dtype)
|
||||
|
||||
torch.ops._C.per_token_group_fp8_quant_packed(
|
||||
x_contiguous,
|
||||
x_q_local,
|
||||
x_s_packed,
|
||||
group_size,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
)
|
||||
|
||||
# return a tensor with the original logical shape.
|
||||
x_q = x_q_local.view_as(x)
|
||||
return x_q, x_s_packed
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_triton_block_scaled_mm(
|
||||
# Pointers to inputs and output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user