mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 13:51:19 +08:00
[NVFP4][Perf] Tune NVFP4 input quant kernel for small batch size (#30897)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
b471092d3a
commit
06d490282f
177
benchmarks/kernels/bench_nvfp4_quant.py
Normal file
177
benchmarks/kernels/bench_nvfp4_quant.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils.flashinfer import flashinfer_fp4_quantize
|
||||||
|
|
||||||
|
if not current_platform.has_device_capability(100):
|
||||||
|
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
|
||||||
|
|
||||||
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||||
|
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||||
|
|
||||||
|
PROVIDER_CFGS = {
|
||||||
|
"vllm": dict(backend="vllm", enabled=True),
|
||||||
|
"flashinfer": dict(backend="flashinfer", enabled=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||||
|
|
||||||
|
|
||||||
|
def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute global scale for FP4 quantization."""
|
||||||
|
amax = torch.abs(tensor).max().to(torch.float32)
|
||||||
|
return FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / amax
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=_enabled,
|
||||||
|
line_names=_enabled,
|
||||||
|
ylabel="us (lower is better)",
|
||||||
|
plot_name="NVFP4 Input Quantization Latency (us)",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, N, K):
|
||||||
|
M = batch_size
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# Create input tensor
|
||||||
|
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Compute global scale for activation
|
||||||
|
a_global_scale = compute_global_scale(a)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
cfg = PROVIDER_CFGS[provider]
|
||||||
|
|
||||||
|
if cfg["backend"] == "vllm":
|
||||||
|
# vLLM's FP4 quantization
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: ops.scaled_fp4_quant(a, a_global_scale),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif cfg["backend"] == "flashinfer":
|
||||||
|
# FlashInfer's FP4 quantization
|
||||||
|
# Use is_sf_swizzled_layout=True to match vLLM's output format
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: flashinfer_fp4_quantize(
|
||||||
|
a, a_global_scale, is_sf_swizzled_layout=True
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert ms to us for better readability at small batch sizes
|
||||||
|
to_us = lambda t_ms: t_ms * 1000
|
||||||
|
return to_us(ms), to_us(max_ms), to_us(min_ms)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_shapes(args):
|
||||||
|
out = []
|
||||||
|
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||||
|
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||||
|
KN[tp_dim] //= tp_size
|
||||||
|
KN.append(model)
|
||||||
|
out.append(KN)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
|
||||||
|
"""Test accuracy between vLLM and FlashInfer FP4 quantization."""
|
||||||
|
# Create input tensor
|
||||||
|
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Compute global scale
|
||||||
|
a_global_scale = compute_global_scale(a)
|
||||||
|
|
||||||
|
# vLLM quantization
|
||||||
|
vllm_fp4, vllm_scale = ops.scaled_fp4_quant(a, a_global_scale)
|
||||||
|
|
||||||
|
# FlashInfer quantization (with swizzled layout to match vLLM's output)
|
||||||
|
flashinfer_fp4, flashinfer_scale = flashinfer_fp4_quantize(
|
||||||
|
a, a_global_scale, is_sf_swizzled_layout=True
|
||||||
|
)
|
||||||
|
flashinfer_scale = flashinfer_scale.view(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
# Compare outputs
|
||||||
|
torch.testing.assert_close(
|
||||||
|
vllm_fp4,
|
||||||
|
flashinfer_fp4,
|
||||||
|
)
|
||||||
|
print(f"M={M}, K={K}, dtype={dtype}: PASSED")
|
||||||
|
|
||||||
|
|
||||||
|
def test_accuracy():
|
||||||
|
"""Run accuracy tests across various shapes."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Running accuracy tests: vLLM vs FlashInfer")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# Test various batch sizes and hidden dimensions
|
||||||
|
Ms = [1, 1024]
|
||||||
|
Ks = [4096]
|
||||||
|
|
||||||
|
for M in Ms:
|
||||||
|
for K in Ks:
|
||||||
|
_test_accuracy_once(M, K, dtype, device)
|
||||||
|
|
||||||
|
print("\nAll accuracy tests passed!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Benchmark NVFP4 quantization: vLLM vs FlashInfer"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||||
|
choices=list(WEIGHT_SHAPES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to save benchmark results",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--accuracy",
|
||||||
|
action="store_true",
|
||||||
|
help="Run accuracy tests",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.accuracy:
|
||||||
|
test_accuracy()
|
||||||
|
|
||||||
|
for K, N, model in prepare_shapes(args):
|
||||||
|
print(f"\n{model}, N={N} K={K}")
|
||||||
|
benchmark.run(
|
||||||
|
print_data=True,
|
||||||
|
save_path=args.save_path,
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nBenchmark finished!")
|
||||||
@ -74,6 +74,9 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
|||||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||||
"Vec size is not matched.");
|
"Vec size is not matched.");
|
||||||
|
|
||||||
|
// Precompute SF layout parameter (constant for entire kernel).
|
||||||
|
int32_t const numKTiles = (numCols + 63) / 64;
|
||||||
|
|
||||||
// Get the global scaling factor, which will be applied to the SF.
|
// Get the global scaling factor, which will be applied to the SF.
|
||||||
// Note SFScale is the same as next GEMM's alpha, which is
|
// Note SFScale is the same as next GEMM's alpha, which is
|
||||||
// (448.f / (Alpha_A / 6.f)).
|
// (448.f / (Alpha_A / 6.f)).
|
||||||
@ -101,7 +104,7 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
|||||||
auto sf_out =
|
auto sf_out =
|
||||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||||
rowIdx, colIdx, numCols, SFout);
|
rowIdx, colIdx, numKTiles, SFout);
|
||||||
|
|
||||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal,
|
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal,
|
||||||
sf_out);
|
sf_out);
|
||||||
|
|||||||
@ -25,6 +25,7 @@
|
|||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
|
#include "cuda_utils.h"
|
||||||
#include "nvfp4_utils.cuh"
|
#include "nvfp4_utils.cuh"
|
||||||
#include "launch_bounds_utils.h"
|
#include "launch_bounds_utils.h"
|
||||||
|
|
||||||
@ -44,6 +45,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||||
"Vec size is not matched.");
|
"Vec size is not matched.");
|
||||||
|
|
||||||
|
// Precompute SF layout parameter (constant for entire kernel).
|
||||||
|
int32_t const numKTiles = (numCols + 63) / 64;
|
||||||
|
|
||||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||||
|
|
||||||
@ -112,17 +116,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
// (448.f / (Alpha_A / 6.f)).
|
// (448.f / (Alpha_A / 6.f)).
|
||||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||||
|
|
||||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
|
||||||
// The actual output_scales dim is computed from the padded numCols.
|
|
||||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
|
||||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
|
||||||
uint32_t* SFout_in_expert =
|
uint32_t* SFout_in_expert =
|
||||||
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
|
||||||
|
|
||||||
auto sf_out =
|
auto sf_out =
|
||||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
|
||||||
|
|
||||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||||
}
|
}
|
||||||
@ -140,6 +140,10 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
|||||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||||
"Vec size is not matched.");
|
"Vec size is not matched.");
|
||||||
|
|
||||||
|
// Precompute SF layout parameter (constant for entire kernel).
|
||||||
|
int32_t const numKTiles = (numCols + 63) / 64;
|
||||||
|
|
||||||
extern __shared__ uint32_t shared_input_offsets[];
|
extern __shared__ uint32_t shared_input_offsets[];
|
||||||
|
|
||||||
// Load input offsets into shared memory.
|
// Load input offsets into shared memory.
|
||||||
@ -202,16 +206,13 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
|||||||
|
|
||||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||||
|
|
||||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
|
||||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
|
||||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
|
||||||
uint32_t* SFout_in_expert =
|
uint32_t* SFout_in_expert =
|
||||||
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
|
||||||
|
|
||||||
auto sf_out =
|
auto sf_out =
|
||||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
|
||||||
|
|
||||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||||
}
|
}
|
||||||
@ -222,12 +223,8 @@ void quant_impl(void* output, void* output_scale, void* input,
|
|||||||
void* input_global_scale, void* input_offset_by_experts,
|
void* input_global_scale, void* input_offset_by_experts,
|
||||||
void* output_scale_offset_by_experts, int m_topk, int k,
|
void* output_scale_offset_by_experts, int m_topk, int k,
|
||||||
int n_experts, cudaStream_t stream) {
|
int n_experts, cudaStream_t stream) {
|
||||||
// TODO: this multiProcessorCount should be cached.
|
int multiProcessorCount =
|
||||||
int device;
|
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||||
cudaGetDevice(&device);
|
|
||||||
int multiProcessorCount;
|
|
||||||
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount,
|
|
||||||
device);
|
|
||||||
|
|
||||||
// Grid, Block size.
|
// Grid, Block size.
|
||||||
// Each thread converts 8 values.
|
// Each thread converts 8 values.
|
||||||
|
|||||||
@ -38,6 +38,12 @@ __host__ __device__ inline Int round_up(Int x, Int y) {
|
|||||||
return (x + y - 1) / y * y;
|
return (x + y - 1) / y * y;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compute effective rows for grid configuration with swizzled SF layouts.
|
||||||
|
inline int computeEffectiveRows(int m) {
|
||||||
|
constexpr int ROW_TILE = 128;
|
||||||
|
return round_up(m, ROW_TILE);
|
||||||
|
}
|
||||||
|
|
||||||
// Use UE4M3 by default.
|
// Use UE4M3 by default.
|
||||||
template <class Type, bool UE8M0_SF = false>
|
template <class Type, bool UE8M0_SF = false>
|
||||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||||
@ -49,6 +55,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||||
"Vec size is not matched.");
|
"Vec size is not matched.");
|
||||||
|
|
||||||
|
// Precompute SF layout parameter (constant for entire kernel).
|
||||||
|
int32_t const numKTiles = (numCols + 63) / 64;
|
||||||
|
|
||||||
int sf_m = round_up<int>(numRows, 128);
|
int sf_m = round_up<int>(numRows, 128);
|
||||||
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
|
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
|
||||||
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
|
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
|
||||||
@ -79,7 +88,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
auto sf_out =
|
auto sf_out =
|
||||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||||
rowIdx, colIdx, numCols, SFout);
|
rowIdx, colIdx, numKTiles, SFout);
|
||||||
|
|
||||||
out_pos =
|
out_pos =
|
||||||
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
|
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
|
||||||
@ -87,43 +96,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale,
|
|
||||||
int64_t* output, int32_t* SFOuput, bool useUE8M0,
|
|
||||||
int multiProcessorCount, cudaStream_t stream) {
|
|
||||||
// Grid, Block size.
|
|
||||||
// Each thread converts 8 values.
|
|
||||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
|
||||||
// Get number of blocks per SM
|
|
||||||
int const numBlocksPerSM =
|
|
||||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
|
||||||
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
|
|
||||||
|
|
||||||
// Launch the cvt kernel.
|
|
||||||
if (useUE8M0) {
|
|
||||||
cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(
|
|
||||||
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
|
|
||||||
reinterpret_cast<uint32_t*>(SFOuput));
|
|
||||||
} else {
|
|
||||||
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
|
|
||||||
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
|
|
||||||
reinterpret_cast<uint32_t*>(SFOuput));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Instantiate the function.
|
|
||||||
template void invokeFP4Quantization(int m, int n, half const* input,
|
|
||||||
float const* SFScale, int64_t* output,
|
|
||||||
int32_t* SFOuput, bool useUE8M0,
|
|
||||||
int multiProcessorCount,
|
|
||||||
cudaStream_t stream);
|
|
||||||
|
|
||||||
template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
|
|
||||||
float const* SFScale, int64_t* output,
|
|
||||||
int32_t* SFOuput, bool useUE8M0,
|
|
||||||
int multiProcessorCount,
|
|
||||||
cudaStream_t stream);
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||||
@ -147,13 +119,19 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||||
|
|
||||||
// We don't support e8m0 scales at this moment.
|
// Grid, Block size. Each thread converts 8 values.
|
||||||
bool useUE8M0 = false;
|
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
||||||
|
int const numBlocksPerSM =
|
||||||
|
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||||
|
int effectiveRows = vllm::computeEffectiveRows(m);
|
||||||
|
dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));
|
||||||
|
|
||||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||||
vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr,
|
// NOTE: We don't support e8m0 scales at this moment.
|
||||||
sf_out, useUE8M0, multiProcessorCount, stream);
|
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
|
||||||
|
m, n, input_ptr, input_sf_ptr, reinterpret_cast<uint32_t*>(output_ptr),
|
||||||
|
reinterpret_cast<uint32_t*>(sf_out));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -128,51 +128,42 @@ inline __device__ float reciprocal_approximate_ftz(float a) {
|
|||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compute SF output offset for swizzled tensor core layout.
|
||||||
|
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
|
||||||
|
// Caller must precompute: numKTiles = (numCols + 63) / 64
|
||||||
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||||
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
|
__device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(
|
||||||
int numCols,
|
int rowIdx, int colIdx, int32_t numKTiles, SFType* SFout) {
|
||||||
SFType* SFout) {
|
|
||||||
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
||||||
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||||
|
|
||||||
// One pair of threads write one SF to global memory.
|
// One pair of threads write one SF to global memory.
|
||||||
// TODO: stage through smem for packed STG.32
|
// TODO: stage through smem for packed STG.32
|
||||||
// is it better than STG.8 from 4 threads ?
|
// is it better than STG.8 from 4 threads ?
|
||||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF != 0) {
|
||||||
// SF vector index (16 elements share one SF in the K dimension).
|
return nullptr;
|
||||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
|
||||||
int32_t mIdx = rowIdx;
|
|
||||||
|
|
||||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
|
||||||
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
|
||||||
|
|
||||||
int32_t mTileIdx = mIdx / (32 * 4);
|
|
||||||
// SF vector size 16.
|
|
||||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
|
||||||
int32_t numKTiles = (numCols + factor - 1) / factor;
|
|
||||||
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
|
||||||
|
|
||||||
int32_t kTileIdx = (kIdx / 4);
|
|
||||||
int64_t kTileStride = 32 * 4 * 4;
|
|
||||||
|
|
||||||
// M tile layout [32, 4] is column-major.
|
|
||||||
int32_t outerMIdx = (mIdx % 32);
|
|
||||||
int64_t outerMStride = 4 * 4;
|
|
||||||
|
|
||||||
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
|
||||||
int64_t innerMStride = 4;
|
|
||||||
|
|
||||||
int32_t innerKIdx = (kIdx % 4);
|
|
||||||
int64_t innerKStride = 1;
|
|
||||||
|
|
||||||
// Compute the global offset.
|
|
||||||
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
|
|
||||||
outerMIdx * outerMStride + innerMIdx * innerMStride +
|
|
||||||
innerKIdx * innerKStride;
|
|
||||||
|
|
||||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
|
||||||
}
|
}
|
||||||
return nullptr;
|
|
||||||
|
// SF vector index (16 elements share one SF in the K dimension).
|
||||||
|
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||||
|
int32_t mIdx = rowIdx;
|
||||||
|
|
||||||
|
// Decompose indices using bitwise ops (all divisors are powers of 2).
|
||||||
|
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||||
|
int32_t mTileIdx = mIdx >> 7; // mIdx / 128
|
||||||
|
int32_t outerMIdx = mIdx & 31; // mIdx % 32
|
||||||
|
int32_t innerMIdx = (mIdx >> 5) & 3; // (mIdx / 32) % 4
|
||||||
|
int32_t kTileIdx = kIdx >> 2; // kIdx / 4
|
||||||
|
int32_t innerKIdx = kIdx & 3; // kIdx % 4
|
||||||
|
|
||||||
|
// Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 +
|
||||||
|
// outerMIdx * 16 + innerMIdx * 4 + innerKIdx
|
||||||
|
// Use bitwise OR for non-overlapping lower bits.
|
||||||
|
int64_t SFOffset = (static_cast<int64_t>(mTileIdx) * numKTiles + kTileIdx)
|
||||||
|
<< 9 |
|
||||||
|
(outerMIdx << 4) | (innerMIdx << 2) | innerKIdx;
|
||||||
|
|
||||||
|
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quantizes the provided PackedVec into the uint32_t output
|
// Quantizes the provided PackedVec into the uint32_t output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user