mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 10:34:28 +08:00
[Performance] Fused blockwise quant RMS norm (#27883)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
0044c4038c
commit
af0444bf40
@ -14,6 +14,9 @@ from tqdm import tqdm
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -22,6 +25,7 @@ class bench_params_t:
|
||||
hidden_size: int
|
||||
add_residual: bool
|
||||
dtype: torch.dtype
|
||||
group_size: list[int]
|
||||
|
||||
def description(self):
|
||||
return (
|
||||
@ -29,6 +33,7 @@ class bench_params_t:
|
||||
f"x D {self.hidden_size} "
|
||||
f"x R {self.add_residual} "
|
||||
f"x DT {self.dtype}"
|
||||
f"x GS {self.group_size}"
|
||||
)
|
||||
|
||||
|
||||
@ -38,10 +43,11 @@ def get_bench_params() -> list[bench_params_t]:
|
||||
HIDDEN_SIZES = list(range(1024, 8129, 1024))
|
||||
ADD_RESIDUAL = [True, False]
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
GROUP_SIZES = [[1, 64], [1, 128]]
|
||||
|
||||
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
|
||||
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES)
|
||||
bench_params = list(
|
||||
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
|
||||
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations)
|
||||
)
|
||||
return bench_params
|
||||
|
||||
@ -52,6 +58,7 @@ def unfused_int8_impl(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: list[int],
|
||||
):
|
||||
# Norm
|
||||
torch_out = None
|
||||
@ -69,6 +76,7 @@ def unfused_fp8_impl(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: list[int],
|
||||
):
|
||||
# Norm
|
||||
torch_out = None
|
||||
@ -81,23 +89,63 @@ def unfused_fp8_impl(
|
||||
torch_out, _ = ops.scaled_fp8_quant(torch_out)
|
||||
|
||||
|
||||
def unfused_groupwise_fp8_impl(
|
||||
rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: list[int],
|
||||
):
|
||||
# Norm
|
||||
torch_out = None
|
||||
if residual is None:
|
||||
torch_out = rms_norm_layer.forward_cuda(x, residual)
|
||||
else:
|
||||
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
|
||||
|
||||
# Quant
|
||||
torch_out, _ = per_token_group_quant_fp8(
|
||||
torch_out, group_size=group_size[1], use_ue8m0=False
|
||||
)
|
||||
|
||||
|
||||
def fused_impl(
|
||||
rms_norm_layer: RMSNorm, # this stores the weights
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: list[int],
|
||||
):
|
||||
out, _ = ops.rms_norm_dynamic_per_token_quant(
|
||||
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
|
||||
)
|
||||
|
||||
|
||||
def fused_groupwise_impl(
|
||||
rms_norm_layer: RMSNorm, # this stores the weights
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: list[int],
|
||||
):
|
||||
out, _ = ops.rms_norm_per_block_quant(
|
||||
x,
|
||||
rms_norm_layer.weight,
|
||||
1e-6,
|
||||
quant_dtype,
|
||||
group_size,
|
||||
residual=residual,
|
||||
is_scale_transposed=True,
|
||||
)
|
||||
|
||||
|
||||
# Bench functions
|
||||
def bench_fn(
|
||||
rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: list[int],
|
||||
label: str,
|
||||
sub_label: str,
|
||||
fn: Callable,
|
||||
@ -110,10 +158,11 @@ def bench_fn(
|
||||
"x": x,
|
||||
"residual": residual,
|
||||
"quant_dtype": quant_dtype,
|
||||
"group_size": group_size,
|
||||
"fn": fn,
|
||||
}
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
|
||||
stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -147,6 +196,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
|
||||
x,
|
||||
residual,
|
||||
torch.int8,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
unfused_int8_impl,
|
||||
@ -161,6 +211,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
|
||||
x,
|
||||
residual,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
unfused_fp8_impl,
|
||||
@ -175,6 +226,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
|
||||
x,
|
||||
residual,
|
||||
torch.int8,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
fused_impl,
|
||||
@ -189,6 +241,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
|
||||
x,
|
||||
residual,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
fused_impl,
|
||||
@ -196,6 +249,36 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
|
||||
)
|
||||
)
|
||||
|
||||
# unfused groupwise fp8 impl.
|
||||
timers.append(
|
||||
bench_fn(
|
||||
layer,
|
||||
x,
|
||||
residual,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
unfused_groupwise_fp8_impl,
|
||||
"unfused_groupwise_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
# fused groupwise fp8 impl.
|
||||
timers.append(
|
||||
bench_fn(
|
||||
layer,
|
||||
x,
|
||||
residual,
|
||||
torch.float8_e4m3fn,
|
||||
params.group_size,
|
||||
label,
|
||||
sub_label,
|
||||
fused_groupwise_impl,
|
||||
"fused_groupwise_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
print_timers(timers)
|
||||
|
||||
return timers
|
||||
|
||||
@ -118,6 +118,24 @@
|
||||
} \
|
||||
}
|
||||
|
||||
#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
if (expr) { \
|
||||
constexpr bool const_expr = true; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool const_expr = false; \
|
||||
__VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define VLLM_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \
|
||||
if (group_size == 128) { \
|
||||
constexpr int const_group_size = 128; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (group_size == 64) { \
|
||||
constexpr int const_group_size = 64; \
|
||||
__VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
|
||||
switch (NUM_DIMS) { \
|
||||
case 2: { \
|
||||
|
||||
@ -128,6 +128,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
std::optional<torch::Tensor> residual);
|
||||
|
||||
void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& weight,
|
||||
torch::Tensor& scales, double const epsilon,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
std::optional<torch::Tensor> residual,
|
||||
int64_t group_size, bool is_scale_transposed);
|
||||
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
std::optional<torch::Tensor> key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
|
||||
@ -31,14 +31,15 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
|
||||
|
||||
// RMS Norm + Quant
|
||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||
token_scale = 1.0f / token_scale;
|
||||
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
|
||||
has_residual>(
|
||||
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
} else {
|
||||
// FP8 - Do not invert token_scale for exact match with FBGemm
|
||||
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
|
||||
has_residual>(
|
||||
out, input, weight, rms, token_scale, hidden_size, residual);
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
}
|
||||
}
|
||||
|
||||
@ -75,14 +76,52 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
||||
|
||||
// RMS Norm + Quant
|
||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||
token_scale = 1.0f / token_scale;
|
||||
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
|
||||
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
} else {
|
||||
// FP8 - Do not invert s_token_scale for exact match with FBGemm
|
||||
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
|
||||
out, input, weight, rms, token_scale, hidden_size, residual);
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
}
|
||||
}
|
||||
|
||||
// RMS norm + quant kernel
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
|
||||
bool is_scale_transposed = false, int32_t group_size = 0>
|
||||
__global__ void rms_norm_per_block_quant_kernel(
|
||||
scalar_out_t* __restrict__ out, // [..., hidden_size]
|
||||
float* __restrict__ scales, // [num_tokens, hidden_size / group_size]
|
||||
// or
|
||||
// [hidden_size / group_size, num_tokens]
|
||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
float rms;
|
||||
// Compute RMS
|
||||
// Always able to vectorize due to constraints on hidden_size
|
||||
vllm::vectorized::compute_rms<scalar_t, has_residual>(
|
||||
&rms, input, hidden_size, var_epsilon, residual);
|
||||
|
||||
// Compute Scale
|
||||
// Always able to vectorize due to constraints on hidden_size and group_size
|
||||
vllm::vectorized::compute_dynamic_per_token_scales<
|
||||
scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>(
|
||||
nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual);
|
||||
|
||||
// RMS Norm + Quant
|
||||
// Always able to vectorize due to constraints on hidden_size
|
||||
// For int8, don't invert token_scale here: do it inside the norm_and_quant
|
||||
// kernel. We do it because particular elements of token_scale can be shared
|
||||
// between multiple threads, so this way, we avoid extra synchronization
|
||||
// overhead.
|
||||
vllm::vectorized::norm_and_quant<
|
||||
scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>,
|
||||
has_residual, is_scale_transposed, group_size>(
|
||||
out, input, weight, rms, scales, hidden_size, residual);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Residual add + RMS norm + dynamic per token
|
||||
@ -103,30 +142,19 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (residual.has_value()) {
|
||||
VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] {
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
||||
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
|
||||
true>
|
||||
has_residual>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
var_epsilon, hidden_size, residual->data_ptr<scalar_in_t>());
|
||||
var_epsilon, hidden_size,
|
||||
has_residual ? residual->data_ptr<scalar_in_t>() : nullptr);
|
||||
});
|
||||
|
||||
} else {
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
||||
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
|
||||
false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
var_epsilon, hidden_size, nullptr);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void rms_norm_dynamic_per_token_quant(
|
||||
@ -157,3 +185,79 @@ void rms_norm_dynamic_per_token_quant(
|
||||
out, input, weight, scales, var_epsilon, scale_ub, residual);
|
||||
});
|
||||
}
|
||||
|
||||
// Residual add + RMS norm + dynamic per token
|
||||
void rms_norm_per_block_quant_dispatch(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor const& weight, // [hidden_size]
|
||||
torch::Tensor& scales, // [num_tokens, hidden_size / group_size] or
|
||||
// [hidden_size / group_size, num_tokens]
|
||||
int32_t group_size,
|
||||
double const var_epsilon, // Variance epsilon used in norm calculation
|
||||
std::optional<at::Tensor> const& scale_ub,
|
||||
std::optional<at::Tensor>& residual, bool is_scale_transposed) {
|
||||
int32_t hidden_size = input.size(-1);
|
||||
auto num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
const int max_block_size = (num_tokens <= 256) ? 512 : 256;
|
||||
dim3 block(std::min(hidden_size, max_block_size));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "rms_norm_per_block_quant_fp_dispatch", [&] {
|
||||
using scalar_in_t = scalar_t;
|
||||
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
|
||||
VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] {
|
||||
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] {
|
||||
vllm::rms_norm_per_block_quant_kernel<scalar_in_t, scalar_t,
|
||||
has_residual,
|
||||
transpose_scale, gs>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(),
|
||||
weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||
: nullptr,
|
||||
var_epsilon, hidden_size,
|
||||
has_residual ? residual->data_ptr<scalar_in_t>()
|
||||
: nullptr);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& weight,
|
||||
torch::Tensor& scales, double const var_epsilon,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
std::optional<torch::Tensor> residual,
|
||||
int64_t group_size, bool is_scale_transposed) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
|
||||
if (scale_ub.has_value()) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||
}
|
||||
TORCH_CHECK(weight.dtype() == input.dtype());
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat32);
|
||||
if (residual) {
|
||||
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
|
||||
}
|
||||
|
||||
TORCH_CHECK(group_size == 128 || group_size == 64,
|
||||
"Unsupported group size: ", group_size);
|
||||
|
||||
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
|
||||
var_epsilon, scale_ub, residual,
|
||||
is_scale_transposed);
|
||||
}
|
||||
@ -9,6 +9,7 @@
|
||||
#include "quant_conversions.cuh"
|
||||
|
||||
#include "../../cub_helpers.h"
|
||||
#include "../../cuda_compat.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@ -43,62 +44,150 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
*rms = s_rms;
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||
__device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid,
|
||||
int64_t thread_in_warp,
|
||||
int64_t reduced_elems) {
|
||||
static_assert(WARP_SIZE == 32 || WARP_SIZE == 64);
|
||||
if constexpr (WARP_SIZE == 64) {
|
||||
if (thread_in_warp + 64 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 64]);
|
||||
}
|
||||
if (thread_in_warp + 32 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 32]);
|
||||
if (thread_in_warp + 16 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 16]);
|
||||
if (thread_in_warp + 8 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 8]);
|
||||
if (thread_in_warp + 4 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 4]);
|
||||
if (thread_in_warp + 2 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 2]);
|
||||
if (thread_in_warp + 1 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 1]);
|
||||
return val[tid];
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
|
||||
bool is_scale_transposed = false>
|
||||
__device__ void compute_dynamic_per_token_scales(
|
||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
int32_t const hidden_size,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
|
||||
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
|
||||
int32_t const group_size = 0) {
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
}
|
||||
|
||||
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
|
||||
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
s_token_scale = scale; // Shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // Global output store
|
||||
}
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
__syncthreads();
|
||||
if (group_size > 0) {
|
||||
__shared__ float s_max_vals[1024];
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
int64_t num_groups = hidden_size / group_size;
|
||||
int64_t const threads_per_group = blockDim.x / num_groups;
|
||||
int64_t const thread_in_group = threadIdx.x % threads_per_group;
|
||||
int64_t const group_offset = threadIdx.x / threads_per_group * group_size;
|
||||
int64_t const thread_offset = group_offset + thread_in_group;
|
||||
int64_t const thread_end =
|
||||
min(group_offset + group_size, static_cast<int64_t>(hidden_size));
|
||||
for (auto i = thread_offset; i < thread_end; i += threads_per_group) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
}
|
||||
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
|
||||
}
|
||||
s_max_vals[threadIdx.x] = block_absmax_val_maybe;
|
||||
__syncthreads();
|
||||
|
||||
*token_scale = s_token_scale;
|
||||
int64_t const warp_size = WARP_SIZE;
|
||||
int64_t const num_warps = blockDim.x / warp_size;
|
||||
int64_t const warp_id = threadIdx.x / warp_size;
|
||||
int64_t const thread_in_warp = threadIdx.x % warp_size;
|
||||
int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps;
|
||||
for (auto i = 0; i < groups_per_warp; ++i) {
|
||||
int64_t const group_id = i * num_warps + warp_id;
|
||||
if (group_id < num_groups) {
|
||||
int64_t warp_start = group_id * threads_per_group;
|
||||
int64_t const start = warp_start + thread_in_warp;
|
||||
int64_t const warp_end = min(warp_start + threads_per_group,
|
||||
static_cast<int64_t>(hidden_size));
|
||||
for (auto j = start; j + warp_size < warp_end; j += warp_size) {
|
||||
s_max_vals[start] =
|
||||
fmaxf(s_max_vals[start], s_max_vals[j + warp_size]);
|
||||
}
|
||||
warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp,
|
||||
min(warp_end - warp_start, warp_size));
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (thread_in_group == 0 && thread_offset < thread_end) {
|
||||
block_absmax_val_maybe = s_max_vals[threadIdx.x];
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
// Global output store
|
||||
if constexpr (is_scale_transposed) {
|
||||
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x +
|
||||
blockIdx.x] = scale;
|
||||
} else {
|
||||
all_token_scales[blockIdx.x * num_groups +
|
||||
threadIdx.x / threads_per_group] = scale;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
}
|
||||
|
||||
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
|
||||
}
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
|
||||
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
s_token_scale = scale; // Shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // Global output store
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
*token_scale = s_token_scale;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
|
||||
bool has_residual = false>
|
||||
bool has_residual = false, bool is_scale_transposed = false>
|
||||
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
scalar_t const* __restrict__ input,
|
||||
scalar_t const* __restrict__ weight,
|
||||
float const rms, float const scale,
|
||||
float const rms, float* const scale,
|
||||
int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
scalar_t* __restrict__ residual = nullptr,
|
||||
int32_t const group_size = 0) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
@ -109,8 +198,21 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
// Norm
|
||||
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||
// Quant
|
||||
// If groupwise is_scale_inverted is true, so we invert the scale here.
|
||||
int64_t scale_idx = 0;
|
||||
if (group_size > 0) {
|
||||
if constexpr (is_scale_transposed) {
|
||||
scale_idx = (i / group_size) * gridDim.x + blockIdx.x;
|
||||
} else {
|
||||
scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size;
|
||||
}
|
||||
}
|
||||
auto scale_val =
|
||||
(group_size > 0
|
||||
? (is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx])
|
||||
: *scale);
|
||||
output[token_offset + i] =
|
||||
ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale);
|
||||
ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale_val);
|
||||
}
|
||||
}
|
||||
|
||||
@ -178,95 +280,191 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
|
||||
// Vectorized version of vllm::compute_dynamic_per_token_scales
|
||||
// hidden_size must be a multiple of 4
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
|
||||
bool is_scale_transposed = false, int32_t group_size = 0>
|
||||
__device__ void compute_dynamic_per_token_scales(
|
||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
int32_t const hidden_size,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
|
||||
// Vectorized input/weight/residual to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vec_input =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
vec4_t<scalar_t> const* vec_weight =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||
vec4_t<scalar_t> const* vec_residual = nullptr;
|
||||
if constexpr (has_residual) {
|
||||
vec_residual =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||
}
|
||||
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
|
||||
const int VEC_SIZE = 4;
|
||||
int32_t const num_vec_elems = hidden_size >> 2;
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
// Vectorized input/weight/residual to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vec_input = nullptr;
|
||||
vec4_t<scalar_t> const* vec_weight = nullptr;
|
||||
vec4_t<scalar_t> const* vec_residual = nullptr;
|
||||
|
||||
vec4_t<float> x;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] = static_cast<float>(in.val[j]);
|
||||
}
|
||||
if constexpr (group_size > 0) {
|
||||
__shared__ float s_max_vals[1024];
|
||||
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
int64_t const num_groups = hidden_size / group_size;
|
||||
int64_t const threads_per_group = blockDim.x / num_groups;
|
||||
int64_t const thread_in_group = threadIdx.x % threads_per_group;
|
||||
int64_t const group_offset =
|
||||
threadIdx.x / threads_per_group * (group_size >> 2);
|
||||
int64_t const thread_offset = group_offset + thread_in_group;
|
||||
int64_t const thread_end = min(group_offset + (group_size >> 2),
|
||||
static_cast<int64_t>(hidden_size >> 2));
|
||||
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
vec_residual =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||
}
|
||||
int32_t const num_vec_elems = thread_end;
|
||||
|
||||
#pragma unroll 4
|
||||
for (auto i = thread_offset; i < num_vec_elems; i += threads_per_group) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] += static_cast<float>(r.val[j]);
|
||||
x.val[j] = static_cast<float>(in.val[j]);
|
||||
}
|
||||
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] += static_cast<float>(r.val[j]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
block_absmax_val_maybe =
|
||||
fmaxf(block_absmax_val_maybe,
|
||||
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
|
||||
}
|
||||
}
|
||||
|
||||
s_max_vals[threadIdx.x] = block_absmax_val_maybe;
|
||||
__syncthreads();
|
||||
|
||||
int64_t const warp_size = WARP_SIZE;
|
||||
int64_t const num_warps = blockDim.x / warp_size;
|
||||
int64_t const warp_id = threadIdx.x / warp_size;
|
||||
int64_t const thread_in_warp = threadIdx.x % warp_size;
|
||||
int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps;
|
||||
for (auto i = 0; i < groups_per_warp; ++i) {
|
||||
int64_t const group_id = i * num_warps + warp_id;
|
||||
if (group_id < num_groups) {
|
||||
int64_t warp_start = group_id * threads_per_group;
|
||||
int64_t const start = warp_start + thread_in_warp;
|
||||
int64_t const warp_end = min(warp_start + threads_per_group,
|
||||
static_cast<int64_t>(hidden_size));
|
||||
for (auto j = start; j + warp_size < warp_end; j += warp_size) {
|
||||
s_max_vals[start] =
|
||||
fmaxf(s_max_vals[start], s_max_vals[j + warp_size]);
|
||||
}
|
||||
warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp,
|
||||
min(warp_end - warp_start, warp_size));
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (thread_in_group == 0 && thread_offset < thread_end) {
|
||||
block_absmax_val_maybe = s_max_vals[threadIdx.x];
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
// Global output store
|
||||
if constexpr (is_scale_transposed) {
|
||||
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x +
|
||||
blockIdx.x] = scale;
|
||||
} else {
|
||||
all_token_scales[blockIdx.x * num_groups +
|
||||
threadIdx.x / threads_per_group] = scale;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
} else {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||
if constexpr (has_residual) {
|
||||
vec_residual =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||
}
|
||||
|
||||
int32_t const num_vec_elems = (hidden_size >> 2);
|
||||
|
||||
#pragma unroll 4
|
||||
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
block_absmax_val_maybe =
|
||||
fmaxf(block_absmax_val_maybe,
|
||||
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] = static_cast<float>(in.val[j]);
|
||||
}
|
||||
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] += static_cast<float>(r.val[j]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
block_absmax_val_maybe =
|
||||
fmaxf(block_absmax_val_maybe,
|
||||
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
|
||||
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
s_token_scale = scale; // shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // global output store
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
s_token_scale = scale; // shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // global output store
|
||||
}
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
|
||||
*token_scale = s_token_scale;
|
||||
*token_scale = s_token_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// hidden_size must be a multiple of 4
|
||||
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
|
||||
bool has_residual = false>
|
||||
bool has_residual = false, bool is_scale_transposed = false,
|
||||
int32_t group_size = 0>
|
||||
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
scalar_t const* __restrict__ input,
|
||||
scalar_t const* __restrict__ weight,
|
||||
float const rms, float const scale,
|
||||
float const rms, float* const scale,
|
||||
int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
|
||||
// Vectorized input/output/weight/residual to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vec_input =
|
||||
@ -311,10 +509,26 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
}
|
||||
|
||||
q8x4_t<scalar_out_t> out;
|
||||
|
||||
float scale_val;
|
||||
|
||||
if constexpr (group_size > 0) {
|
||||
int64_t const num_groups = hidden_size / group_size;
|
||||
int64_t scale_idx = 0;
|
||||
if constexpr (is_scale_transposed) {
|
||||
scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x;
|
||||
} else {
|
||||
scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size;
|
||||
}
|
||||
scale_val =
|
||||
is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx];
|
||||
} else {
|
||||
scale_val = *scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
out.val[j] = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||
static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale);
|
||||
static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale_val);
|
||||
}
|
||||
vec_output[i] = out;
|
||||
}
|
||||
|
||||
@ -215,6 +215,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
|
||||
&rms_norm_dynamic_per_token_quant);
|
||||
|
||||
// Fused Layernorm + Block quant kernels
|
||||
ops.def(
|
||||
"rms_norm_per_block_quant(Tensor! result, Tensor input, "
|
||||
"Tensor weight, Tensor! scale, float epsilon, "
|
||||
"Tensor? scale_ub, Tensor!? residual, int group_size, "
|
||||
"bool is_scale_transposed) -> ()");
|
||||
ops.impl("rms_norm_per_block_quant", torch::kCUDA, &rms_norm_per_block_quant);
|
||||
|
||||
// Rotary embedding
|
||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
||||
ops.def(
|
||||
|
||||
@ -18,6 +18,9 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
@ -25,10 +28,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
|
||||
from ..utils import override_cutlass_fp8_supported
|
||||
from .backend import TestBackend
|
||||
@ -44,7 +49,7 @@ class TestModel(torch.nn.Module):
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float,
|
||||
static: bool,
|
||||
group_shape: GroupShape,
|
||||
cuda_force_torch: bool,
|
||||
*args,
|
||||
**kwargs,
|
||||
@ -52,8 +57,17 @@ class TestModel(torch.nn.Module):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cuda_force_torch = cuda_force_torch
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
if group_shape.is_per_group():
|
||||
self.wscale = [
|
||||
torch.rand(
|
||||
(hidden_size // group_shape[1], hidden_size // group_shape[1]),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
else:
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
static = group_shape == GroupShape.PER_TENSOR
|
||||
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
if static:
|
||||
@ -61,18 +75,29 @@ class TestModel(torch.nn.Module):
|
||||
else:
|
||||
self.scale = [None for _ in range(3)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
for _ in range(3)
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
|
||||
]
|
||||
if not group_shape.is_per_group():
|
||||
self.w = [self.w[0].t() for _ in range(3)]
|
||||
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=static,
|
||||
if group_shape.is_per_group():
|
||||
self.fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||
act_quant_group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=False,
|
||||
)
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
|
||||
else:
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=static,
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
self.group_shape = group_shape
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
@ -119,11 +144,19 @@ class TestModel(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
GROUP_SHAPES = [
|
||||
GroupShape.PER_TOKEN,
|
||||
GroupShape.PER_TENSOR,
|
||||
GroupShape(1, 128),
|
||||
GroupShape(1, 64),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("static", [True, False])
|
||||
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
@ -139,7 +172,7 @@ def test_fusion_rmsnorm_quant(
|
||||
hidden_size,
|
||||
num_tokens,
|
||||
eps,
|
||||
static,
|
||||
group_shape,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
cuda_force_torch,
|
||||
@ -149,6 +182,15 @@ def test_fusion_rmsnorm_quant(
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||
|
||||
if not enable_quant_fp8_custom_op and group_shape.is_per_group():
|
||||
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
|
||||
|
||||
# Skip test for 64-bit group shape when running with cutlass or deepgemm
|
||||
if group_shape == GroupShape(1, 64) and (
|
||||
cutlass_block_fp8_supported() or is_deep_gemm_supported()
|
||||
):
|
||||
pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")
|
||||
|
||||
custom_ops = []
|
||||
if enable_rms_norm_custom_op:
|
||||
custom_ops.append("+rms_norm")
|
||||
@ -172,8 +214,7 @@ def test_fusion_rmsnorm_quant(
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
backend2 = TestBackend(noop_pass, cleanup_pass)
|
||||
model = TestModel(hidden_size, eps, static, cuda_force_torch)
|
||||
|
||||
model = TestModel(hidden_size, eps, group_shape, cuda_force_torch)
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
@ -8,6 +8,12 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8,
|
||||
)
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
|
||||
@ -21,6 +27,7 @@ NUM_TOKENS_HIDDEN_SIZES = [
|
||||
|
||||
ADD_RESIDUAL = [False, True]
|
||||
SCALE_UBS = [True, False]
|
||||
GROUP_SIZES = [None, [1, 64], [1, 128]]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
@ -45,12 +52,13 @@ def ref_rms_norm(
|
||||
return out, residual
|
||||
|
||||
|
||||
def ref_dynamic_per_token_quant(
|
||||
def ref_dynamic_per_token_or_block_quant(
|
||||
rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
@ -59,13 +67,24 @@ def ref_dynamic_per_token_quant(
|
||||
torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)
|
||||
|
||||
# Quant
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
torch_out, scales = ops.scaled_fp8_quant(
|
||||
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
|
||||
)
|
||||
if group_size is not None:
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
torch_out, scales = per_token_group_quant_fp8(
|
||||
torch_out, group_size=group_size[1], use_ue8m0=False
|
||||
)
|
||||
else:
|
||||
assert quant_dtype == torch.int8
|
||||
torch_out, scales = per_token_group_quant_int8(
|
||||
torch_out, group_size=group_size[1]
|
||||
)
|
||||
else:
|
||||
assert quant_dtype == torch.int8
|
||||
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
torch_out, scales = ops.scaled_fp8_quant(
|
||||
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
|
||||
)
|
||||
else:
|
||||
assert quant_dtype == torch.int8
|
||||
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
|
||||
|
||||
return torch_out, scales, residual
|
||||
|
||||
@ -76,24 +95,32 @@ def ref_impl(
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
return ref_dynamic_per_token_quant(
|
||||
rms_norm_layer, x, quant_dtype, residual, scale_ub
|
||||
return ref_dynamic_per_token_or_block_quant(
|
||||
rms_norm_layer, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
|
||||
|
||||
def ops_dynamic_per_token_quant(
|
||||
def ops_dynamic_per_token_or_block_quant(
|
||||
weight: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
if residual is not None:
|
||||
residual = residual.clone()
|
||||
out, scales = ops.rms_norm_dynamic_per_token_quant(
|
||||
x, weight, EPS, quant_dtype, scale_ub, residual
|
||||
)
|
||||
if group_size is not None:
|
||||
out, scales = ops.rms_norm_per_block_quant(
|
||||
x, weight, EPS, quant_dtype, group_size, scale_ub, residual, True
|
||||
)
|
||||
scales = scales.contiguous()
|
||||
else:
|
||||
out, scales = ops.rms_norm_dynamic_per_token_quant(
|
||||
x, weight, EPS, quant_dtype, scale_ub, residual
|
||||
)
|
||||
return out, scales, residual
|
||||
|
||||
|
||||
@ -103,8 +130,11 @@ def ops_impl(
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub)
|
||||
return ops_dynamic_per_token_or_block_quant(
|
||||
weight, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
|
||||
@ -112,6 +142,7 @@ def ops_impl(
|
||||
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
|
||||
@pytest.mark.parametrize("group_size", GROUP_SIZES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
@ -122,6 +153,7 @@ def test_rms_norm(
|
||||
has_scale_ub: bool,
|
||||
dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: list[int] | None,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
@ -130,6 +162,14 @@ def test_rms_norm(
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
if group_size is not None and hidden_size % group_size[1] != 0:
|
||||
# skip
|
||||
return
|
||||
|
||||
if group_size is not None and has_scale_ub:
|
||||
# blockwise baseline doesn't support scale_ub
|
||||
return
|
||||
|
||||
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
|
||||
# skip
|
||||
return
|
||||
@ -150,10 +190,10 @@ def test_rms_norm(
|
||||
scale_ub = None
|
||||
|
||||
ref_out, ref_scales, ref_residual = ref_impl(
|
||||
layer, x, quant_dtype, residual, scale_ub
|
||||
layer, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
ops_out, ops_scales, ops_residual = ops_impl(
|
||||
layer.weight, x, quant_dtype, residual, scale_ub
|
||||
layer.weight, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
|
||||
assert ref_out.dtype == quant_dtype
|
||||
@ -166,11 +206,15 @@ def test_rms_norm(
|
||||
assert torch.allclose(ref_scales, ops_scales)
|
||||
a = ref_out.to(dtype=torch.float32)
|
||||
b = ops_out.to(dtype=torch.float32)
|
||||
ok = torch.allclose(a, b)
|
||||
ok = torch.allclose(a, b, atol=1e-6)
|
||||
if not ok:
|
||||
# fallback: compare dequantized values with relaxed tolerance
|
||||
a_deq = a * ref_scales.view(-1, 1)
|
||||
b_deq = b * ops_scales.view(-1, 1)
|
||||
if group_size is None:
|
||||
a_deq = a * ref_scales.view(-1, 1)
|
||||
b_deq = b * ops_scales.view(-1, 1)
|
||||
else:
|
||||
a_deq = a * ref_scales.repeat_interleave(group_size[1], dim=1)
|
||||
b_deq = b * ops_scales.repeat_interleave(group_size[1], dim=1)
|
||||
# NOTE: It is possible that some future test cases trigger this
|
||||
# max diff due to precision issues. If such an error is
|
||||
# encountered, it's recommended to inspect the differences between
|
||||
|
||||
@ -436,6 +436,46 @@ def rms_norm_dynamic_per_token_quant(
|
||||
return output, scales
|
||||
|
||||
|
||||
# fused quant layer norm ops blocked
|
||||
def rms_norm_per_block_quant(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: list[int],
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
is_scale_transposed: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert len(group_size) == 2
|
||||
output = torch.empty_like(input, dtype=quant_dtype)
|
||||
if is_scale_transposed:
|
||||
scales = torch.empty(
|
||||
(input.shape[-1] // group_size[1], input.numel() // input.shape[-1]),
|
||||
device=input.device,
|
||||
dtype=torch.float32,
|
||||
).transpose(0, 1)
|
||||
else:
|
||||
scales = torch.empty(
|
||||
(input.numel() // input.shape[-1], input.shape[-1] // group_size[1]),
|
||||
device=input.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
torch.ops._C.rms_norm_per_block_quant(
|
||||
output,
|
||||
input,
|
||||
weight,
|
||||
scales,
|
||||
epsilon,
|
||||
scale_ub,
|
||||
residual,
|
||||
group_size[1],
|
||||
is_scale_transposed,
|
||||
)
|
||||
return output, scales
|
||||
|
||||
|
||||
# quantization ops
|
||||
# awq
|
||||
def awq_dequantize(
|
||||
|
||||
@ -15,13 +15,22 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
should_use_deepgemm_for_fp8_linear_for_nk,
|
||||
)
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
@ -58,6 +67,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
|
||||
class FusedRMSQuantKey(NamedTuple):
|
||||
@ -90,6 +102,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, True
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic128Sym, False
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic128Sym, True
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic64Sym, False
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic64Sym, True
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
@ -100,6 +124,15 @@ class RMSNormQuantPattern:
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
|
||||
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
|
||||
using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
|
||||
self.model_dtype,
|
||||
config.model_config.hf_config.intermediate_size,
|
||||
config.model_config.hf_config.hidden_size,
|
||||
)
|
||||
use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
|
||||
use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False
|
||||
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
@ -108,7 +141,9 @@ class RMSNormQuantPattern:
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(key.quant)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
|
||||
)
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
@ -218,6 +253,120 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(
|
||||
input, transposed=self.quant_matcher.use_col_major_scales
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.quant_matcher.use_col_major_scales,
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(
|
||||
input, transposed=self.quant_matcher.use_col_major_scales
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.quant_matcher.use_col_major_scales,
|
||||
)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
@ -340,6 +489,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse fused_add_rms_norm + fp8 group quant
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
@ -366,9 +534,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
def uuid(self) -> Any:
|
||||
return self.hash_source(
|
||||
self,
|
||||
RMSNormGroupQuantPattern,
|
||||
RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormGroupQuantPattern,
|
||||
)
|
||||
|
||||
@ -13,6 +13,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
_normalize_quant_group_shape,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
@ -35,6 +37,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
|
||||
@ -224,12 +230,20 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
|
||||
|
||||
class MatcherQuantFP8(MatcherCustomOp):
|
||||
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
enabled: bool | None = None,
|
||||
use_col_major_scales: bool = False,
|
||||
use_e8m0: bool = False,
|
||||
):
|
||||
if enabled is None:
|
||||
enabled = QuantFP8.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.quant_key = quant_key
|
||||
self.use_col_major_scales = use_col_major_scales
|
||||
self.use_e8m0 = use_e8m0
|
||||
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||
|
||||
@ -248,6 +262,27 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
input.shape, device=input.device, dtype=self.quant_key.dtype
|
||||
)
|
||||
|
||||
if self.quant_key.scale.group_shape.is_per_group():
|
||||
assert scale is None
|
||||
scale = self.make_scale(input, transposed=self.use_col_major_scales)
|
||||
|
||||
finfo = torch.finfo(self.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
input=input,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.quant_key.scale.group_shape[1],
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.use_e8m0,
|
||||
)
|
||||
return result, scale
|
||||
|
||||
if self.quant_key.scale.static:
|
||||
assert scale is not None
|
||||
_, result = auto_functionalized(
|
||||
@ -269,7 +304,7 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.quant_fp8(input, scale)
|
||||
|
||||
def make_scale(self, input: torch.Tensor):
|
||||
def make_scale(self, input: torch.Tensor, transposed: bool = False):
|
||||
normalized_group_shape = _normalize_quant_group_shape(
|
||||
input, self.quant_key.scale.group_shape
|
||||
)
|
||||
@ -277,6 +312,11 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
input.shape[0] // normalized_group_shape[0],
|
||||
input.shape[1] // normalized_group_shape[1],
|
||||
)
|
||||
if transposed:
|
||||
scale_shape = tuple(reversed(scale_shape))
|
||||
return torch.empty(
|
||||
scale_shape, device=input.device, dtype=torch.float32
|
||||
).permute(-1, -2)
|
||||
|
||||
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
|
||||
|
||||
|
||||
@ -733,7 +733,7 @@ def per_token_group_quant_fp8(
|
||||
assert out_q is None or out_q.shape == x.shape
|
||||
x_q = out_q
|
||||
if x_q is None:
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
x_q = torch.empty(x.shape, device=x.device, dtype=dtype)
|
||||
|
||||
# Allocate the scale tensor in either row- or column-major format.
|
||||
if column_major_scales:
|
||||
|
||||
@ -115,6 +115,12 @@ kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
|
||||
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
|
||||
kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale)
|
||||
|
||||
kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128))
|
||||
kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True)
|
||||
|
||||
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
|
||||
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
|
||||
@ -381,6 +381,22 @@ def should_use_deepgemm_for_fp8_linear(
|
||||
)
|
||||
|
||||
|
||||
def should_use_deepgemm_for_fp8_linear_for_nk(
|
||||
output_dtype: torch.dtype,
|
||||
shape0: int,
|
||||
shape1: int,
|
||||
supports_deep_gemm: bool | None = None,
|
||||
):
|
||||
if supports_deep_gemm is None:
|
||||
supports_deep_gemm = is_deep_gemm_supported()
|
||||
return (
|
||||
supports_deep_gemm
|
||||
and output_dtype == torch.bfloat16
|
||||
and shape0 % 128 == 0
|
||||
and shape1 % 128 == 0
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"calc_diff",
|
||||
"fp8_gemm_nt",
|
||||
@ -394,6 +410,7 @@ __all__ = [
|
||||
"is_deep_gemm_supported",
|
||||
"get_num_sms",
|
||||
"should_use_deepgemm_for_fp8_linear",
|
||||
"should_use_deepgemm_for_fp8_linear_for_nk",
|
||||
"get_col_major_tma_aligned_tensor",
|
||||
"get_mk_alignment_for_contiguous_layout",
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user