mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 15:37:16 +08:00
[Bugfix] Fix test fused quant layernorm tests (#27865)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
32787d0644
commit
171133f929
@ -1,5 +1,6 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
@ -275,6 +276,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 256));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
||||
@ -306,6 +308,7 @@ void dynamic_scaled_int8_quant(
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 256));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
|
||||
VEC_HIDDEN_SIZES = range(1024, 1030)
|
||||
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
|
||||
# Avoid combinatorial explosion with full Cartesian product
|
||||
NUM_TOKENS_HIDDEN_SIZES = [
|
||||
*[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]],
|
||||
@ -65,7 +65,7 @@ def ref_dynamic_per_token_quant(
|
||||
)
|
||||
else:
|
||||
assert quant_dtype == torch.int8
|
||||
torch_out, scales = ops.scaled_int8_quant(torch_out)
|
||||
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
|
||||
|
||||
return torch_out, scales, residual
|
||||
|
||||
@ -109,7 +109,7 @@ def ops_impl(
|
||||
|
||||
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@ -119,7 +119,7 @@ def test_rms_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
scale_ub: bool,
|
||||
has_scale_ub: bool,
|
||||
dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype,
|
||||
seed: int,
|
||||
@ -130,7 +130,7 @@ def test_rms_norm(
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
if scale_ub is not None and quant_dtype != torch.float8_e4m3fn:
|
||||
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
|
||||
# skip
|
||||
return
|
||||
|
||||
@ -143,9 +143,11 @@ def test_rms_norm(
|
||||
scale = 1 / (hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
if scale_ub is not None:
|
||||
if has_scale_ub:
|
||||
rms_x, _ = ref_rms_norm(layer, x, residual)
|
||||
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda")
|
||||
else:
|
||||
scale_ub = None
|
||||
|
||||
ref_out, ref_scales, ref_residual = ref_impl(
|
||||
layer, x, quant_dtype, residual, scale_ub
|
||||
@ -156,14 +158,27 @@ def test_rms_norm(
|
||||
|
||||
assert ref_out.dtype == quant_dtype
|
||||
assert ops_out.dtype == quant_dtype
|
||||
assert torch.allclose(ref_scales, ops_scales)
|
||||
if quant_dtype == torch.int8:
|
||||
assert torch.allclose(ref_scales, ops_scales, atol=1e-6)
|
||||
# big atol to account for round-off errors.
|
||||
assert torch.allclose(ref_out, ops_out, atol=1)
|
||||
else:
|
||||
assert torch.allclose(
|
||||
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
|
||||
)
|
||||
assert torch.allclose(ref_scales, ops_scales)
|
||||
a = ref_out.to(dtype=torch.float32)
|
||||
b = ops_out.to(dtype=torch.float32)
|
||||
ok = torch.allclose(a, b)
|
||||
if not ok:
|
||||
# fallback: compare dequantized values with relaxed tolerance
|
||||
a_deq = a * ref_scales.view(-1, 1)
|
||||
b_deq = b * ops_scales.view(-1, 1)
|
||||
# NOTE: It is possible that some future test cases trigger this
|
||||
# max diff due to precision issues. If such an error is
|
||||
# encountered, it's recommended to inspect the differences between
|
||||
# all corresponding elements from each tensor (e.g. by looping over
|
||||
# them) and checking how many the max diff error shows up on (just
|
||||
# a few bad elements should still be considered acceptable).
|
||||
ok = torch.allclose(a_deq, b_deq, rtol=5e-2, atol=5e-2)
|
||||
assert ok
|
||||
if add_residual:
|
||||
assert torch.allclose(ref_residual, ops_residual)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user