[Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support (#10995)

Co-authored-by: Faraz Shahsavan <faraz.shahsavan@gmail.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
Dipika Sikka 2024-12-18 09:57:16 -05:00 committed by GitHub
parent f04e407e6b
commit 60508ffda9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 2365 additions and 117 deletions

View File

@ -206,7 +206,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG v3.5.1
GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
GIT_SHALLOW FALSE
)
endif()
FetchContent_MakeAvailable(cutlass)
@ -241,7 +241,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
"csrc/cutlass_extensions/common.cpp")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
@ -271,11 +274,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
#
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels
# For Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/sparse/cutlass/sparse_compressor_c3x.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
@ -284,12 +290,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"later if you intend on running FP8 sparse or quantized models on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
message(STATUS "Not building cutlass_c3x as no compatible archs found "
"in CUDA target architectures")
endif()
@ -404,7 +410,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

View File

@ -0,0 +1,384 @@
import argparse
import copy
import itertools
import pickle as pkl
import time
from typing import Callable, Iterable, List, Tuple
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import make_rand_sparse_tensors
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]
# bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
**kwargs) -> TMeasurement:
min_run_time = 1
globals = {
"args": args,
"kwargs": kwargs,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(*args, **kwargs)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.int8
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref):
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")
timers = []
# pytorch impl - bfloat16
timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16)))
# pytorch impl - float16
timers.append(
bench_fn(label, sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
# cutlass impl
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
# cutlass with bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))
# cutlass sparse impl
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16))
# cutlass sparse with bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16, bias))
return timers
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref):
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")
timers = []
# pytorch impl w. bf16
timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda")))
# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16))
# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True))
# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16))
# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
use_fast_accum=True))
# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16))
# cutlass impl: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.float16))
# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.bfloat16, bias))
# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(label, sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
scale_b, torch.float16, bias.to(dtype=torch.float16)))
return timers
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, m, k, n, label, sub_label)
raise ValueError("unsupported type")
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def run(dtype: torch.dtype,
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
f"MKN=({m}x{k}x{n})")
print_timers(timers)
results.extend(timers)
return results
# output makers
def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None):
print(f"== All Results {base_description} ====")
print_timers(data)
# pickle all the results
timestamp = int(time.time()) if timestamp is None else timestamp
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
pkl.dump(data, f)
# argparse runners
def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
data = run(args.dtype, MKNs)
make_output(data, MKNs, f"range_bench-{args.dtype}")
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs
model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
for k, n in KNs:
MKNs.append((m, k, n))
data = run(args.dtype, MKNs)
model_bench_data.append(data)
# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
print_timers(data)
timestamp = int(time.time())
all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)
if __name__ == '__main__':
def to_torch_dtype(dt):
if dt == "int8":
return torch.int8
if dt == "fp8":
return torch.float8_e4m3fn
raise ValueError("unsupported dtype")
parser = FlexibleArgumentParser(
description="""
Benchmark Cutlass GEMM.
To run square GEMMs:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8']")
subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench")
square_parser.add_argument("--dim-start", type=int, required=True)
square_parser.add_argument("--dim-end", type=int, required=True)
square_parser.add_argument("--dim-increment", type=int, required=True)
square_parser.set_defaults(func=run_square_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys())
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
args.func(args)

View File

@ -0,0 +1,96 @@
# Cutlass bench utils
from typing import Iterable, Tuple
import torch
import vllm._custom_ops as ops
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.bfloat16)
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16)
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
if dtype == torch.int8:
return to_int8(a), to_int8(b)
if dtype == torch.float8_e4m3fn:
return to_fp8(a), to_fp8(b)
raise ValueError("unsupported dtype")
def prune_to_2_4(tensor):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape = tensor.shape
reshaped = tensor.reshape(-1, 4)
# Get indices of top 2 absolute values in each group of 4
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1,
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
# Turn all -0.0 to 0.0
pruned[pruned == -0.0] = 0.0
return pruned.reshape(original_shape)
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
b = prune_to_2_4(b.t()).t()
if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
elif dtype == torch.float8_e4m3fn:
a, b = to_fp8(a), to_fp8(b)
elif dtype == torch.float16:
a, b = to_fp16(a), to_fp16(b)
elif dtype == torch.bfloat16:
a, b = to_bf16(a), to_bf16(b)
else:
raise ValueError("unsupported dtype")
b_compressed, e = ops.cutlass_sparse_compress(b.t())
# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
m: int, n: int, k: int) -> \
Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
ABs = []
for _ in range(num_tensors):
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
if b_comp is not None:
ABs.append(make_rand_sparse_tensors(dtype, m, n, k))
BComps, Es, As, Bs = zip(*ABs)
return list(BComps), list(Es), list(As), list(Bs)

View File

@ -8,6 +8,7 @@ from typing import Callable, Iterable, List, Tuple
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import make_rand_tensors
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
@ -17,31 +18,6 @@ DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]
# helpers
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
if dtype == torch.int8:
return to_int8(a), to_int8(b)
if dtype == torch.float8_e4m3fn:
return to_fp8(a), to_fp8(b)
raise ValueError("unsupported dtype")
# bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
@ -386,4 +362,4 @@ Benchmark Cutlass GEMM.
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
args.func(args)
args.func(args)

View File

@ -40,4 +40,4 @@ WEIGHT_SHAPES = {
([8192, 57344], 1),
([28672, 8192], 0),
],
}
}

7
csrc/core/math.hpp Normal file
View File

@ -0,0 +1,7 @@
#include <climits>
#include <iostream>
inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

View File

@ -0,0 +1,11 @@
#include "cutlass_extensions/common.hpp"
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
0);
int32_t version_num = major_capability * 10 + minor_capability;
return version_num;
}

View File

@ -0,0 +1,35 @@
#pragma once
#include "cutlass/cutlass.h"
#include <climits>
#include "cuda_runtime.h"
#include <iostream>
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}
/**
* Panic wrapper for unwinding CUDA runtime errors
*/
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
}
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
int max_shared_mem_per_block_opt_in = 0;
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device);
return max_shared_mem_per_block_opt_in;
}
int32_t get_sm_version_num();

View File

@ -36,13 +36,13 @@ struct ScaledEpilogueBase {
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors

View File

@ -162,6 +162,15 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& e,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
torch::Tensor& e, torch::Tensor const& a);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

View File

@ -1,27 +0,0 @@
#pragma once
#include "cutlass/cutlass.h"
#include <climits>
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlassGetStatusString(status)) \
}
inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
int max_shared_mem_per_block_opt_in = 0;
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device);
return max_shared_mem_per_block_opt_in;
}

View File

@ -21,7 +21,8 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "common.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on
using namespace cute;

View File

@ -24,7 +24,8 @@
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "common.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on
using namespace cute;

View File

@ -3,6 +3,8 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return false;
}
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
0);
int32_t version_num = major_capability * 10 + minor_capability;
return version_num;
}
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,

View File

@ -0,0 +1,163 @@
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
using namespace cute;
using namespace vllm;
/// Make A structured sparse by replacing elements with 0 and compress it
template <typename ElementA_, typename ElementAcc_>
bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
TORCH_CHECK(a.dim() == 2)
// Check for strides and alignment
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
TORCH_CHECK(a.stride(1) == 1)
int m = a.size(0);
int k = a.size(1);
// Sparse kernel setup; this kernel is not used for matmul,
// but just for setting up the compressor utility
// A matrix configuration
using ElementA = ElementA_;
using LayoutTagA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
// B matrix configuration
using ElementB = ElementA;
using LayoutTagB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
// C/D matrix configuration
using ElementC = float;
using LayoutTagC = cutlass::layout::ColumnMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
// Core kernel configurations
using ElementAccumulator = ElementAcc_;
using TileShape = Shape<_128, _128, _128>;
using TileShapeRef = Shape<_128, _128, _64>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = typename std::conditional<
std::is_same_v<ElementA, cutlass::float_e4m3_t>,
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
cutlass::gemm::KernelTmaWarpSpecialized>::type;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using ProblemShape = Shape<int, int, int, int>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
AlignmentC, ElementC, LayoutTagC, AlignmentC,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
using StrideE = StrideA;
using StrideA = Stride<int64_t, Int<1>, int64_t>;
// The n (=1) dimension does not matter for the compressor
typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};
using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
// Offline compressor kernel
using CompressorUtility =
cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShape, ElementA, LayoutTagA, SparseConfig>;
using CompressorKernel =
cutlass::transform::kernel::StructuredSparseCompressor<
ProblemShape, ElementA, LayoutTagA, SparseConfig,
cutlass::arch::Sm90>;
using Compressor =
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
auto [M, N, K, L] = prob_shape;
StrideA stride_A;
stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
CompressorUtility compressor_utility(prob_shape, stride_A);
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int KC = compressor_utility.get_tensorA_k_physical();
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
auto a_meta_ptr = static_cast<typename Gemm::CollectiveMainloop::ElementE*>(
a_meta.data_ptr());
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
typename Compressor::Arguments arguments{
prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());
return true;
}
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
if (a.dtype() == torch::kBFloat16) {
return cutlass_sparse_compress<cutlass::bfloat16_t, float>(a_nzs, a_meta,
a);
} else if (a.dtype() == torch::kFloat16) {
return cutlass_sparse_compress<cutlass::half_t, float>(a_nzs, a_meta, a);
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
return cutlass_sparse_compress<cutlass::float_e4m3_t, float>(a_nzs, a_meta,
a);
} else if (a.dtype() == torch::kInt8) {
return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
}
return false;
}

View File

@ -0,0 +1,42 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a);
#endif
bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2);
TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) &&
a_nzs.size(1) * 2 == a.size(1) &&
a_meta.size(1) * 2 * 4 == a.size(1));
// Considering elemsPerMetaElem = 8b / 2b_per_nz = 4
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 &&
a_meta.stride(1) == 1); // Row-major
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}

View File

@ -0,0 +1,303 @@
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on
using namespace cute;
using namespace vllm;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM256 =
typename sm90_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM512 =
typename sm90_fp8_config_M512<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm1 =
typename sm90_fp8_config_1<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm2 =
typename sm90_fp8_config_2<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm3 =
typename sm90_fp8_config_3<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm4 =
typename sm90_fp8_config_4<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm5 =
typename sm90_fp8_config_5<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm6 =
typename sm90_fp8_config_6<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm7 =
typename sm90_fp8_config_7<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm8 =
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const n = bt_nzs.size(0);
uint32_t const m = a.size(0); // Batch size
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) {
if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm2>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 4096 || n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm1>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 128) {
if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm3>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm5>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm4>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 256) {
if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm6>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
} else {
if (n == 6144 || n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
}
// Otherwise the default heuristic
if (mp2 <= 64) {
// n in [1, 64]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// n in (64, 128]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// n in (128, 256]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else {
// n in (256, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmM512>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::half_t>());
TORCH_CHECK(a.dtype() == torch::kFloat16);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
// m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::bfloat16_t>());
TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
// m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NBig =
typename sm90_int8_config_M32_NBig<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NSmall =
typename sm90_int8_config_M32_NSmall<InType, OutType,
Epilogue>::Cutlass3xGemm;
uint32_t const n = out.size(1);
bool const is_small_n = n < 8192;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// m in [1, 32]
if (is_small_n) {
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NSmall>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else {
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NBig>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 64) {
// m in (32, 64]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
}
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else if (a.dtype() == torch::kFloat16) {
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t,
cutlass::bfloat16_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t, cutlass::half_t,
Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else { // a.dtype() == torch::kBFloat16
TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
cutlass::bfloat16_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
cutlass::half_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
out, a, bt_nzs, bt_meta, b_scales, a_scales, *bias);
} else {
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
out, a, bt_nzs, bt_meta, b_scales, a_scales);
}
}
#endif

View File

@ -0,0 +1,496 @@
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/common.hpp"
#include "cutlass_extensions/torch_utils.hpp"
// clang-format on
using namespace cute;
/*
This file defines sparse quantized GEMM operations using the CUTLASS 3.x API,
for NVIDIA GPUs with sm90a (Hopper) or later.
*/
namespace {
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm90_or_later : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
using GemmUniversalMode = cutlass::gemm::GemmUniversalMode;
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule, typename AccType,
typename TileSchedule = cutlass::gemm::PersistentScheduler,
GemmUniversalMode Mode_ = GemmUniversalMode::kGemm>
struct cutlass_sparse_3x_gemm {
static const GemmUniversalMode Mode = Mode_;
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc = AccType;
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using LayoutC_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using EVTCompute = typename Epilogue::EVTCompute;
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD =
128 / cutlass::sizeof_bits<ElementD>::value;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, ElementAcc, ElementC, LayoutC_Transpose, AlignmentCD,
ElementD, LayoutD_Transpose, AlignmentCD, EpilogueSchedule,
EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(CEStorageSize)>;
// clang-format off
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
ElementAB, cutlass::layout::RowMajor, AlignmentA,
ElementAB, cutlass::layout::ColumnMajor, AlignmentB,
ElementAcc, TileShape, ClusterShape,
Stages,
KernelSchedule>::CollectiveOp;
// clang-format on
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
TileSchedule>>;
struct GemmKernel : public KernelType {};
};
template <typename Gemm, typename... EpilogueArgs>
void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
// Interface stride expected from the argument a (will get transposed)
// We compute C^T = B^T * A^T, but we assume B is transposed before
// compression and hence the bt_* naming
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
using LayoutD = cutlass::layout::RowMajor;
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
auto layout_A = make_cute_layout<StrideA>(a, "A");
auto layout_D = make_cute_layout<StrideD>(out, "D");
// Transpose A and D
// A doesn't need to be transposed since cutlass expects a NxK matrix
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{
static_cast<int>(bt_nzs.size(0)), static_cast<int>(size<0>(layout_A)),
static_cast<int>(size<1>(layout_A)), 1};
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
LayoutB b_layout = SparseConfig::fill_layoutA(prob_shape);
LayoutE e_layout = SparseConfig::fill_layoutE(prob_shape);
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(bt_nzs.data_ptr());
auto e_ptr = static_cast<ElementE*>(bt_meta.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{
b_ptr, b_layout, a_ptr, stride_At, e_ptr, e_layout};
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, stride_Dt, c_ptr, stride_Dt};
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args};
// Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_config_default {};
template <typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_config_default<half_t, OutType, Epilogue> {
// M in (128, inf)
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
};
template <typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> {
// M in (128, inf)
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
float>;
};
//////////////////////// Cherry-Picking Kernels ////////////////////////
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_1 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_2 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
using EpilogueSchedule =
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _64, _256>;
using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_3 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_4 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using EpilogueSchedule =
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_5 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_6 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_7 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
using EpilogueSchedule =
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_8 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
using EpilogueSchedule =
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _256, _128>;
using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
};
////////////////////////////////////////////////////////////////////////
template <typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_config_default<cutlass::float_e4m3_t, OutType, Epilogue> {
// M in (128, inf)
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue,
TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule, float>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M64 {
// M in [1, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using EpilogueSchedule =
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float,
TileSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M128 {
// M in (64, 128]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float,
TileSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M256 {
// M in (128, 256]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
using EpilogueSchedule =
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float,
TileSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M512 {
// M in (256, ]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
using EpilogueSchedule =
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float,
TileSchedule>;
};
template <typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_config_default<int8_t, OutType, Epilogue> {
// For M > 128 and any N
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<int8_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M128 {
// For M in (64, 128] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M64 {
// For M in (32, 64] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NBig {
// For M in [1, 32] and N >= 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NSmall {
// For M in [1, 32] and N < 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>;
};
} // namespace

View File

@ -0,0 +1,59 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& e,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
#endif
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) &&
a.size(0) == c.size(0));
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == bt_nzs.size(0));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && bt_nzs.stride(1) == 1 &&
c.stride(1) == 1); // Row-major
TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(bt_nzs.stride(0) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->numel() == bt_nzs.size(0) && bias->is_contiguous() &&
bias->dim() == 1);
}
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
bias);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}

View File

@ -321,6 +321,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor bt_nzs,"
" Tensor bt_meta, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
// CUTLASS sparse matrix compressor
ops.def(
"cutlass_sparse_compress_entry(Tensor! a_nzs, Tensor! a_meta,"
" Tensor a) -> bool");
ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry);
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"

View File

@ -83,7 +83,7 @@ exclude = [
]
[tool.codespell]
ignore-words-list = "dout, te, indicies, subtile"
ignore-words-list = "dout, te, indicies, subtile, ElementE"
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
[tool.isort]

View File

@ -0,0 +1,131 @@
"""Tests for sparse cutlass kernels
Run `pytest tests/kernels/test_semi_structured.py`.
"""
from typing import Optional, Tuple, Type
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128)
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.bfloat16)
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16)
def prune_to_2_4(tensor):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape = tensor.shape
reshaped = tensor.reshape(-1, 4)
# Get indices of top 2 absolute values in each group of 4
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1,
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
# Turn all -0.0 to 0.0
pruned[pruned == -0.0] = 0.0
return pruned.reshape(original_shape)
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
b = prune_to_2_4(b.t()).t()
if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
elif dtype == torch.float8_e4m3fn:
a, b = to_fp8(a), to_fp8(b)
elif dtype == torch.float16:
a, b = to_fp16(a), to_fp16(b)
elif dtype == torch.bfloat16:
a, b = to_bf16(a), to_bf16(b)
else:
raise ValueError("unsupported dtype")
b_compressed, e = ops.cutlass_sparse_compress(b.t())
# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
output = output + bias
return output
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():
big_m = 1024
m, n, k = 512, 512, 512
# Create tensors
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
big_m, n, k)
a = whole_a[0:m, 0:k]
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)

View File

@ -10,9 +10,11 @@ from compressed_tensors.quantization import QuantizationType
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
CompressedTensors24, CompressedTensorsLinearMethod,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
from vllm.platforms import current_platform
@pytest.mark.parametrize(
@ -208,3 +210,98 @@ def test_compressed_tensors_kv_cache(vllm_runner):
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
assert qkv_proj.scheme.input_quant.strategy == input_strategy
assert qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel",
"token"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
"channel", "tensor"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor",
"tensor"),
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor", "token"),
])
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
"channel", "token"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor",
"tensor"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor", "token"),
])
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.int8
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize(
"args_2of4",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant is None
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output

View File

@ -21,6 +21,8 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main

View File

@ -26,6 +26,10 @@ do
export QUANTIZATION=${array[0]}
export MODEL_NAME=${array[1]}
export REVISION=${array[2]}
# If array length is larger than 3, then MIN_CAPABILITY is provided
if [ ${#array[@]} -gt 3 ]; then
export MIN_CAPABILITY=${array[3]}
fi
pytest -s weight_loading/test_weight_loading.py || LOCAL_SUCCESS=$?
if [[ $LOCAL_SUCCESS == 0 ]]; then

View File

@ -1,14 +1,21 @@
import os
import pytest
import torch
from vllm.platforms import current_platform
MAX_MODEL_LEN = 1024
MODEL_NAME = os.environ.get("MODEL_NAME",
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
REVISION = os.environ.get("REVISION", "main")
QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "89")
@pytest.mark.skipif(
not current_platform.has_device_capability(int(MIN_CAPABILITY)),
reason="Current system does not have minimum capability.")
def test_weight_loading(vllm_runner):
"""
Test parameter weight loading with tp>1.

View File

@ -552,6 +552,109 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
return out
def cutlass_sparse_compress(a: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""
Compresses a sparse matrix for use with Cutlass sparse operations.
This function takes a dense tensor and compresses it into two components:
non-zero elements and metadata. The compressed representation is compatible
with Cutlass sparse kernels.
Args:
a (torch.Tensor):
The input tensor to be compressed. Must have one of the following data types:
- `torch.int8`
- `torch.float8_e4m3fn`
- `torch.bfloat16`
- `torch.float16`
Returns:
Tuple[torch.Tensor, torch.Tensor]:
A tuple containing:
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
Raises:
ValueError: If the compression operation fails.
Notes:
- The `a_meta` tensor has a data type of `torch.uint8`.
- Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
- The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
- The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
"""
assert (a.dtype in [
torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16
])
assert (a.is_contiguous())
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
elemsPerMetaElem = 4
m = a.shape[0]
k = a.shape[1]
assert (k % 2 == 0)
a_nzs = torch.empty((m, k // 2), dtype=a.dtype, device=a.device)
a_meta = torch.empty((m, k // 2 // elemsPerMetaElem),
dtype=torch.uint8,
device=a.device)
if not (torch.ops._C.cutlass_sparse_compress_entry(a_nzs, a_meta, a)):
raise ValueError
assert (a_nzs.is_contiguous())
assert (a_meta.is_contiguous())
return a_nzs, a_meta
def cutlass_scaled_sparse_mm(
a: torch.Tensor,
bt_nzs: torch.Tensor,
bt_meta: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Performs a scaled sparse matrix multiplication using Cutlass.
Steps:
1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
`a = torch.randn((m, k), device='cuda')`.
2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
`b = torch.randn((k, n), device='cuda')`.
3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
`b = prune_to_2_4(b, dim=0)`.
4. Compress the transposed sparse matrix `b.t()`:
`bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.
5. Perform sparse matrix multiplication using the compressed matrix,
applying scaling factors for `a` and `b`, and the output data type:
`out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.
Returns:
- The result of the scaled sparse matrix multiplication.
"""
assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == bt_nzs.shape[0] \
and bias.dtype == out_dtype
m = a.shape[0]
n = bt_nzs.shape[0]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a,
scale_b, bias)
return out
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,

View File

@ -1,7 +1,9 @@
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, Literal, Optional, cast
import torch
from compressed_tensors.config import CompressionFormat
from compressed_tensors.config import (CompressionFormat,
SparsityCompressionConfig,
SparsityStructure)
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
@ -15,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
@ -27,20 +29,29 @@ from vllm.platforms import current_platform
__all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
class CompressedTensorsConfig(QuantizationConfig):
def __init__(self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
kv_cache_scheme: Optional[Dict[str, Any]] = None):
def __init__(
self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
):
self.ignore = ignore
self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
self.kv_cache_scheme = kv_cache_scheme
self.sparsity_scheme_map = sparsity_scheme_map
self.config = config
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
@ -78,8 +89,50 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
ignore: List[str] = cast(List[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(
config=config)
return cls(
target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
sparsity_scheme_map=sparsity_scheme_map,
config=config,
)
@classmethod
def _sparsity_scheme_map_from_config(
cls, config: Dict[str,
Any]) -> Dict[str, SparsityCompressionConfig]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
sparsity compression configurations
"""
if (sparsity_config := config.get(SPARSITY_CONFIG_NAME)) is None:
return dict()
sparsity_config = SparsityCompressionConfig.model_validate(
sparsity_config)
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
target: sparsity_config
for target in sparsity_config.targets or list()
}
return sparse_scheme_map
@classmethod
def _quantization_scheme_map_from_config(
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: Dict[str, Any] = dict()
ignore = cast(List[str], config.get("ignore"))
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
@ -90,12 +143,14 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
for _, quant_config in config["config_groups"].items():
config_groups = config.get("config_groups", dict())
for _, quant_config in config_groups.items():
targets = quant_config.get("targets")
for target in targets:
target_scheme_map[target] = {}
target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj(
"weights"] = QuantizationArgs.model_validate(
quant_config.get("weights"))
target_scheme_map[target]["input_activations"] = None
@ -110,13 +165,9 @@ class CompressedTensorsConfig(QuantizationConfig):
"weights"].type == QuantizationType.FLOAT
else:
target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj(
"input_activations"] = QuantizationArgs.model_validate( # noqa: E501
quant_config.get("input_activations"))
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
kv_cache_scheme=config.get("kv_cache_scheme"))
return target_scheme_map
@classmethod
def get_config_filenames(cls) -> List[str]:
@ -315,23 +366,105 @@ class CompressedTensorsConfig(QuantizationConfig):
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())
# Find the quant_scheme
scheme_dict = self.target_scheme_map[matched_target]
scheme = self._get_scheme_from_parts(
weight_quant=scheme_dict["weights"],
input_quant=scheme_dict["input_activations"])
# Will be empty for models with only sparsity
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())
scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
elif self.sparsity_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.sparsity_scheme_map.keys())
weight_quant = None
input_quant = None
# For models with sparsity, assumes that the sparse layers are also
# quantized for cutlass 2:4 support
sparsity_scheme: Optional[
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
matched_target)
if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None
or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant)
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
weight_quant=weight_quant,
input_quant=input_quant,
)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return scheme
@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],
input_quant: Optional[QuantizationArgs],
sparsity_scheme: Optional[SparsityCompressionConfig] = None
) -> bool:
"""
Check if the layer is supported by the Cutlass 2:4 Kernel
Conditions:
- Overarching condition: Sparsity Structure is 2:4
- Unquantized cases are supported
- Weight only quantization is not-supported
- Supported weight quantization strategies are TENSOR and CHANNEL
- Supported input quantization strategies are TENSOR and TOKEN
- Only 8 bit quantization is supported
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense")
if not is_valid_sparsity:
return False
# Unquantized cases are supported
if weight_quant is None and input_quant is None:
return True
# Weight only quantization is not-supported
if weight_quant is not None and input_quant is None:
return False
supported_weight_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.CHANNEL.value
]
assert weight_quant is not None
assert input_quant is not None
if weight_quant.strategy not in supported_weight_quant_strategies:
return False
supported_input_quant_strategies = [
QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value
]
if input_quant.strategy not in supported_input_quant_strategies:
return False
return weight_quant.num_bits == input_quant.num_bits == 8
class CompressedTensorsLinearMethod(LinearMethodBase):

View File

@ -7,13 +7,12 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16)
from .compressed_tensors_24 import CompressedTensors24 # isort: skip
__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8",
"CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS",
"W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensorsScheme", "CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24"
]

