mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 15:35:01 +08:00
[Bugfix] Disable w16a16 2of4 sparse CompressedTensors24 (#12417)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
9ddc35220b
commit
aa2cd2c43d
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Run `pytest tests/kernels/test_cutlass.py`.
|
Run `pytest tests/kernels/test_cutlass.py`.
|
||||||
"""
|
"""
|
||||||
from typing import Optional, Type
|
from typing import Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -11,6 +11,8 @@ from tests.kernels.utils import opcheck
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .utils import baseline_scaled_mm, to_fp8, to_int8
|
||||||
|
|
||||||
MNK_FACTORS = [
|
MNK_FACTORS = [
|
||||||
(1, 256, 128),
|
(1, 256, 128),
|
||||||
(1, 16384, 1024),
|
(1, 16384, 1024),
|
||||||
@ -41,34 +43,10 @@ capability = current_platform.get_device_capability()
|
|||||||
capability = capability[0] * 10 + capability[1]
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
|
||||||
|
|
||||||
def to_fp8(tensor: torch.Tensor):
|
|
||||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
||||||
return torch.round(tensor.clamp(
|
|
||||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
|
|
||||||
def to_int8(tensor: torch.Tensor):
|
|
||||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
|
||||||
|
|
||||||
|
|
||||||
def rand_int8(shape: tuple, device: str = "cuda"):
|
def rand_int8(shape: tuple, device: str = "cuda"):
|
||||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||||
|
|
||||||
|
|
||||||
def baseline_scaled_mm(a: torch.Tensor,
|
|
||||||
b: torch.Tensor,
|
|
||||||
scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
out_dtype: Type[torch.dtype],
|
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
output = (scale_a * (scale_b * (torch.mm(
|
|
||||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
|
||||||
if bias is not None:
|
|
||||||
output = output + bias
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp8_gemm_helper(m: int,
|
def cutlass_fp8_gemm_helper(m: int,
|
||||||
n: int,
|
n: int,
|
||||||
k: int,
|
k: int,
|
||||||
|
|||||||
214
tests/kernels/test_cutlass_2of4_sparse.py
Normal file
214
tests/kernels/test_cutlass_2of4_sparse.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
"""Tests for sparse cutlass kernels
|
||||||
|
|
||||||
|
Run `pytest tests/kernels/test_semi_structured.py`.
|
||||||
|
"""
|
||||||
|
from typing import Tuple, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
sparse_cutlass_supported)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .utils import baseline_scaled_mm, to_fp8, to_int8
|
||||||
|
|
||||||
|
CUDA_DEVICES = [
|
||||||
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
|
]
|
||||||
|
|
||||||
|
capability = current_platform.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
|
||||||
|
|
||||||
|
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return tensor.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return tensor.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
|
||||||
|
def prune_to_2_4(tensor):
|
||||||
|
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||||
|
original_shape = tensor.shape
|
||||||
|
reshaped = tensor.reshape(-1, 4)
|
||||||
|
|
||||||
|
# Get indices of top 2 absolute values in each group of 4
|
||||||
|
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||||
|
|
||||||
|
# Create binary mask
|
||||||
|
mask = torch.zeros_like(reshaped)
|
||||||
|
mask.scatter_(dim=1,
|
||||||
|
index=indices,
|
||||||
|
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||||
|
|
||||||
|
# Apply mask and reshape back
|
||||||
|
pruned = reshaped * mask
|
||||||
|
|
||||||
|
# Turn all -0.0 to 0.0
|
||||||
|
pruned[pruned == -0.0] = 0.0
|
||||||
|
|
||||||
|
return pruned.reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def make_rand_sparse_tensors(
|
||||||
|
dtype: torch.dtype, m: int, n: int, k: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
|
|
||||||
|
b = prune_to_2_4(b.t()).t()
|
||||||
|
|
||||||
|
if dtype == torch.int8:
|
||||||
|
a, b = to_int8(a), to_int8(b)
|
||||||
|
elif dtype == torch.float8_e4m3fn:
|
||||||
|
a, b = to_fp8(a), to_fp8(b)
|
||||||
|
elif dtype == torch.float16:
|
||||||
|
a, b = to_fp16(a), to_fp16(b)
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
a, b = to_bf16(a), to_bf16(b)
|
||||||
|
else:
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||||
|
|
||||||
|
# Compressed B, Metadata, Original A, B
|
||||||
|
return b_compressed, e, a, b
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||||
|
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||||
|
# Test working with a subset of A and B for sparse matmul
|
||||||
|
def test_cutlass_sparse_subset():
|
||||||
|
|
||||||
|
big_m = 1024
|
||||||
|
m, n, k = 512, 512, 512
|
||||||
|
|
||||||
|
# Create tensors
|
||||||
|
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
|
||||||
|
big_m, n, k)
|
||||||
|
a = whole_a[0:m, 0:k]
|
||||||
|
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||||
|
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||||
|
|
||||||
|
out = ops.cutlass_scaled_sparse_mm(a,
|
||||||
|
b_comp,
|
||||||
|
e,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=torch.bfloat16)
|
||||||
|
baseline = baseline_scaled_mm(a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||||
|
|
||||||
|
|
||||||
|
MNK_FACTORS = [
|
||||||
|
(1, 256, 128),
|
||||||
|
(1, 16384, 1024),
|
||||||
|
(1, 24576, 512),
|
||||||
|
(16, 256, 512),
|
||||||
|
(16, 16384, 128),
|
||||||
|
(16, 24576, 4096),
|
||||||
|
(32, 8192, 4096),
|
||||||
|
(32, 16384, 4096),
|
||||||
|
(33, 1024, 1024),
|
||||||
|
(33, 8192, 128),
|
||||||
|
(64, 2048, 512),
|
||||||
|
(64, 16384, 1024),
|
||||||
|
(100, 8192, 512),
|
||||||
|
(128, 32768, 4096),
|
||||||
|
(256, 4096, 4096),
|
||||||
|
(512, 256, 1024),
|
||||||
|
(512, 8192, 4096),
|
||||||
|
(512, 16384, 128),
|
||||||
|
(512, 24576, 128),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Test working with a subset of A and B for sparse matmul
|
||||||
|
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
|
||||||
|
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||||
|
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
|
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]):
|
||||||
|
|
||||||
|
# Create tensors
|
||||||
|
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||||
|
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
out = ops.cutlass_scaled_sparse_mm(a,
|
||||||
|
b_comp,
|
||||||
|
e,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=dtype)
|
||||||
|
baseline = F.linear(a, b.T)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||||
|
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
||||||
|
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||||
|
reason="FP8 is not supported on this GPU type.")
|
||||||
|
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int):
|
||||||
|
|
||||||
|
# Create tensors
|
||||||
|
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
||||||
|
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||||
|
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||||
|
|
||||||
|
out = ops.cutlass_scaled_sparse_mm(a,
|
||||||
|
b_comp,
|
||||||
|
e,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
baseline = baseline_scaled_mm(a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||||
|
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||||
|
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||||
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
|
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||||
|
per_out_ch: bool, use_bias: bool):
|
||||||
|
|
||||||
|
# Create tensors
|
||||||
|
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||||
|
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||||
|
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||||
|
|
||||||
|
out = ops.cutlass_scaled_sparse_mm(a,
|
||||||
|
b_comp,
|
||||||
|
e,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
baseline = baseline_scaled_mm(a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|
||||||
@ -1,134 +0,0 @@
|
|||||||
"""Tests for sparse cutlass kernels
|
|
||||||
|
|
||||||
Run `pytest tests/kernels/test_semi_structured.py`.
|
|
||||||
"""
|
|
||||||
from typing import Optional, Tuple, Type
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
||||||
sparse_cutlass_supported)
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
CUDA_DEVICES = [
|
|
||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
|
||||||
]
|
|
||||||
|
|
||||||
capability = current_platform.get_device_capability()
|
|
||||||
capability = capability[0] * 10 + capability[1]
|
|
||||||
|
|
||||||
|
|
||||||
def to_fp8(tensor: torch.Tensor):
|
|
||||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
||||||
return torch.round(tensor.clamp(
|
|
||||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
|
|
||||||
def to_int8(tensor: torch.Tensor):
|
|
||||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
|
||||||
|
|
||||||
|
|
||||||
def rand_int8(shape: tuple, device: str = "cuda"):
|
|
||||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
|
||||||
|
|
||||||
|
|
||||||
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
|
||||||
return tensor.to(dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
|
|
||||||
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
|
||||||
return tensor.to(dtype=torch.float16)
|
|
||||||
|
|
||||||
|
|
||||||
def prune_to_2_4(tensor):
|
|
||||||
# Reshape tensor to [N, 4] where N is number of groups of 4
|
|
||||||
original_shape = tensor.shape
|
|
||||||
reshaped = tensor.reshape(-1, 4)
|
|
||||||
|
|
||||||
# Get indices of top 2 absolute values in each group of 4
|
|
||||||
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
|
||||||
|
|
||||||
# Create binary mask
|
|
||||||
mask = torch.zeros_like(reshaped)
|
|
||||||
mask.scatter_(dim=1,
|
|
||||||
index=indices,
|
|
||||||
src=torch.ones_like(indices, dtype=mask.dtype))
|
|
||||||
|
|
||||||
# Apply mask and reshape back
|
|
||||||
pruned = reshaped * mask
|
|
||||||
|
|
||||||
# Turn all -0.0 to 0.0
|
|
||||||
pruned[pruned == -0.0] = 0.0
|
|
||||||
|
|
||||||
return pruned.reshape(original_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def make_rand_sparse_tensors(
|
|
||||||
dtype: torch.dtype, m: int, n: int, k: int
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
a = torch.randn((m, k), device='cuda') * 5
|
|
||||||
b = torch.randn((n, k), device='cuda').t() * 5
|
|
||||||
|
|
||||||
b = prune_to_2_4(b.t()).t()
|
|
||||||
|
|
||||||
if dtype == torch.int8:
|
|
||||||
a, b = to_int8(a), to_int8(b)
|
|
||||||
elif dtype == torch.float8_e4m3fn:
|
|
||||||
a, b = to_fp8(a), to_fp8(b)
|
|
||||||
elif dtype == torch.float16:
|
|
||||||
a, b = to_fp16(a), to_fp16(b)
|
|
||||||
elif dtype == torch.bfloat16:
|
|
||||||
a, b = to_bf16(a), to_bf16(b)
|
|
||||||
else:
|
|
||||||
raise ValueError("unsupported dtype")
|
|
||||||
|
|
||||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
|
||||||
|
|
||||||
# Compressed B, Metadata, Original A, B
|
|
||||||
return b_compressed, e, a, b
|
|
||||||
|
|
||||||
|
|
||||||
def baseline_scaled_mm(a: torch.Tensor,
|
|
||||||
b: torch.Tensor,
|
|
||||||
scale_a: torch.Tensor,
|
|
||||||
scale_b: torch.Tensor,
|
|
||||||
out_dtype: Type[torch.dtype],
|
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
output = (scale_a * (scale_b * (torch.mm(
|
|
||||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
|
||||||
if bias is not None:
|
|
||||||
output = output + bias
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
|
||||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
|
||||||
# Test working with a subset of A and B for sparse matmul
|
|
||||||
def test_cutlass_sparse_subset():
|
|
||||||
|
|
||||||
big_m = 1024
|
|
||||||
m, n, k = 512, 512, 512
|
|
||||||
|
|
||||||
# Create tensors
|
|
||||||
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
|
|
||||||
big_m, n, k)
|
|
||||||
a = whole_a[0:m, 0:k]
|
|
||||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
|
||||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
|
||||||
|
|
||||||
out = ops.cutlass_scaled_sparse_mm(a,
|
|
||||||
b_comp,
|
|
||||||
e,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
out_dtype=torch.bfloat16)
|
|
||||||
baseline = baseline_scaled_mm(a,
|
|
||||||
b,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
out_dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
|
||||||
@ -5,7 +5,7 @@ import random
|
|||||||
import unittest
|
import unittest
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
|
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
|
||||||
Union)
|
Type, Union)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -1100,3 +1100,28 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
|||||||
kwargs,
|
kwargs,
|
||||||
test_utils=test_utils,
|
test_utils=test_utils,
|
||||||
raise_exception=raise_exception) if cond else {}
|
raise_exception=raise_exception) if cond else {}
|
||||||
|
|
||||||
|
|
||||||
|
# For testing quantized linear kernels
|
||||||
|
def to_fp8(tensor: torch.Tensor):
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
return torch.round(tensor.clamp(
|
||||||
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
|
def to_int8(tensor: torch.Tensor):
|
||||||
|
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
def baseline_scaled_mm(a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor,
|
||||||
|
out_dtype: Type[torch.dtype],
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
output = (scale_a * (scale_b * (torch.mm(
|
||||||
|
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
|
||||||
|
return output
|
||||||
|
|||||||
@ -313,8 +313,10 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
|
|||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
|
||||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
reason="2of4 Sparse is not yet supported on this GPU type."
|
||||||
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"args_2of4",
|
"args_2of4",
|
||||||
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
|
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from compressed_tensors.quantization import (QuantizationArgs,
|
|||||||
QuantizationType)
|
QuantizationType)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
@ -27,6 +28,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsLinearMethod"]
|
__all__ = ["CompressedTensorsLinearMethod"]
|
||||||
|
|
||||||
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
||||||
@ -79,6 +82,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||||
|
if scheme is None:
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
layer.scheme = scheme
|
layer.scheme = scheme
|
||||||
return CompressedTensorsLinearMethod(self)
|
return CompressedTensorsLinearMethod(self)
|
||||||
if isinstance(layer, Attention):
|
if isinstance(layer, Attention):
|
||||||
@ -340,10 +345,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"No compressed-tensors compatible scheme was found.")
|
"No compressed-tensors compatible scheme was found.")
|
||||||
|
|
||||||
def get_scheme(
|
def get_scheme(self,
|
||||||
self,
|
layer: torch.nn.Module,
|
||||||
layer: torch.nn.Module,
|
layer_name: Optional[str] = None
|
||||||
layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
|
) -> Optional["CompressedTensorsScheme"]:
|
||||||
"""
|
"""
|
||||||
compressed-tensors supports non uniform in the following way:
|
compressed-tensors supports non uniform in the following way:
|
||||||
|
|
||||||
@ -353,10 +358,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
which can be a full layer_name, a regex for a layer_name, or
|
which can be a full layer_name, a regex for a layer_name, or
|
||||||
an nn.Module name.
|
an nn.Module name.
|
||||||
|
|
||||||
We first check whether a layer is in the ignore group and use
|
Detect whether a layer_name is found in any target and
|
||||||
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
|
|
||||||
|
|
||||||
We then detect whether a layer_name is found in any target and
|
|
||||||
use the quantization scheme corresponding to the matched target
|
use the quantization scheme corresponding to the matched target
|
||||||
to select the CompressedTensorsScheme used for infernece.
|
to select the CompressedTensorsScheme used for infernece.
|
||||||
"""
|
"""
|
||||||
@ -394,6 +396,13 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
if self.supports_cutlass_24(weight_quant=weight_quant,
|
if self.supports_cutlass_24(weight_quant=weight_quant,
|
||||||
input_quant=input_quant,
|
input_quant=input_quant,
|
||||||
sparsity_scheme=sparsity_scheme):
|
sparsity_scheme=sparsity_scheme):
|
||||||
|
# FIXME(tlrmchlsmth): layers using W16A16 CUTLASS 2:4 sparse kernels
|
||||||
|
# currently produce bad output in some cases
|
||||||
|
if weight_quant is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"CompressedTensors24 scheme is disabled for the w16a16 "
|
||||||
|
"case. Falling back to UnquantizedLinearMethod")
|
||||||
|
return None
|
||||||
# Have a valid sparsity scheme
|
# Have a valid sparsity scheme
|
||||||
# Validate layer is supported by Cutlass 2:4 Kernel
|
# Validate layer is supported by Cutlass 2:4 Kernel
|
||||||
scheme = CompressedTensors24(quantized=weight_quant is not None
|
scheme = CompressedTensors24(quantized=weight_quant is not None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user