[Bugfix] Disable w16a16 2of4 sparse CompressedTensors24 (#12417)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-01-26 06:59:58 -05:00 committed by GitHub
parent 9ddc35220b
commit aa2cd2c43d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 263 additions and 169 deletions

View File

@ -2,7 +2,7 @@
Run `pytest tests/kernels/test_cutlass.py`. Run `pytest tests/kernels/test_cutlass.py`.
""" """
from typing import Optional, Type from typing import Type
import pytest import pytest
import torch import torch
@ -11,6 +11,8 @@ from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .utils import baseline_scaled_mm, to_fp8, to_int8
MNK_FACTORS = [ MNK_FACTORS = [
(1, 256, 128), (1, 256, 128),
(1, 16384, 1024), (1, 16384, 1024),
@ -41,34 +43,10 @@ capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def rand_int8(shape: tuple, device: str = "cuda"): def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128) return to_int8(torch.rand(shape, device=device) * 255 - 128)
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
output = output + bias
return output
def cutlass_fp8_gemm_helper(m: int, def cutlass_fp8_gemm_helper(m: int,
n: int, n: int,
k: int, k: int,

View File

@ -0,0 +1,214 @@
"""Tests for sparse cutlass kernels
Run `pytest tests/kernels/test_semi_structured.py`.
"""
from typing import Tuple, Type
import pytest
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform
from .utils import baseline_scaled_mm, to_fp8, to_int8
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.bfloat16)
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16)
def prune_to_2_4(tensor):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape = tensor.shape
reshaped = tensor.reshape(-1, 4)
# Get indices of top 2 absolute values in each group of 4
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1,
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
# Turn all -0.0 to 0.0
pruned[pruned == -0.0] = 0.0
return pruned.reshape(original_shape)
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
b = prune_to_2_4(b.t()).t()
if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
elif dtype == torch.float8_e4m3fn:
a, b = to_fp8(a), to_fp8(b)
elif dtype == torch.float16:
a, b = to_fp16(a), to_fp16(b)
elif dtype == torch.bfloat16:
a, b = to_bf16(a), to_bf16(b)
else:
raise ValueError("unsupported dtype")
b_compressed, e = ops.cutlass_sparse_compress(b.t())
# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():
big_m = 1024
m, n, k = 512, 512, 512
# Create tensors
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
big_m, n, k)
a = whole_a[0:m, 0:k]
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
MNK_FACTORS = [
(1, 256, 128),
(1, 16384, 1024),
(1, 24576, 512),
(16, 256, 512),
(16, 16384, 128),
(16, 24576, 4096),
(32, 8192, 4096),
(32, 16384, 4096),
(33, 1024, 1024),
(33, 8192, 128),
(64, 2048, 512),
(64, 16384, 1024),
(100, 8192, 512),
(128, 32768, 4096),
(256, 4096, 4096),
(512, 256, 1024),
(512, 8192, 4096),
(512, 16384, 128),
(512, 24576, 128),
]
# Test working with a subset of A and B for sparse matmul
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=dtype)
baseline = F.linear(a, b.T)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1e-2)
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)

View File

@ -1,134 +0,0 @@
"""Tests for sparse cutlass kernels
Run `pytest tests/kernels/test_semi_structured.py`.
"""
from typing import Optional, Tuple, Type
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128)
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.bfloat16)
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16)
def prune_to_2_4(tensor):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape = tensor.shape
reshaped = tensor.reshape(-1, 4)
# Get indices of top 2 absolute values in each group of 4
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1,
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
# Turn all -0.0 to 0.0
pruned[pruned == -0.0] = 0.0
return pruned.reshape(original_shape)
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
b = prune_to_2_4(b.t()).t()
if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
elif dtype == torch.float8_e4m3fn:
a, b = to_fp8(a), to_fp8(b)
elif dtype == torch.float16:
a, b = to_fp16(a), to_fp16(b)
elif dtype == torch.bfloat16:
a, b = to_bf16(a), to_bf16(b)
else:
raise ValueError("unsupported dtype")
b_compressed, e = ops.cutlass_sparse_compress(b.t())
# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
output = output + bias
return output
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.")
# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():
big_m = 1024
m, n, k = 512, 512, 512
# Create tensors
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
big_m, n, k)
a = whole_a[0:m, 0:k]
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)

View File

@ -5,7 +5,7 @@ import random
import unittest import unittest
from numbers import Number from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Union) Type, Union)
import pytest import pytest
import torch import torch
@ -1100,3 +1100,28 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
kwargs, kwargs,
test_utils=test_utils, test_utils=test_utils,
raise_exception=raise_exception) if cond else {} raise_exception=raise_exception) if cond else {}
# For testing quantized linear kernels
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
output = output + bias
return output

View File

@ -313,8 +313,10 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
assert output assert output
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
@pytest.mark.skipif(not sparse_cutlass_supported(), @pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.") reason="2of4 Sparse is not yet supported on this GPU type."
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"args_2of4", "args_2of4",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")]) [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])

View File

@ -9,6 +9,7 @@ from compressed_tensors.quantization import (QuantizationArgs,
QuantizationType) QuantizationType)
from pydantic import BaseModel from pydantic import BaseModel
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
@ -27,6 +28,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"] __all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config" SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
@ -79,6 +82,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix) scheme = self.get_scheme(layer=layer, layer_name=prefix)
if scheme is None:
return UnquantizedLinearMethod()
layer.scheme = scheme layer.scheme = scheme
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention): if isinstance(layer, Attention):
@ -340,10 +345,10 @@ class CompressedTensorsConfig(QuantizationConfig):
raise NotImplementedError( raise NotImplementedError(
"No compressed-tensors compatible scheme was found.") "No compressed-tensors compatible scheme was found.")
def get_scheme( def get_scheme(self,
self, layer: torch.nn.Module,
layer: torch.nn.Module, layer_name: Optional[str] = None
layer_name: Optional[str] = None) -> "CompressedTensorsScheme": ) -> Optional["CompressedTensorsScheme"]:
""" """
compressed-tensors supports non uniform in the following way: compressed-tensors supports non uniform in the following way:
@ -353,10 +358,7 @@ class CompressedTensorsConfig(QuantizationConfig):
which can be a full layer_name, a regex for a layer_name, or which can be a full layer_name, a regex for a layer_name, or
an nn.Module name. an nn.Module name.
We first check whether a layer is in the ignore group and use Detect whether a layer_name is found in any target and
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
We then detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for infernece. to select the CompressedTensorsScheme used for infernece.
""" """
@ -394,6 +396,13 @@ class CompressedTensorsConfig(QuantizationConfig):
if self.supports_cutlass_24(weight_quant=weight_quant, if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant, input_quant=input_quant,
sparsity_scheme=sparsity_scheme): sparsity_scheme=sparsity_scheme):
# FIXME(tlrmchlsmth): layers using W16A16 CUTLASS 2:4 sparse kernels
# currently produce bad output in some cases
if weight_quant is None:
logger.warning_once(
"CompressedTensors24 scheme is disabled for the w16a16 "
"case. Falling back to UnquantizedLinearMethod")
return None
# Have a valid sparsity scheme # Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel # Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None scheme = CompressedTensors24(quantized=weight_quant is not None