View File

@ -0,0 +1,203 @@
from typing import Callable, List, Optional
import torch
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
__all__ = ["CompressedTensors24"]
class CompressedTensors24(CompressedTensorsScheme):
def __init__(self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None):
self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
@classmethod
def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far
return 90
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
# parameter to store uncompressed weight
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=self.weights_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
# Check if quantized, not just 2:4 Sparse
if self.quantized:
if (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.CHANNEL.value):
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.TENSOR.value)
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# input quant will be non-none
if self.input_quant and not self.input_quant.dynamic:
# register input quant scale
assert (self.input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
else:
# for sparse-only, pass in 1 for weight/input scales
weight_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
input_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths),
requires_grad=False)
else:
# torch.compile workaround
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False)
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if self.quantized:
scale = None
if hasattr(layer, "input_scale"):
scale = layer.input_scale
if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
q_input = ops_output[0]
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
if scale is not None:
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
else:
q_input, input_scale = ops.scaled_fp8_quant(
x, use_per_token_if_dynamic=True)
else:
# Not quantized, nothing to do with the input_scales, use as is
input_scale = layer.input_scale
q_input = x
out = ops.cutlass_scaled_sparse_mm(a=q_input,
bt_nzs=layer.weight,
bt_meta=layer.meta,
scale_a=input_scale,
scale_b=layer.weight_scale,
out_dtype=self.output_dtype,
bias=bias)
assert out.is_contiguous()
return out
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
if not self.quantized:
return params_dtype
assert self.weight_quant is not None
assert self.input_quant is not None
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
if not is_8_bits:
raise ValueError("Cutlass only supports 8-bit quantization")
if (self.weight_quant.type == QuantizationType.FLOAT
and self.input_quant.type == QuantizationType.FLOAT):
return torch.float8_e4m3fn
if (self.weight_quant.type == QuantizationType.INT
and self.input_quant.type == QuantizationType.INT):
return torch.int8
raise ValueError("Quantization type not supported by Cutlass")
def check_24(tensor):
new_tensor = tensor.view(-1, 4)
zero_counts = (new_tensor == 0).sum(dim=1)
return (zero_counts >= 2).all().item()