mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[bugfix] Fix static asymmetric quantization case (#10334)
Signed-off-by: Daniël de Kok <me@danieldk.eu> Signed-off-by: luka <luka@neuralmagic.com> Co-authored-by: Daniël de Kok <me@danieldk.eu>
This commit is contained in:
parent
972112d82f
commit
bf2ddc6610
@ -86,10 +86,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
assert torch_out.min() >= int8_traits.min and torch_out.max(
|
||||
) <= int8_traits.max
|
||||
|
||||
ops_out = torch.empty_like(x, dtype=torch.int8)
|
||||
scales_out = torch.empty_like(scales, dtype=torch.float32)
|
||||
azp_out = torch.empty_like(azps, dtype=torch.int32)
|
||||
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out)
|
||||
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)
|
||||
|
||||
if (not torch.allclose(scales_out, scales)):
|
||||
print(torch.argmax(torch.abs(scales_out - scales)))
|
||||
@ -119,7 +116,8 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
|
||||
out1 = (x / scale_arg).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out2, _, _ = scaled_int8_quant(x, scale_arg)
|
||||
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
|
||||
assert scale2 is scale_arg
|
||||
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
|
||||
@ -145,11 +143,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
|
||||
out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out2 = torch.empty_like(x, dtype=torch.int8)
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
|
||||
|
||||
torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg)
|
||||
out2, scale2, azp2 = scaled_int8_quant(x,
|
||||
scale_arg,
|
||||
azp_arg,
|
||||
symmetric=False)
|
||||
assert scale2 is scale_arg
|
||||
assert azp2 is azp_arg
|
||||
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
|
||||
@ -184,6 +186,5 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
|
||||
val_i8 = int8_traits.max if is_max else int8_traits.min
|
||||
expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda")
|
||||
|
||||
out = torch.empty_like(expected)
|
||||
torch.ops._C.static_scaled_int8_quant(out, x, scale, azp)
|
||||
out, _, _ = scaled_int8_quant(x, scale, azp, symmetric=False)
|
||||
torch.testing.assert_close(expected, out, atol=0, rtol=0)
|
||||
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationType
|
||||
|
||||
from tests.models.utils import check_logprobs_close
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
@ -74,6 +75,35 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[
|
||||
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
|
||||
# TODO static & asymmetric
|
||||
])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
|
||||
example_prompts, model_path,
|
||||
max_tokens, num_logprobs):
|
||||
dtype = "bfloat16"
|
||||
|
||||
with hf_runner(model_path, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model_path, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
def test_compressed_tensors_no_enforce_eager(vllm_runner):
|
||||
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
|
||||
with vllm_runner(model_path) as llm:
|
||||
|
||||
@ -510,10 +510,16 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
|
||||
azp_adj: torch.Tensor,
|
||||
azp: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
:param azp_adj: In the per-tensor case, this should include the azp.
|
||||
Always per-channel.
|
||||
:param azp: Only set in the per-token case. Per-token if set.
|
||||
"""
|
||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.numel(
|
||||
) == b.shape[1] and bias.dtype == out_dtype
|
||||
assert azp is None or azp.numel() == a.shape[0]
|
||||
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
@ -735,7 +741,7 @@ def scaled_int8_quant(
|
||||
azp is
|
||||
None), "azp must only be provided for asymmetric quantization."
|
||||
torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
|
||||
return output, scale, None
|
||||
return output, scale, azp
|
||||
|
||||
# dynamic-per-token quantization.
|
||||
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
|
||||
@ -82,9 +82,13 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
if not self.input_symmetric:
|
||||
layer.azp_adj = layer.weight.sum(dim=0,
|
||||
keepdim=True,
|
||||
dtype=torch.int32)
|
||||
azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if self.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = layer.input_zero_point * azp_adj
|
||||
|
||||
layer.azp_adj = azp_adj
|
||||
else:
|
||||
layer.azp_adj = None
|
||||
|
||||
@ -138,7 +142,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
return apply_int8_linear(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
|
||||
@ -211,13 +211,16 @@ def apply_int8_linear(
|
||||
symmetric=symmetric)
|
||||
|
||||
if x_zp is not None:
|
||||
# Currently, static is always per-tensor and dynamic is per-token
|
||||
static = input_zero_point is not None
|
||||
azp = None if static else x_zp
|
||||
return ops.cutlass_scaled_mm_azp(x_q,
|
||||
weight,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
azp_adj=azp_adj,
|
||||
azp=x_zp,
|
||||
azp=azp,
|
||||
bias=bias)
|
||||
return ops.cutlass_scaled_mm(x_q,
|
||||
weight,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user