mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 05:28:44 +08:00
[NVFP4][Perf] Tune NVFP4 input quant kernel for small batch size (#30897)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
b471092d3a
commit
06d490282f
177
benchmarks/kernels/bench_nvfp4_quant.py
Normal file
177
benchmarks/kernels/bench_nvfp4_quant.py
Normal file
@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.flashinfer import flashinfer_fp4_quantize
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"vllm": dict(backend="vllm", enabled=True),
|
||||
"flashinfer": dict(backend="flashinfer", enabled=True),
|
||||
}
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute global scale for FP4 quantization."""
|
||||
amax = torch.abs(tensor).max().to(torch.float32)
|
||||
return FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / amax
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="us (lower is better)",
|
||||
plot_name="NVFP4 Input Quantization Latency (us)",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Create input tensor
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
|
||||
# Compute global scale for activation
|
||||
a_global_scale = compute_global_scale(a)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
|
||||
if cfg["backend"] == "vllm":
|
||||
# vLLM's FP4 quantization
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: ops.scaled_fp4_quant(a, a_global_scale),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif cfg["backend"] == "flashinfer":
|
||||
# FlashInfer's FP4 quantization
|
||||
# Use is_sf_swizzled_layout=True to match vLLM's output format
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: flashinfer_fp4_quantize(
|
||||
a, a_global_scale, is_sf_swizzled_layout=True
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
# Convert ms to us for better readability at small batch sizes
|
||||
to_us = lambda t_ms: t_ms * 1000
|
||||
return to_us(ms), to_us(max_ms), to_us(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
|
||||
"""Test accuracy between vLLM and FlashInfer FP4 quantization."""
|
||||
# Create input tensor
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
|
||||
# Compute global scale
|
||||
a_global_scale = compute_global_scale(a)
|
||||
|
||||
# vLLM quantization
|
||||
vllm_fp4, vllm_scale = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
|
||||
# FlashInfer quantization (with swizzled layout to match vLLM's output)
|
||||
flashinfer_fp4, flashinfer_scale = flashinfer_fp4_quantize(
|
||||
a, a_global_scale, is_sf_swizzled_layout=True
|
||||
)
|
||||
flashinfer_scale = flashinfer_scale.view(torch.float8_e4m3fn)
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(
|
||||
vllm_fp4,
|
||||
flashinfer_fp4,
|
||||
)
|
||||
print(f"M={M}, K={K}, dtype={dtype}: PASSED")
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
"""Run accuracy tests across various shapes."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Running accuracy tests: vLLM vs FlashInfer")
|
||||
print("=" * 60)
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Test various batch sizes and hidden dimensions
|
||||
Ms = [1, 1024]
|
||||
Ks = [4096]
|
||||
|
||||
for M in Ms:
|
||||
for K in Ks:
|
||||
_test_accuracy_once(M, K, dtype, device)
|
||||
|
||||
print("\nAll accuracy tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark NVFP4 quantization: vLLM vs FlashInfer"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accuracy",
|
||||
action="store_true",
|
||||
help="Run accuracy tests",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.accuracy:
|
||||
test_accuracy()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
print(f"\n{model}, N={N} K={K}")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
save_path=args.save_path,
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("\nBenchmark finished!")
|
||||
@ -74,6 +74,9 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||
"Vec size is not matched.");
|
||||
|
||||
// Precompute SF layout parameter (constant for entire kernel).
|
||||
int32_t const numKTiles = (numCols + 63) / 64;
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
@ -101,7 +104,7 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
auto sf_out =
|
||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx, colIdx, numCols, SFout);
|
||||
rowIdx, colIdx, numKTiles, SFout);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal,
|
||||
sf_out);
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include <cuda_fp8.h>
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "nvfp4_utils.cuh"
|
||||
#include "launch_bounds_utils.h"
|
||||
|
||||
@ -44,6 +45,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||
"Vec size is not matched.");
|
||||
|
||||
// Precompute SF layout parameter (constant for entire kernel).
|
||||
int32_t const numKTiles = (numCols + 63) / 64;
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
|
||||
@ -112,17 +116,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
// The actual output_scales dim is computed from the padded numCols.
|
||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||
uint32_t* SFout_in_expert =
|
||||
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
||||
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
|
||||
|
||||
auto sf_out =
|
||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
@ -140,6 +140,10 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||
"Vec size is not matched.");
|
||||
|
||||
// Precompute SF layout parameter (constant for entire kernel).
|
||||
int32_t const numKTiles = (numCols + 63) / 64;
|
||||
|
||||
extern __shared__ uint32_t shared_input_offsets[];
|
||||
|
||||
// Load input offsets into shared memory.
|
||||
@ -202,16 +206,13 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||
uint32_t* SFout_in_expert =
|
||||
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
||||
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
|
||||
|
||||
auto sf_out =
|
||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
@ -222,12 +223,8 @@ void quant_impl(void* output, void* output_scale, void* input,
|
||||
void* input_global_scale, void* input_offset_by_experts,
|
||||
void* output_scale_offset_by_experts, int m_topk, int k,
|
||||
int n_experts, cudaStream_t stream) {
|
||||
// TODO: this multiProcessorCount should be cached.
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int multiProcessorCount;
|
||||
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount,
|
||||
device);
|
||||
int multiProcessorCount =
|
||||
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||
|
||||
// Grid, Block size.
|
||||
// Each thread converts 8 values.
|
||||
|
||||
@ -38,6 +38,12 @@ __host__ __device__ inline Int round_up(Int x, Int y) {
|
||||
return (x + y - 1) / y * y;
|
||||
}
|
||||
|
||||
// Compute effective rows for grid configuration with swizzled SF layouts.
|
||||
inline int computeEffectiveRows(int m) {
|
||||
constexpr int ROW_TILE = 128;
|
||||
return round_up(m, ROW_TILE);
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
@ -49,6 +55,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||
"Vec size is not matched.");
|
||||
|
||||
// Precompute SF layout parameter (constant for entire kernel).
|
||||
int32_t const numKTiles = (numCols + 63) / 64;
|
||||
|
||||
int sf_m = round_up<int>(numRows, 128);
|
||||
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
|
||||
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
|
||||
@ -79,7 +88,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
auto sf_out =
|
||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx, colIdx, numCols, SFout);
|
||||
rowIdx, colIdx, numKTiles, SFout);
|
||||
|
||||
out_pos =
|
||||
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
|
||||
@ -87,43 +96,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale,
|
||||
int64_t* output, int32_t* SFOuput, bool useUE8M0,
|
||||
int multiProcessorCount, cudaStream_t stream) {
|
||||
// Grid, Block size.
|
||||
// Each thread converts 8 values.
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
||||
// Get number of blocks per SM
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
// Launch the cvt kernel.
|
||||
if (useUE8M0) {
|
||||
cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(
|
||||
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(SFOuput));
|
||||
} else {
|
||||
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
|
||||
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(SFOuput));
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiate the function.
|
||||
template void invokeFP4Quantization(int m, int n, half const* input,
|
||||
float const* SFScale, int64_t* output,
|
||||
int32_t* SFOuput, bool useUE8M0,
|
||||
int multiProcessorCount,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
|
||||
float const* SFScale, int64_t* output,
|
||||
int32_t* SFOuput, bool useUE8M0,
|
||||
int multiProcessorCount,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
@ -147,13 +119,19 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
// We don't support e8m0 scales at this moment.
|
||||
bool useUE8M0 = false;
|
||||
// Grid, Block size. Each thread converts 8 values.
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
int effectiveRows = vllm::computeEffectiveRows(m);
|
||||
dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr,
|
||||
sf_out, useUE8M0, multiProcessorCount, stream);
|
||||
// NOTE: We don't support e8m0 scales at this moment.
|
||||
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
|
||||
m, n, input_ptr, input_sf_ptr, reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
}
|
||||
|
||||
@ -128,51 +128,42 @@ inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||
return b;
|
||||
}
|
||||
|
||||
// Compute SF output offset for swizzled tensor core layout.
|
||||
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
|
||||
// Caller must precompute: numKTiles = (numCols + 63) / 64
|
||||
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
|
||||
int numCols,
|
||||
SFType* SFout) {
|
||||
__device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(
|
||||
int rowIdx, int colIdx, int32_t numKTiles, SFType* SFout) {
|
||||
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
||||
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||
|
||||
// One pair of threads write one SF to global memory.
|
||||
// TODO: stage through smem for packed STG.32
|
||||
// is it better than STG.8 from 4 threads ?
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
||||
|
||||
int32_t mTileIdx = mIdx / (32 * 4);
|
||||
// SF vector size 16.
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numKTiles = (numCols + factor - 1) / factor;
|
||||
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
||||
|
||||
int32_t kTileIdx = (kIdx / 4);
|
||||
int64_t kTileStride = 32 * 4 * 4;
|
||||
|
||||
// M tile layout [32, 4] is column-major.
|
||||
int32_t outerMIdx = (mIdx % 32);
|
||||
int64_t outerMStride = 4 * 4;
|
||||
|
||||
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
||||
int64_t innerMStride = 4;
|
||||
|
||||
int32_t innerKIdx = (kIdx % 4);
|
||||
int64_t innerKStride = 1;
|
||||
|
||||
// Compute the global offset.
|
||||
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
|
||||
outerMIdx * outerMStride + innerMIdx * innerMStride +
|
||||
innerKIdx * innerKStride;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// Decompose indices using bitwise ops (all divisors are powers of 2).
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
int32_t mTileIdx = mIdx >> 7; // mIdx / 128
|
||||
int32_t outerMIdx = mIdx & 31; // mIdx % 32
|
||||
int32_t innerMIdx = (mIdx >> 5) & 3; // (mIdx / 32) % 4
|
||||
int32_t kTileIdx = kIdx >> 2; // kIdx / 4
|
||||
int32_t innerKIdx = kIdx & 3; // kIdx % 4
|
||||
|
||||
// Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 +
|
||||
// outerMIdx * 16 + innerMIdx * 4 + innerKIdx
|
||||
// Use bitwise OR for non-overlapping lower bits.
|
||||
int64_t SFOffset = (static_cast<int64_t>(mTileIdx) * numKTiles + kTileIdx)
|
||||
<< 9 |
|
||||
(outerMIdx << 4) | (innerMIdx << 2) | innerKIdx;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
}
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user