mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support (#10995)
Co-authored-by: Faraz Shahsavan <faraz.shahsavan@gmail.com> Co-authored-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Rahul Tuli <rahul@neuralmagic.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
parent
f04e407e6b
commit
60508ffda9
@ -206,7 +206,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
|
||||
set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use")
|
||||
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
GIT_TAG v3.5.1
|
||||
GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
|
||||
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
|
||||
GIT_SHALLOW TRUE
|
||||
GIT_SHALLOW FALSE
|
||||
)
|
||||
endif()
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
@ -241,7 +241,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/permute_cols.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
@ -271,11 +274,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
#
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels
|
||||
# For Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||
"csrc/sparse/cutlass/sparse_compressor_c3x.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
@ -284,12 +290,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
|
||||
message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"later if you intend on running FP8 sparse or quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
|
||||
message(STATUS "Not building cutlass_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
@ -404,7 +410,7 @@ define_gpu_extension_target(
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
|
||||
384
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Normal file
384
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Normal file
@ -0,0 +1,384 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import pickle as pkl
|
||||
import time
|
||||
from typing import Callable, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import make_rand_sparse_tensors
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
**kwargs) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
"args": args,
|
||||
"kwargs": kwargs,
|
||||
"fn": fn,
|
||||
}
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn(*args, **kwargs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
|
||||
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.int8
|
||||
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
|
||||
torch.bfloat16)
|
||||
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||
|
||||
if not torch.allclose(out, out_ref):
|
||||
print("Incorrect results")
|
||||
print(out)
|
||||
print(out_ref)
|
||||
else:
|
||||
print("Correct results")
|
||||
|
||||
timers = []
|
||||
# pytorch impl - bfloat16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16),
|
||||
b.to(dtype=torch.bfloat16)))
|
||||
|
||||
# pytorch impl - float16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
|
||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
|
||||
|
||||
# cutlass impl
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
|
||||
# cutlass with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias))
|
||||
|
||||
# cutlass sparse impl
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16))
|
||||
|
||||
# cutlass sparse with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16, bias))
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
|
||||
k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
|
||||
torch.bfloat16)
|
||||
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||
|
||||
if not torch.allclose(out, out_ref):
|
||||
print("Incorrect results")
|
||||
print(out)
|
||||
print(out_ref)
|
||||
else:
|
||||
print("Correct results")
|
||||
|
||||
timers = []
|
||||
|
||||
# pytorch impl w. bf16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda")))
|
||||
|
||||
# pytorch impl: bf16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16))
|
||||
|
||||
# pytorch impl: bf16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# pytorch impl: fp16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16))
|
||||
|
||||
# pytorch impl: fp16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16))
|
||||
|
||||
# cutlass impl: fp16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.float16))
|
||||
|
||||
# cutlass impl: bf16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16, bias))
|
||||
|
||||
# cutlass impl: fp16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.float16, bias.to(dtype=torch.float16)))
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
if dtype == torch.int8:
|
||||
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return bench_fp8(dtype, m, k, n, label, sub_label)
|
||||
raise ValueError("unsupported type")
|
||||
|
||||
|
||||
# runner
|
||||
def print_timers(timers: Iterable[TMeasurement]):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})")
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# output makers
|
||||
def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None):
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
# pickle all the results
|
||||
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(data, f)
|
||||
|
||||
|
||||
# argparse runners
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_range_bench(args):
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||
n = len(dim_sizes)
|
||||
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||
MKNs = list(zip(Ms, Ks, Ns))
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_model_bench(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
||||
KNs = []
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KNs.append(KN)
|
||||
return KNs
|
||||
|
||||
model_bench_data = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
Ms = args.batch_sizes
|
||||
KNs = model_shapes(model, tp_size)
|
||||
MKNs = []
|
||||
for m in Ms:
|
||||
for k, n in KNs:
|
||||
MKNs.append((m, k, n))
|
||||
|
||||
data = run(args.dtype, MKNs)
|
||||
model_bench_data.append(data)
|
||||
|
||||
# Print all results
|
||||
for data, model_tp in zip(model_bench_data, models_tps):
|
||||
model, tp_size = model_tp
|
||||
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||
print_timers(data)
|
||||
|
||||
timestamp = int(time.time())
|
||||
|
||||
all_data = []
|
||||
for d in model_bench_data:
|
||||
all_data.extend(d)
|
||||
# pickle all data
|
||||
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(all_data, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "int8":
|
||||
return torch.int8
|
||||
if dt == "fp8":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""
|
||||
Benchmark Cutlass GEMM.
|
||||
|
||||
To run square GEMMs:
|
||||
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||
|
||||
To run constant N and K and sweep M:
|
||||
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||
|
||||
To run dimensions from a model:
|
||||
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||
|
||||
Output:
|
||||
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
parser.add_argument("--dtype",
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']")
|
||||
subparsers = parser.add_subparsers(dest="cmd")
|
||||
|
||||
square_parser = subparsers.add_parser("square_bench")
|
||||
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
square_parser.set_defaults(func=run_square_bench)
|
||||
|
||||
range_parser = subparsers.add_parser("range_bench")
|
||||
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||
range_parser.set_defaults(func=run_range_bench)
|
||||
|
||||
model_parser = subparsers.add_parser("model_bench")
|
||||
model_parser.add_argument("--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys())
|
||||
model_parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
model_parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
96
benchmarks/cutlass_benchmarks/utils.py
Normal file
96
benchmarks/cutlass_benchmarks/utils.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Cutlass bench utils
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.float16)
|
||||
|
||||
|
||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
if dtype == torch.int8:
|
||||
return to_int8(a), to_int8(b)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return to_fp8(a), to_fp8(b)
|
||||
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
|
||||
def prune_to_2_4(tensor):
|
||||
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||
original_shape = tensor.shape
|
||||
reshaped = tensor.reshape(-1, 4)
|
||||
|
||||
# Get indices of top 2 absolute values in each group of 4
|
||||
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1,
|
||||
index=indices,
|
||||
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
|
||||
# Turn all -0.0 to 0.0
|
||||
pruned[pruned == -0.0] = 0.0
|
||||
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
a, b = to_int8(a), to_int8(b)
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
a, b = to_fp8(a), to_fp8(b)
|
||||
elif dtype == torch.float16:
|
||||
a, b = to_fp16(a), to_fp16(b)
|
||||
elif dtype == torch.bfloat16:
|
||||
a, b = to_bf16(a), to_bf16(b)
|
||||
else:
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||
|
||||
# Compressed B, Metadata, Original A, B
|
||||
return b_compressed, e, a, b
|
||||
|
||||
|
||||
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
|
||||
m: int, n: int, k: int) -> \
|
||||
Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
||||
ABs = []
|
||||
for _ in range(num_tensors):
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||
if b_comp is not None:
|
||||
ABs.append(make_rand_sparse_tensors(dtype, m, n, k))
|
||||
BComps, Es, As, Bs = zip(*ABs)
|
||||
return list(BComps), list(Es), list(As), list(Bs)
|
||||
@ -8,6 +8,7 @@ from typing import Callable, Iterable, List, Tuple
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import make_rand_tensors
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
@ -17,31 +18,6 @@ DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
if dtype == torch.int8:
|
||||
return to_int8(a), to_int8(b)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return to_fp8(a), to_fp8(b)
|
||||
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
@ -386,4 +362,4 @@ Benchmark Cutlass GEMM.
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
args.func(args)
|
||||
@ -40,4 +40,4 @@ WEIGHT_SHAPES = {
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
}
|
||||
}
|
||||
7
csrc/core/math.hpp
Normal file
7
csrc/core/math.hpp
Normal file
@ -0,0 +1,7 @@
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
|
||||
inline uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
11
csrc/cutlass_extensions/common.cpp
Normal file
11
csrc/cutlass_extensions/common.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
int32_t get_sm_version_num() {
|
||||
int32_t major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
0);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
0);
|
||||
int32_t version_num = major_capability * 10 + minor_capability;
|
||||
return version_num;
|
||||
}
|
||||
35
csrc/cutlass_extensions/common.hpp
Normal file
35
csrc/cutlass_extensions/common.hpp
Normal file
@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include <climits>
|
||||
#include "cuda_runtime.h"
|
||||
#include <iostream>
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
/**
|
||||
* Panic wrapper for unwinding CUDA runtime errors
|
||||
*/
|
||||
#define CUDA_CHECK(status) \
|
||||
{ \
|
||||
cudaError_t error = status; \
|
||||
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
|
||||
}
|
||||
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
int max_shared_mem_per_block_opt_in = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
||||
device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
int32_t get_sm_version_num();
|
||||
@ -36,13 +36,13 @@ struct ScaledEpilogueBase {
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
|
||||
@ -162,6 +162,15 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& e,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
|
||||
torch::Tensor& e, torch::Tensor const& a);
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
|
||||
@ -1,27 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include <climits>
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(status)) \
|
||||
}
|
||||
|
||||
inline uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
int max_shared_mem_per_block_opt_in = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
||||
device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
@ -21,7 +21,8 @@
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "common.hpp"
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -24,7 +24,8 @@
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "common.hpp"
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t get_sm_version_num() {
|
||||
int32_t major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
0);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
0);
|
||||
int32_t version_num = major_capability * 10 + minor_capability;
|
||||
return version_num;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
|
||||
163
csrc/sparse/cutlass/sparse_compressor_c3x.cu
Normal file
163
csrc/sparse/cutlass/sparse_compressor_c3x.cu
Normal file
@ -0,0 +1,163 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include "sparse_scaled_mm_c3x.cuh"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||
template <typename ElementA_, typename ElementAcc_>
|
||||
bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
|
||||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(a.dim() == 2)
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
|
||||
TORCH_CHECK(a.stride(1) == 1)
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
|
||||
// Sparse kernel setup; this kernel is not used for matmul,
|
||||
// but just for setting up the compressor utility
|
||||
// A matrix configuration
|
||||
using ElementA = ElementA_;
|
||||
using LayoutTagA = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
// B matrix configuration
|
||||
using ElementB = ElementA;
|
||||
using LayoutTagB = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
// C/D matrix configuration
|
||||
using ElementC = float;
|
||||
using LayoutTagC = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = ElementAcc_;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using TileShapeRef = Shape<_128, _128, _64>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = typename std::conditional<
|
||||
std::is_same_v<ElementA, cutlass::float_e4m3_t>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::type;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using ProblemShape = Shape<int, int, int, int>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
|
||||
AlignmentC, ElementC, LayoutTagC, AlignmentC,
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
|
||||
LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
|
||||
using StrideE = StrideA;
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
|
||||
// The n (=1) dimension does not matter for the compressor
|
||||
typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};
|
||||
|
||||
using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;
|
||||
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
// Offline compressor kernel
|
||||
using CompressorUtility =
|
||||
cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
ProblemShape, ElementA, LayoutTagA, SparseConfig>;
|
||||
|
||||
using CompressorKernel =
|
||||
cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
ProblemShape, ElementA, LayoutTagA, SparseConfig,
|
||||
cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor =
|
||||
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
auto [M, N, K, L] = prob_shape;
|
||||
|
||||
StrideA stride_A;
|
||||
stride_A =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
|
||||
CompressorUtility compressor_utility(prob_shape, stride_A);
|
||||
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
|
||||
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
|
||||
auto a_meta_ptr = static_cast<typename Gemm::CollectiveMainloop::ElementE*>(
|
||||
a_meta.data_ptr());
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
typename Compressor::Arguments arguments{
|
||||
prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}};
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
CUTLASS_CHECK(compressor_op.can_implement(arguments));
|
||||
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(compressor_op.run());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a) {
|
||||
if (a.dtype() == torch::kBFloat16) {
|
||||
return cutlass_sparse_compress<cutlass::bfloat16_t, float>(a_nzs, a_meta,
|
||||
a);
|
||||
} else if (a.dtype() == torch::kFloat16) {
|
||||
return cutlass_sparse_compress<cutlass::half_t, float>(a_nzs, a_meta, a);
|
||||
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
return cutlass_sparse_compress<cutlass::float_e4m3_t, float>(a_nzs, a_meta,
|
||||
a);
|
||||
} else if (a.dtype() == torch::kInt8) {
|
||||
return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
42
csrc/sparse/cutlass/sparse_compressor_entry.cu
Normal file
42
csrc/sparse/cutlass/sparse_compressor_entry.cu
Normal file
@ -0,0 +1,42 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a);
|
||||
#endif
|
||||
|
||||
bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2);
|
||||
TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) &&
|
||||
a_nzs.size(1) * 2 == a.size(1) &&
|
||||
a_meta.size(1) * 2 * 4 == a.size(1));
|
||||
// Considering elemsPerMetaElem = 8b / 2b_per_nz = 4
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 &&
|
||||
a_meta.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
303
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Normal file
303
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Normal file
@ -0,0 +1,303 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
#include "sparse_scaled_mm_c3x.cuh"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM256 =
|
||||
typename sm90_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM512 =
|
||||
typename sm90_fp8_config_M512<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
using Cutlass3xGemm1 =
|
||||
typename sm90_fp8_config_1<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm2 =
|
||||
typename sm90_fp8_config_2<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm3 =
|
||||
typename sm90_fp8_config_3<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm4 =
|
||||
typename sm90_fp8_config_4<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm5 =
|
||||
typename sm90_fp8_config_5<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm6 =
|
||||
typename sm90_fp8_config_6<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm7 =
|
||||
typename sm90_fp8_config_7<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm8 =
|
||||
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const n = bt_nzs.size(0);
|
||||
uint32_t const m = a.size(0); // Batch size
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 64) {
|
||||
if (n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm2>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 4096 || n == 6144) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm1>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 128) {
|
||||
if (n == 4096) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm3>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm5>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 6144) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm4>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 256) {
|
||||
if (n == 4096) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm6>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 6144) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else {
|
||||
if (n == 6144 || n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 4096) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise the default heuristic
|
||||
if (mp2 <= 64) {
|
||||
// n in [1, 64]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// n in (64, 128]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// n in (128, 256]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// n in (256, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM512>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::half_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
// m in (128, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::bfloat16_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
// m in (128, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NBig =
|
||||
typename sm90_int8_config_M32_NBig<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NSmall =
|
||||
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
bool const is_small_n = n < 8192;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 32) {
|
||||
// m in [1, 32]
|
||||
if (is_small_n) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NSmall>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NBig>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (32, 64]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else if (a.dtype() == torch::kFloat16) {
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t, cutlass::half_t,
|
||||
Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else { // a.dtype() == torch::kBFloat16
|
||||
TORCH_CHECK(a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, bt_nzs, bt_meta, b_scales, a_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, bt_nzs, bt_meta, b_scales, a_scales);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
496
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Normal file
496
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Normal file
@ -0,0 +1,496 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/cute_utils.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
This file defines sparse quantized GEMM operations using the CUTLASS 3.x API,
|
||||
for NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
|
||||
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
using GemmUniversalMode = cutlass::gemm::GemmUniversalMode;
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule, typename AccType,
|
||||
typename TileSchedule = cutlass::gemm::PersistentScheduler,
|
||||
GemmUniversalMode Mode_ = GemmUniversalMode::kGemm>
|
||||
struct cutlass_sparse_3x_gemm {
|
||||
static const GemmUniversalMode Mode = Mode_;
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc = AccType;
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
|
||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||
|
||||
using LayoutC_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||
using LayoutD_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
static constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, ElementAcc, ElementC, LayoutC_Transpose, AlignmentCD,
|
||||
ElementD, LayoutD_Transpose, AlignmentCD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
// clang-format off
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, AlignmentA,
|
||||
ElementAB, cutlass::layout::ColumnMajor, AlignmentB,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
// clang-format on
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
TileSchedule>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
// Interface stride expected from the argument a (will get transposed)
|
||||
// We compute C^T = B^T * A^T, but we assume B is transposed before
|
||||
// compression and hence the bt_* naming
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||
|
||||
auto layout_A = make_cute_layout<StrideA>(a, "A");
|
||||
auto layout_D = make_cute_layout<StrideD>(out, "D");
|
||||
|
||||
// Transpose A and D
|
||||
// A doesn't need to be transposed since cutlass expects a NxK matrix
|
||||
// for B (which is At)
|
||||
auto stride_At = layout_A.stride();
|
||||
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
typename GemmKernel::ProblemShape prob_shape{
|
||||
static_cast<int>(bt_nzs.size(0)), static_cast<int>(size<0>(layout_A)),
|
||||
static_cast<int>(size<1>(layout_A)), 1};
|
||||
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
LayoutB b_layout = SparseConfig::fill_layoutA(prob_shape);
|
||||
LayoutE e_layout = SparseConfig::fill_layoutE(prob_shape);
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(bt_nzs.data_ptr());
|
||||
auto e_ptr = static_cast<ElementE*>(bt_meta.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
b_ptr, b_layout, a_ptr, stride_At, e_ptr, e_layout};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, stride_Dt, c_ptr, stride_Dt};
|
||||
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default {};
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<half_t, OutType, Epilogue> {
|
||||
// M in (128, inf)
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> {
|
||||
// M in (128, inf)
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
float>;
|
||||
};
|
||||
|
||||
//////////////////////// Cherry-Picking Kernels ////////////////////////
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_1 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_2 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _64, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_3 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_4 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_5 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_6 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_7 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_8 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _256, _128>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<cutlass::float_e4m3_t, OutType, Epilogue> {
|
||||
// M in (128, inf)
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue,
|
||||
TileShape, ClusterShape, KernelSchedule,
|
||||
EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M64 {
|
||||
// M in [1, 64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M256 {
|
||||
// M in (128, 256]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M512 {
|
||||
// M in (256, ]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
};
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<int8_t, OutType, Epilogue> {
|
||||
// For M > 128 and any N
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<int8_t, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M128 {
|
||||
// For M in (64, 128] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M64 {
|
||||
// For M in (32, 64] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NBig {
|
||||
// For M in [1, 32] and N >= 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _4, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NSmall {
|
||||
// For M in [1, 32] and N < 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
59
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Normal file
59
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Normal file
@ -0,0 +1,59 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& e,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) &&
|
||||
a.size(0) == c.size(0));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == bt_nzs.size(0));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && bt_nzs.stride(1) == 1 &&
|
||||
c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(bt_nzs.stride(0) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == bt_nzs.size(0) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
|
||||
bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
@ -321,6 +321,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
|
||||
" Tensor bt_nzs,"
|
||||
" Tensor bt_meta, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
|
||||
|
||||
// CUTLASS sparse matrix compressor
|
||||
ops.def(
|
||||
"cutlass_sparse_compress_entry(Tensor! a_nzs, Tensor! a_meta,"
|
||||
" Tensor a) -> bool");
|
||||
ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry);
|
||||
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||
|
||||
@ -83,7 +83,7 @@ exclude = [
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words-list = "dout, te, indicies, subtile"
|
||||
ignore-words-list = "dout, te, indicies, subtile, ElementE"
|
||||
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
||||
|
||||
[tool.isort]
|
||||
|
||||
131
tests/kernels/test_semi_structured.py
Normal file
131
tests/kernels/test_semi_structured.py
Normal file
@ -0,0 +1,131 @@
|
||||
"""Tests for sparse cutlass kernels
|
||||
|
||||
Run `pytest tests/kernels/test_semi_structured.py`.
|
||||
"""
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor):
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def rand_int8(shape: tuple, device: str = "cuda"):
|
||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||
|
||||
|
||||
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.float16)
|
||||
|
||||
|
||||
def prune_to_2_4(tensor):
|
||||
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||
original_shape = tensor.shape
|
||||
reshaped = tensor.reshape(-1, 4)
|
||||
|
||||
# Get indices of top 2 absolute values in each group of 4
|
||||
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1,
|
||||
index=indices,
|
||||
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
|
||||
# Turn all -0.0 to 0.0
|
||||
pruned[pruned == -0.0] = 0.0
|
||||
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
a, b = to_int8(a), to_int8(b)
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
a, b = to_fp8(a), to_fp8(b)
|
||||
elif dtype == torch.float16:
|
||||
a, b = to_fp16(a), to_fp16(b)
|
||||
elif dtype == torch.bfloat16:
|
||||
a, b = to_bf16(a), to_bf16(b)
|
||||
else:
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||
|
||||
# Compressed B, Metadata, Original A, B
|
||||
return b_compressed, e, a, b
|
||||
|
||||
|
||||
def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
output = (scale_a * (scale_b * (torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
def test_cutlass_sparse_subset():
|
||||
big_m = 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
|
||||
big_m, n, k)
|
||||
a = whole_a[0:m, 0:k]
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
@ -10,9 +10,11 @@ from compressed_tensors.quantization import QuantizationType
|
||||
|
||||
from tests.models.utils import check_logprobs_close
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
CompressedTensors24, CompressedTensorsLinearMethod,
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -208,3 +210,98 @@ def test_compressed_tensors_kv_cache(vllm_runner):
|
||||
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
|
||||
output = llm.generate_greedy("Hello world!", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensors24)
|
||||
|
||||
assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
|
||||
assert qkv_proj.scheme.input_quant.strategy == input_strategy
|
||||
assert qkv_proj.scheme.quantized
|
||||
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
|
||||
assert sparsity_map.get("Linear").format == "dense"
|
||||
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
@pytest.mark.parametrize("args_2of4", [
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel",
|
||||
"token"),
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
|
||||
"channel", "tensor"),
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor",
|
||||
"tensor"),
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
|
||||
"tensor", "token"),
|
||||
])
|
||||
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
|
||||
model, weight_strategy, input_strategy = args_2of4
|
||||
with vllm_runner(model) as llm:
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
|
||||
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
@pytest.mark.parametrize("args_2of4", [
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
|
||||
"channel", "token"),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor",
|
||||
"tensor"),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
|
||||
"tensor", "token"),
|
||||
])
|
||||
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
|
||||
model, weight_strategy, input_strategy = args_2of4
|
||||
with vllm_runner(model) as llm:
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert qkv_proj.scheme.weights_dtype == torch.int8
|
||||
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
|
||||
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
|
||||
model = args_2of4
|
||||
with vllm_runner(model) as llm:
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensors24)
|
||||
|
||||
assert qkv_proj.scheme.weight_quant is None
|
||||
assert qkv_proj.scheme.input_quant is None
|
||||
assert not qkv_proj.scheme.quantized
|
||||
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
|
||||
assert sparsity_map.get("Linear").format == "dense"
|
||||
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
@ -21,6 +21,8 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
|
||||
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
|
||||
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
|
||||
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
|
||||
awq, casperhansen/mixtral-instruct-awq, main
|
||||
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
|
||||
|
||||
@ -26,6 +26,10 @@ do
|
||||
export QUANTIZATION=${array[0]}
|
||||
export MODEL_NAME=${array[1]}
|
||||
export REVISION=${array[2]}
|
||||
# If array length is larger than 3, then MIN_CAPABILITY is provided
|
||||
if [ ${#array[@]} -gt 3 ]; then
|
||||
export MIN_CAPABILITY=${array[3]}
|
||||
fi
|
||||
pytest -s weight_loading/test_weight_loading.py || LOCAL_SUCCESS=$?
|
||||
|
||||
if [[ $LOCAL_SUCCESS == 0 ]]; then
|
||||
|
||||
@ -1,14 +1,21 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
MODEL_NAME = os.environ.get("MODEL_NAME",
|
||||
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
|
||||
REVISION = os.environ.get("REVISION", "main")
|
||||
QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
|
||||
MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "89")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(int(MIN_CAPABILITY)),
|
||||
reason="Current system does not have minimum capability.")
|
||||
def test_weight_loading(vllm_runner):
|
||||
"""
|
||||
Test parameter weight loading with tp>1.
|
||||
|
||||
@ -552,6 +552,109 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
|
||||
return out
|
||||
|
||||
|
||||
def cutlass_sparse_compress(a: torch.Tensor) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compresses a sparse matrix for use with Cutlass sparse operations.
|
||||
|
||||
This function takes a dense tensor and compresses it into two components:
|
||||
non-zero elements and metadata. The compressed representation is compatible
|
||||
with Cutlass sparse kernels.
|
||||
|
||||
Args:
|
||||
a (torch.Tensor):
|
||||
The input tensor to be compressed. Must have one of the following data types:
|
||||
- `torch.int8`
|
||||
- `torch.float8_e4m3fn`
|
||||
- `torch.bfloat16`
|
||||
- `torch.float16`
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
A tuple containing:
|
||||
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
|
||||
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
|
||||
|
||||
Raises:
|
||||
ValueError: If the compression operation fails.
|
||||
|
||||
Notes:
|
||||
- The `a_meta` tensor has a data type of `torch.uint8`.
|
||||
- Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
|
||||
- The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
|
||||
- The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
|
||||
"""
|
||||
assert (a.dtype in [
|
||||
torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16
|
||||
])
|
||||
assert (a.is_contiguous())
|
||||
|
||||
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
|
||||
elemsPerMetaElem = 4
|
||||
|
||||
m = a.shape[0]
|
||||
k = a.shape[1]
|
||||
assert (k % 2 == 0)
|
||||
a_nzs = torch.empty((m, k // 2), dtype=a.dtype, device=a.device)
|
||||
a_meta = torch.empty((m, k // 2 // elemsPerMetaElem),
|
||||
dtype=torch.uint8,
|
||||
device=a.device)
|
||||
|
||||
if not (torch.ops._C.cutlass_sparse_compress_entry(a_nzs, a_meta, a)):
|
||||
raise ValueError
|
||||
|
||||
assert (a_nzs.is_contiguous())
|
||||
assert (a_meta.is_contiguous())
|
||||
|
||||
return a_nzs, a_meta
|
||||
|
||||
|
||||
def cutlass_scaled_sparse_mm(
|
||||
a: torch.Tensor,
|
||||
bt_nzs: torch.Tensor,
|
||||
bt_meta: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Performs a scaled sparse matrix multiplication using Cutlass.
|
||||
|
||||
Steps:
|
||||
1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
|
||||
`a = torch.randn((m, k), device='cuda')`.
|
||||
|
||||
2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
|
||||
`b = torch.randn((k, n), device='cuda')`.
|
||||
|
||||
3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
|
||||
`b = prune_to_2_4(b, dim=0)`.
|
||||
|
||||
4. Compress the transposed sparse matrix `b.t()`:
|
||||
`bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.
|
||||
|
||||
5. Perform sparse matrix multiplication using the compressed matrix,
|
||||
applying scaling factors for `a` and `b`, and the output data type:
|
||||
`out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.
|
||||
|
||||
Returns:
|
||||
- The result of the scaled sparse matrix multiplication.
|
||||
"""
|
||||
assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.shape[0] == bt_nzs.shape[0] \
|
||||
and bias.dtype == out_dtype
|
||||
|
||||
m = a.shape[0]
|
||||
n = bt_nzs.shape[0]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
|
||||
torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a,
|
||||
scale_b, bias)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# aqlm
|
||||
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
from compressed_tensors.config import CompressionFormat
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure)
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
@ -15,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||
CompressedTensorsMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
|
||||
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
@ -27,20 +29,29 @@ from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["CompressedTensorsLinearMethod"]
|
||||
|
||||
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self,
|
||||
target_scheme_map: Dict[str, Any],
|
||||
ignore: List[str],
|
||||
quant_format: str,
|
||||
kv_cache_scheme: Optional[Dict[str, Any]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
target_scheme_map: Dict[str, Any],
|
||||
ignore: List[str],
|
||||
quant_format: str,
|
||||
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
|
||||
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
|
||||
self.ignore = ignore
|
||||
self.quant_format = quant_format
|
||||
# Map from [target -> scheme]
|
||||
self.target_scheme_map = target_scheme_map
|
||||
self.kv_cache_scheme = kv_cache_scheme
|
||||
self.sparsity_scheme_map = sparsity_scheme_map
|
||||
self.config = config
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
@ -78,8 +89,50 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(
|
||||
config=config)
|
||||
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(
|
||||
config=config)
|
||||
|
||||
return cls(
|
||||
target_scheme_map=target_scheme_map,
|
||||
ignore=ignore,
|
||||
quant_format=quant_format,
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _sparsity_scheme_map_from_config(
|
||||
cls, config: Dict[str,
|
||||
Any]) -> Dict[str, SparsityCompressionConfig]:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A dictionary mapping target layer names to their corresponding
|
||||
sparsity compression configurations
|
||||
"""
|
||||
if (sparsity_config := config.get(SPARSITY_CONFIG_NAME)) is None:
|
||||
return dict()
|
||||
|
||||
sparsity_config = SparsityCompressionConfig.model_validate(
|
||||
sparsity_config)
|
||||
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
|
||||
target: sparsity_config
|
||||
for target in sparsity_config.targets or list()
|
||||
}
|
||||
return sparse_scheme_map
|
||||
|
||||
@classmethod
|
||||
def _quantization_scheme_map_from_config(
|
||||
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A dictionary mapping target layer names to their corresponding
|
||||
quantization_args for weights and input activations
|
||||
"""
|
||||
target_scheme_map: Dict[str, Any] = dict()
|
||||
ignore = cast(List[str], config.get("ignore"))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
|
||||
# The quant_config has multiple config_groups, each containing
|
||||
@ -90,12 +143,14 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# details follow the structure defined by the QuantizationArgs
|
||||
# pydantic model, which is used to verify the structure of the
|
||||
# quant_config and also store the details for later use.
|
||||
for _, quant_config in config["config_groups"].items():
|
||||
|
||||
config_groups = config.get("config_groups", dict())
|
||||
for _, quant_config in config_groups.items():
|
||||
targets = quant_config.get("targets")
|
||||
for target in targets:
|
||||
target_scheme_map[target] = {}
|
||||
target_scheme_map[target][
|
||||
"weights"] = QuantizationArgs.parse_obj(
|
||||
"weights"] = QuantizationArgs.model_validate(
|
||||
quant_config.get("weights"))
|
||||
|
||||
target_scheme_map[target]["input_activations"] = None
|
||||
@ -110,13 +165,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
"weights"].type == QuantizationType.FLOAT
|
||||
else:
|
||||
target_scheme_map[target][
|
||||
"input_activations"] = QuantizationArgs.parse_obj(
|
||||
"input_activations"] = QuantizationArgs.model_validate( # noqa: E501
|
||||
quant_config.get("input_activations"))
|
||||
|
||||
return cls(target_scheme_map=target_scheme_map,
|
||||
ignore=ignore,
|
||||
quant_format=quant_format,
|
||||
kv_cache_scheme=config.get("kv_cache_scheme"))
|
||||
return target_scheme_map
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@ -315,23 +366,105 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# TODO (@robertgshaw): add compressed-tensors as dep
|
||||
# so we do not have to re-write these functions
|
||||
# need to make accelerate optional in ct to do this
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.target_scheme_map.keys())
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
scheme = self._get_scheme_from_parts(
|
||||
weight_quant=scheme_dict["weights"],
|
||||
input_quant=scheme_dict["input_activations"])
|
||||
# Will be empty for models with only sparsity
|
||||
if self.target_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.target_scheme_map.keys())
|
||||
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
elif self.sparsity_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.sparsity_scheme_map.keys())
|
||||
weight_quant = None
|
||||
input_quant = None
|
||||
|
||||
# For models with sparsity, assumes that the sparse layers are also
|
||||
# quantized for cutlass 2:4 support
|
||||
sparsity_scheme: Optional[
|
||||
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
|
||||
matched_target)
|
||||
|
||||
if self.supports_cutlass_24(weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
sparsity_scheme=sparsity_scheme):
|
||||
# Have a valid sparsity scheme
|
||||
# Validate layer is supported by Cutlass 2:4 Kernel
|
||||
scheme = CompressedTensors24(quantized=weight_quant is not None
|
||||
or input_quant is not None,
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant)
|
||||
else:
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_parts( # type: ignore
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
)
|
||||
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
|
||||
return scheme
|
||||
|
||||
@staticmethod
|
||||
def supports_cutlass_24(
|
||||
weight_quant: Optional[QuantizationArgs],
|
||||
input_quant: Optional[QuantizationArgs],
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the layer is supported by the Cutlass 2:4 Kernel
|
||||
Conditions:
|
||||
- Overarching condition: Sparsity Structure is 2:4
|
||||
- Unquantized cases are supported
|
||||
- Weight only quantization is not-supported
|
||||
- Supported weight quantization strategies are TENSOR and CHANNEL
|
||||
- Supported input quantization strategies are TENSOR and TOKEN
|
||||
- Only 8 bit quantization is supported
|
||||
|
||||
:return: True if the layer is supported by the Cutlass 2:4 Kernel
|
||||
False otherwise
|
||||
"""
|
||||
is_valid_sparsity = (sparsity_scheme is not None
|
||||
and sparsity_scheme.sparsity_structure
|
||||
== SparsityStructure.TWO_FOUR.value
|
||||
and sparsity_scheme.format == "dense")
|
||||
if not is_valid_sparsity:
|
||||
return False
|
||||
|
||||
# Unquantized cases are supported
|
||||
if weight_quant is None and input_quant is None:
|
||||
return True
|
||||
|
||||
# Weight only quantization is not-supported
|
||||
if weight_quant is not None and input_quant is None:
|
||||
return False
|
||||
|
||||
supported_weight_quant_strategies = [
|
||||
QuantizationStrategy.TENSOR.value,
|
||||
QuantizationStrategy.CHANNEL.value
|
||||
]
|
||||
|
||||
assert weight_quant is not None
|
||||
assert input_quant is not None
|
||||
if weight_quant.strategy not in supported_weight_quant_strategies:
|
||||
return False
|
||||
|
||||
supported_input_quant_strategies = [
|
||||
QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value
|
||||
]
|
||||
|
||||
if input_quant.strategy not in supported_input_quant_strategies:
|
||||
return False
|
||||
|
||||
return weight_quant.num_bits == input_quant.num_bits == 8
|
||||
|
||||
|
||||
class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
|
||||
|
||||
@ -7,13 +7,12 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsWNA16)
|
||||
|
||||
from .compressed_tensors_24 import CompressedTensors24 # isort: skip
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsWNA16",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
"CompressedTensorsW4A16Sparse24",
|
||||
"CompressedTensorsW8A8Int8",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS",
|
||||
"W4A16SPARSE24_SUPPORTED_BITS",
|
||||
"CompressedTensorsScheme", "CompressedTensorsWNA16",
|
||||
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
|
||||
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
|
||||
"CompressedTensors24"
|
||||
]
|
||||
|
||||
@ -0,0 +1,203 @@
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
__all__ = ["CompressedTensors24"]
|
||||
|
||||
|
||||
class CompressedTensors24(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self,
|
||||
quantized: bool = False,
|
||||
weight_quant: Optional[QuantizationArgs] = None,
|
||||
input_quant: Optional[QuantizationArgs] = None):
|
||||
|
||||
self.quantized = quantized
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Only cutlass 3.x kernels are implemented so far
|
||||
return 90
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
self.output_dtype = params_dtype
|
||||
layer.logical_widths = output_partition_sizes
|
||||
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
|
||||
|
||||
# parameter to store uncompressed weight
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=self.weights_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Check if quantized, not just 2:4 Sparse
|
||||
if self.quantized:
|
||||
if (self.weight_quant and self.weight_quant.strategy
|
||||
== QuantizationStrategy.CHANNEL.value):
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert (self.weight_quant and self.weight_quant.strategy
|
||||
== QuantizationStrategy.TENSOR.value)
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# input quant will be non-none
|
||||
if self.input_quant and not self.input_quant.dynamic:
|
||||
# register input quant scale
|
||||
assert (self.input_quant.strategy ==
|
||||
QuantizationStrategy.TENSOR.value)
|
||||
input_scale = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
else:
|
||||
# for sparse-only, pass in 1 for weight/input scales
|
||||
weight_scale = torch.nn.Parameter(data=torch.ones(
|
||||
1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
input_scale = torch.nn.Parameter(data=torch.ones(
|
||||
1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""
|
||||
Compress weights after loading. Store compressed weight and meta
|
||||
tensor
|
||||
|
||||
:post-condition: layer.w_compressed and layer.meta are
|
||||
set to the compressed weight and meta tensor in the
|
||||
format expected by the Cutlass kernels
|
||||
:param layer: The layer with the weights to be processed
|
||||
|
||||
"""
|
||||
# torch.compile workaround
|
||||
if hasattr(layer, "input_scale"):
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.weight_quant:
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
|
||||
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths),
|
||||
requires_grad=False)
|
||||
else:
|
||||
# torch.compile workaround
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
|
||||
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
|
||||
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Returns the output tensor for the layer with 2:4
|
||||
sparse compressed weights, given the input tensor
|
||||
and bias
|
||||
|
||||
:param layer: The layer with 2:4 sparse compressed
|
||||
weights to be used for the computation
|
||||
:param x: The input tensor to the layer
|
||||
:param bias: The bias to be added to the output tensor
|
||||
:return: The output tensor of the layer
|
||||
"""
|
||||
if self.quantized:
|
||||
scale = None
|
||||
if hasattr(layer, "input_scale"):
|
||||
scale = layer.input_scale
|
||||
|
||||
if self.weights_dtype == torch.int8:
|
||||
ops_output = ops.scaled_int8_quant(x, scale=scale)
|
||||
q_input = ops_output[0]
|
||||
input_scale = ops_output[1]
|
||||
else:
|
||||
assert self.weights_dtype == torch.float8_e4m3fn
|
||||
if scale is not None:
|
||||
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
|
||||
else:
|
||||
q_input, input_scale = ops.scaled_fp8_quant(
|
||||
x, use_per_token_if_dynamic=True)
|
||||
|
||||
else:
|
||||
# Not quantized, nothing to do with the input_scales, use as is
|
||||
input_scale = layer.input_scale
|
||||
q_input = x
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a=q_input,
|
||||
bt_nzs=layer.weight,
|
||||
bt_meta=layer.meta,
|
||||
scale_a=input_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=self.output_dtype,
|
||||
bias=bias)
|
||||
assert out.is_contiguous()
|
||||
return out
|
||||
|
||||
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
|
||||
if not self.quantized:
|
||||
return params_dtype
|
||||
|
||||
assert self.weight_quant is not None
|
||||
assert self.input_quant is not None
|
||||
|
||||
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
|
||||
|
||||
if not is_8_bits:
|
||||
raise ValueError("Cutlass only supports 8-bit quantization")
|
||||
|
||||
if (self.weight_quant.type == QuantizationType.FLOAT
|
||||
and self.input_quant.type == QuantizationType.FLOAT):
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
if (self.weight_quant.type == QuantizationType.INT
|
||||
and self.input_quant.type == QuantizationType.INT):
|
||||
return torch.int8
|
||||
|
||||
raise ValueError("Quantization type not supported by Cutlass")
|
||||
|
||||
|
||||
def check_24(tensor):
|
||||
new_tensor = tensor.view(-1, 4)
|
||||
zero_counts = (new_tensor == 0).sum(dim=1)
|
||||
return (zero_counts >= 2).all().item()
|
||||
Loading…
x
Reference in New Issue
Block a user