mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 13:25:30 +08:00
[Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support (#10995)
Co-authored-by: Faraz Shahsavan <faraz.shahsavan@gmail.com> Co-authored-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Rahul Tuli <rahul@neuralmagic.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
parent
f04e407e6b
commit
60508ffda9
@ -206,7 +206,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||||
|
|
||||||
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
|
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
|
||||||
set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use")
|
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
|
||||||
|
|
||||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||||
@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
cutlass
|
cutlass
|
||||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||||
GIT_TAG v3.5.1
|
GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
|
|
||||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||||
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
|
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
|
||||||
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
|
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
|
||||||
GIT_SHALLOW TRUE
|
GIT_SHALLOW FALSE
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
FetchContent_MakeAvailable(cutlass)
|
FetchContent_MakeAvailable(cutlass)
|
||||||
@ -241,7 +241,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/custom_all_reduce.cu"
|
"csrc/custom_all_reduce.cu"
|
||||||
"csrc/permute_cols.cu"
|
"csrc/permute_cols.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
|
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||||
|
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||||
|
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
|
||||||
|
"csrc/cutlass_extensions/common.cpp")
|
||||||
|
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${VLLM_EXT_SRC}"
|
SRCS "${VLLM_EXT_SRC}"
|
||||||
@ -271,11 +274,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
# The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels
|
||||||
|
# For Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||||
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
|
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||||
|
"csrc/sparse/cutlass/sparse_compressor_c3x.cu"
|
||||||
|
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||||
@ -284,12 +290,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
||||||
else()
|
else()
|
||||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||||
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
|
message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is "
|
||||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||||
"later if you intend on running FP8 quantized models on "
|
"later if you intend on running FP8 sparse or quantized models on "
|
||||||
"Hopper.")
|
"Hopper.")
|
||||||
else()
|
else()
|
||||||
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
|
message(STATUS "Not building cutlass_c3x as no compatible archs found "
|
||||||
"in CUDA target architectures")
|
"in CUDA target architectures")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@ -404,7 +410,7 @@ define_gpu_extension_target(
|
|||||||
SOURCES ${VLLM_EXT_SRC}
|
SOURCES ${VLLM_EXT_SRC}
|
||||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||||
USE_SABI 3
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
|
|||||||
384
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Normal file
384
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Normal file
@ -0,0 +1,384 @@
|
|||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
import pickle as pkl
|
||||||
|
import time
|
||||||
|
from typing import Callable, Iterable, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
from utils import make_rand_sparse_tensors
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
|
||||||
|
# bench
|
||||||
|
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||||
|
**kwargs) -> TMeasurement:
|
||||||
|
min_run_time = 1
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
"args": args,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
"fn": fn,
|
||||||
|
}
|
||||||
|
return TBenchmark.Timer(
|
||||||
|
stmt="fn(*args, **kwargs)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description=description,
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
|
assert dtype == torch.int8
|
||||||
|
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||||
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
|
||||||
|
torch.bfloat16)
|
||||||
|
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||||
|
|
||||||
|
if not torch.allclose(out, out_ref):
|
||||||
|
print("Incorrect results")
|
||||||
|
print(out)
|
||||||
|
print(out_ref)
|
||||||
|
else:
|
||||||
|
print("Correct results")
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
# pytorch impl - bfloat16
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||||
|
torch.mm, a.to(dtype=torch.bfloat16),
|
||||||
|
b.to(dtype=torch.bfloat16)))
|
||||||
|
|
||||||
|
# pytorch impl - float16
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label,
|
||||||
|
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
|
||||||
|
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
|
||||||
|
|
||||||
|
# cutlass impl
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
|
||||||
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||||
|
torch.bfloat16))
|
||||||
|
|
||||||
|
# cutlass with bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||||
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||||
|
bias))
|
||||||
|
|
||||||
|
# cutlass sparse impl
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
|
||||||
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
|
scale_b, torch.bfloat16))
|
||||||
|
|
||||||
|
# cutlass sparse with bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
|
||||||
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
|
scale_b, torch.bfloat16, bias))
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
|
assert dtype == torch.float8_e4m3fn
|
||||||
|
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
|
||||||
|
k)
|
||||||
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
|
||||||
|
torch.bfloat16)
|
||||||
|
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||||
|
|
||||||
|
if not torch.allclose(out, out_ref):
|
||||||
|
print("Incorrect results")
|
||||||
|
print(out)
|
||||||
|
print(out_ref)
|
||||||
|
else:
|
||||||
|
print("Correct results")
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
|
||||||
|
# pytorch impl w. bf16
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||||
|
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
|
||||||
|
b.to(dtype=torch.bfloat16, device="cuda")))
|
||||||
|
|
||||||
|
# pytorch impl: bf16 output, without fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label,
|
||||||
|
sub_label,
|
||||||
|
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.bfloat16))
|
||||||
|
|
||||||
|
# pytorch impl: bf16 output, with fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label,
|
||||||
|
sub_label,
|
||||||
|
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
use_fast_accum=True))
|
||||||
|
|
||||||
|
# pytorch impl: fp16 output, without fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label,
|
||||||
|
sub_label,
|
||||||
|
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.float16))
|
||||||
|
|
||||||
|
# pytorch impl: fp16 output, with fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label,
|
||||||
|
sub_label,
|
||||||
|
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.float16,
|
||||||
|
use_fast_accum=True))
|
||||||
|
|
||||||
|
# cutlass impl: bf16 output
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
|
||||||
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||||
|
torch.bfloat16))
|
||||||
|
|
||||||
|
# cutlass impl: bf16 output
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
|
||||||
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
|
scale_b, torch.bfloat16))
|
||||||
|
|
||||||
|
# cutlass impl: fp16 output
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
|
||||||
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
|
scale_b, torch.float16))
|
||||||
|
|
||||||
|
# cutlass impl: bf16 output, with bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label,
|
||||||
|
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
|
||||||
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
|
scale_b, torch.bfloat16, bias))
|
||||||
|
|
||||||
|
# cutlass impl: fp16 output, with bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(label, sub_label,
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
|
||||||
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
|
scale_b, torch.float16, bias.to(dtype=torch.float16)))
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
|
if dtype == torch.int8:
|
||||||
|
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return bench_fp8(dtype, m, k, n, label, sub_label)
|
||||||
|
raise ValueError("unsupported type")
|
||||||
|
|
||||||
|
|
||||||
|
# runner
|
||||||
|
def print_timers(timers: Iterable[TMeasurement]):
|
||||||
|
compare = TBenchmark.Compare(timers)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
def run(dtype: torch.dtype,
|
||||||
|
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||||
|
results = []
|
||||||
|
for m, k, n in MKNs:
|
||||||
|
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||||
|
f"MKN=({m}x{k}x{n})")
|
||||||
|
print_timers(timers)
|
||||||
|
results.extend(timers)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# output makers
|
||||||
|
def make_output(data: Iterable[TMeasurement],
|
||||||
|
MKNs: Iterable[Tuple[int, int, int]],
|
||||||
|
base_description: str,
|
||||||
|
timestamp=None):
|
||||||
|
print(f"== All Results {base_description} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
# pickle all the results
|
||||||
|
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||||
|
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
|
# argparse runners
|
||||||
|
|
||||||
|
|
||||||
|
def run_square_bench(args):
|
||||||
|
dim_sizes = list(
|
||||||
|
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_range_bench(args):
|
||||||
|
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||||
|
n = len(dim_sizes)
|
||||||
|
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||||
|
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||||
|
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||||
|
MKNs = list(zip(Ms, Ks, Ns))
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_model_bench(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
||||||
|
KNs = []
|
||||||
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||||
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
KNs.append(KN)
|
||||||
|
return KNs
|
||||||
|
|
||||||
|
model_bench_data = []
|
||||||
|
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||||
|
for model, tp_size in models_tps:
|
||||||
|
Ms = args.batch_sizes
|
||||||
|
KNs = model_shapes(model, tp_size)
|
||||||
|
MKNs = []
|
||||||
|
for m in Ms:
|
||||||
|
for k, n in KNs:
|
||||||
|
MKNs.append((m, k, n))
|
||||||
|
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
model_bench_data.append(data)
|
||||||
|
|
||||||
|
# Print all results
|
||||||
|
for data, model_tp in zip(model_bench_data, models_tps):
|
||||||
|
model, tp_size = model_tp
|
||||||
|
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
timestamp = int(time.time())
|
||||||
|
|
||||||
|
all_data = []
|
||||||
|
for d in model_bench_data:
|
||||||
|
all_data.extend(d)
|
||||||
|
# pickle all data
|
||||||
|
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(all_data, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
def to_torch_dtype(dt):
|
||||||
|
if dt == "int8":
|
||||||
|
return torch.int8
|
||||||
|
if dt == "fp8":
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="""
|
||||||
|
Benchmark Cutlass GEMM.
|
||||||
|
|
||||||
|
To run square GEMMs:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||||
|
|
||||||
|
To run constant N and K and sweep M:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||||
|
|
||||||
|
To run dimensions from a model:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||||
|
""", # noqa: E501
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter)
|
||||||
|
|
||||||
|
parser.add_argument("--dtype",
|
||||||
|
type=to_torch_dtype,
|
||||||
|
required=True,
|
||||||
|
help="Available options are ['int8', 'fp8']")
|
||||||
|
subparsers = parser.add_subparsers(dest="cmd")
|
||||||
|
|
||||||
|
square_parser = subparsers.add_parser("square_bench")
|
||||||
|
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
square_parser.set_defaults(func=run_square_bench)
|
||||||
|
|
||||||
|
range_parser = subparsers.add_parser("range_bench")
|
||||||
|
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||||
|
range_parser.set_defaults(func=run_range_bench)
|
||||||
|
|
||||||
|
model_parser = subparsers.add_parser("model_bench")
|
||||||
|
model_parser.add_argument("--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES.keys())
|
||||||
|
model_parser.add_argument("--tp-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_TP_SIZES)
|
||||||
|
model_parser.add_argument("--batch-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZES)
|
||||||
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.func(args)
|
||||||
96
benchmarks/cutlass_benchmarks/utils.py
Normal file
96
benchmarks/cutlass_benchmarks/utils.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# Cutlass bench utils
|
||||||
|
from typing import Iterable, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
return torch.round(tensor.clamp(
|
||||||
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
|
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return tensor.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return tensor.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
|
||||||
|
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||||
|
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
|
|
||||||
|
if dtype == torch.int8:
|
||||||
|
return to_int8(a), to_int8(b)
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return to_fp8(a), to_fp8(b)
|
||||||
|
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
|
||||||
|
def prune_to_2_4(tensor):
|
||||||
|
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||||
|
original_shape = tensor.shape
|
||||||
|
reshaped = tensor.reshape(-1, 4)
|
||||||
|
|
||||||
|
# Get indices of top 2 absolute values in each group of 4
|
||||||
|
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||||
|
|
||||||
|
# Create binary mask
|
||||||
|
mask = torch.zeros_like(reshaped)
|
||||||
|
mask.scatter_(dim=1,
|
||||||
|
index=indices,
|
||||||
|
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||||
|
|
||||||
|
# Apply mask and reshape back
|
||||||
|
pruned = reshaped * mask
|
||||||
|
|
||||||
|
# Turn all -0.0 to 0.0
|
||||||
|
pruned[pruned == -0.0] = 0.0
|
||||||
|
|
||||||
|
return pruned.reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
||||||
|
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
|
|
||||||
|
b = prune_to_2_4(b.t()).t()
|
||||||
|
|
||||||
|
if dtype == torch.int8:
|
||||||
|
a, b = to_int8(a), to_int8(b)
|
||||||
|
elif dtype == torch.float8_e4m3fn:
|
||||||
|
a, b = to_fp8(a), to_fp8(b)
|
||||||
|
elif dtype == torch.float16:
|
||||||
|
a, b = to_fp16(a), to_fp16(b)
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
a, b = to_bf16(a), to_bf16(b)
|
||||||
|
else:
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||||
|
|
||||||
|
# Compressed B, Metadata, Original A, B
|
||||||
|
return b_compressed, e, a, b
|
||||||
|
|
||||||
|
|
||||||
|
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
|
||||||
|
m: int, n: int, k: int) -> \
|
||||||
|
Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
||||||
|
ABs = []
|
||||||
|
for _ in range(num_tensors):
|
||||||
|
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||||
|
if b_comp is not None:
|
||||||
|
ABs.append(make_rand_sparse_tensors(dtype, m, n, k))
|
||||||
|
BComps, Es, As, Bs = zip(*ABs)
|
||||||
|
return list(BComps), list(Es), list(As), list(Bs)
|
||||||
@ -8,6 +8,7 @@ from typing import Callable, Iterable, List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.benchmark as TBenchmark
|
import torch.utils.benchmark as TBenchmark
|
||||||
from torch.utils.benchmark import Measurement as TMeasurement
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
from utils import make_rand_tensors
|
||||||
from weight_shapes import WEIGHT_SHAPES
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@ -17,31 +18,6 @@ DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
|||||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||||
DEFAULT_TP_SIZES = [1]
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
# helpers
|
|
||||||
|
|
||||||
|
|
||||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
|
||||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
||||||
return torch.round(tensor.clamp(
|
|
||||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
|
|
||||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
|
||||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
|
||||||
|
|
||||||
|
|
||||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
|
||||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
a = torch.randn((m, k), device='cuda') * 5
|
|
||||||
b = torch.randn((n, k), device='cuda').t() * 5
|
|
||||||
|
|
||||||
if dtype == torch.int8:
|
|
||||||
return to_int8(a), to_int8(b)
|
|
||||||
if dtype == torch.float8_e4m3fn:
|
|
||||||
return to_fp8(a), to_fp8(b)
|
|
||||||
|
|
||||||
raise ValueError("unsupported dtype")
|
|
||||||
|
|
||||||
|
|
||||||
# bench
|
# bench
|
||||||
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||||
@ -386,4 +362,4 @@ Benchmark Cutlass GEMM.
|
|||||||
model_parser.set_defaults(func=run_model_bench)
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.func(args)
|
args.func(args)
|
||||||
@ -40,4 +40,4 @@ WEIGHT_SHAPES = {
|
|||||||
([8192, 57344], 1),
|
([8192, 57344], 1),
|
||||||
([28672, 8192], 0),
|
([28672, 8192], 0),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
7
csrc/core/math.hpp
Normal file
7
csrc/core/math.hpp
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
#include <climits>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
inline uint32_t next_pow_2(uint32_t const num) {
|
||||||
|
if (num <= 1) return num;
|
||||||
|
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||||
|
}
|
||||||
11
csrc/cutlass_extensions/common.cpp
Normal file
11
csrc/cutlass_extensions/common.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
|
int32_t get_sm_version_num() {
|
||||||
|
int32_t major_capability, minor_capability;
|
||||||
|
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||||
|
0);
|
||||||
|
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||||
|
0);
|
||||||
|
int32_t version_num = major_capability * 10 + minor_capability;
|
||||||
|
return version_num;
|
||||||
|
}
|
||||||
35
csrc/cutlass_extensions/common.hpp
Normal file
35
csrc/cutlass_extensions/common.hpp
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include <climits>
|
||||||
|
#include "cuda_runtime.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper function for checking CUTLASS errors
|
||||||
|
*/
|
||||||
|
#define CUTLASS_CHECK(status) \
|
||||||
|
{ \
|
||||||
|
cutlass::Status error = status; \
|
||||||
|
TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
||||||
|
cutlassGetStatusString(error)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Panic wrapper for unwinding CUDA runtime errors
|
||||||
|
*/
|
||||||
|
#define CUDA_CHECK(status) \
|
||||||
|
{ \
|
||||||
|
cudaError_t error = status; \
|
||||||
|
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||||
|
int max_shared_mem_per_block_opt_in = 0;
|
||||||
|
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||||
|
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
||||||
|
device);
|
||||||
|
return max_shared_mem_per_block_opt_in;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t get_sm_version_num();
|
||||||
@ -36,13 +36,13 @@ struct ScaledEpilogueBase {
|
|||||||
// Don't want to support nullptr by default
|
// Don't want to support nullptr by default
|
||||||
template <typename T, bool EnableNullPtr = false>
|
template <typename T, bool EnableNullPtr = false>
|
||||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
||||||
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||||
|
|
||||||
// Don't want to support nullptr by default
|
// Don't want to support nullptr by default
|
||||||
template <typename T, bool EnableNullPtr = false>
|
template <typename T, bool EnableNullPtr = false>
|
||||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
||||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||||
|
|
||||||
// This utility function constructs the arguments for the load descriptors
|
// This utility function constructs the arguments for the load descriptors
|
||||||
|
|||||||
@ -162,6 +162,15 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
c10::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b, torch::Tensor const& e,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
|
||||||
|
torch::Tensor& e, torch::Tensor const& a);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
|
|||||||
@ -1,27 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
|
||||||
#include <climits>
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function for checking CUTLASS errors
|
|
||||||
*/
|
|
||||||
#define CUTLASS_CHECK(status) \
|
|
||||||
{ \
|
|
||||||
TORCH_CHECK(status == cutlass::Status::kSuccess, \
|
|
||||||
cutlassGetStatusString(status)) \
|
|
||||||
}
|
|
||||||
|
|
||||||
inline uint32_t next_pow_2(uint32_t const num) {
|
|
||||||
if (num <= 1) return num;
|
|
||||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
|
||||||
int max_shared_mem_per_block_opt_in = 0;
|
|
||||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
|
||||||
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
|
||||||
device);
|
|
||||||
return max_shared_mem_per_block_opt_in;
|
|
||||||
}
|
|
||||||
|
|
||||||
@ -21,7 +21,8 @@
|
|||||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "core/math.hpp"
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|||||||
@ -24,7 +24,8 @@
|
|||||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
#include "common.hpp"
|
#include "core/math.hpp"
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|||||||
@ -3,6 +3,8 @@
|
|||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t get_sm_version_num() {
|
|
||||||
int32_t major_capability, minor_capability;
|
|
||||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
|
||||||
0);
|
|
||||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
|
||||||
0);
|
|
||||||
int32_t version_num = major_capability * 10 + minor_capability;
|
|
||||||
return version_num;
|
|
||||||
}
|
|
||||||
|
|
||||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
|
|||||||
163
csrc/sparse/cutlass/sparse_compressor_c3x.cu
Normal file
163
csrc/sparse/cutlass/sparse_compressor_c3x.cu
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
// clang-format will break include orders
|
||||||
|
// clang-format off
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#include "sparse_scaled_mm_c3x.cuh"
|
||||||
|
|
||||||
|
#include "cutlass/numeric_conversion.h"
|
||||||
|
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||||
|
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||||
|
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||||
|
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
#include "cutlass/util/packed_stride.hpp"
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
using namespace vllm;
|
||||||
|
|
||||||
|
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||||
|
template <typename ElementA_, typename ElementAcc_>
|
||||||
|
bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||||
|
torch::Tensor const& a) {
|
||||||
|
// Checks for conformality
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
|
||||||
|
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
|
||||||
|
TORCH_CHECK(a.dim() == 2)
|
||||||
|
// Check for strides and alignment
|
||||||
|
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
|
||||||
|
TORCH_CHECK(a.stride(1) == 1)
|
||||||
|
|
||||||
|
int m = a.size(0);
|
||||||
|
int k = a.size(1);
|
||||||
|
|
||||||
|
// Sparse kernel setup; this kernel is not used for matmul,
|
||||||
|
// but just for setting up the compressor utility
|
||||||
|
// A matrix configuration
|
||||||
|
using ElementA = ElementA_;
|
||||||
|
using LayoutTagA = cutlass::layout::RowMajor;
|
||||||
|
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||||
|
// B matrix configuration
|
||||||
|
using ElementB = ElementA;
|
||||||
|
using LayoutTagB = cutlass::layout::ColumnMajor;
|
||||||
|
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||||
|
// C/D matrix configuration
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutTagC = cutlass::layout::ColumnMajor;
|
||||||
|
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||||
|
// Core kernel configurations
|
||||||
|
using ElementAccumulator = ElementAcc_;
|
||||||
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
|
using TileShapeRef = Shape<_128, _128, _64>;
|
||||||
|
using ClusterShape = Shape<_1, _2, _1>;
|
||||||
|
using KernelSchedule = typename std::conditional<
|
||||||
|
std::is_same_v<ElementA, cutlass::float_e4m3_t>,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecialized>::type;
|
||||||
|
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using ProblemShape = Shape<int, int, int, int>;
|
||||||
|
|
||||||
|
using CollectiveEpilogue =
|
||||||
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||||
|
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
|
||||||
|
AlignmentC, ElementC, LayoutTagC, AlignmentC,
|
||||||
|
EpilogueSchedule>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMainloop =
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
|
||||||
|
LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
|
||||||
|
ElementAccumulator, TileShape, ClusterShape,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
|
KernelSchedule>::CollectiveOp;
|
||||||
|
|
||||||
|
using GemmKernel =
|
||||||
|
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||||
|
CollectiveEpilogue>;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
|
||||||
|
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
|
||||||
|
using StrideE = StrideA;
|
||||||
|
|
||||||
|
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||||
|
|
||||||
|
// The n (=1) dimension does not matter for the compressor
|
||||||
|
typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};
|
||||||
|
|
||||||
|
using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
|
||||||
|
using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;
|
||||||
|
|
||||||
|
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||||
|
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||||
|
|
||||||
|
// Offline compressor kernel
|
||||||
|
using CompressorUtility =
|
||||||
|
cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||||
|
ProblemShape, ElementA, LayoutTagA, SparseConfig>;
|
||||||
|
|
||||||
|
using CompressorKernel =
|
||||||
|
cutlass::transform::kernel::StructuredSparseCompressor<
|
||||||
|
ProblemShape, ElementA, LayoutTagA, SparseConfig,
|
||||||
|
cutlass::arch::Sm90>;
|
||||||
|
|
||||||
|
using Compressor =
|
||||||
|
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||||
|
|
||||||
|
auto [M, N, K, L] = prob_shape;
|
||||||
|
|
||||||
|
StrideA stride_A;
|
||||||
|
stride_A =
|
||||||
|
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||||
|
|
||||||
|
CompressorUtility compressor_utility(prob_shape, stride_A);
|
||||||
|
|
||||||
|
int ME = compressor_utility.get_metadata_m_physical();
|
||||||
|
int KE = compressor_utility.get_metadata_k_physical();
|
||||||
|
int KC = compressor_utility.get_tensorA_k_physical();
|
||||||
|
|
||||||
|
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||||
|
|
||||||
|
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
|
||||||
|
auto a_meta_ptr = static_cast<typename Gemm::CollectiveMainloop::ElementE*>(
|
||||||
|
a_meta.data_ptr());
|
||||||
|
|
||||||
|
cutlass::KernelHardwareInfo hw_info;
|
||||||
|
hw_info.device_id = 0;
|
||||||
|
hw_info.sm_count =
|
||||||
|
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||||
|
hw_info.device_id);
|
||||||
|
typename Compressor::Arguments arguments{
|
||||||
|
prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}};
|
||||||
|
|
||||||
|
Compressor compressor_op;
|
||||||
|
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||||
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||||
|
|
||||||
|
CUTLASS_CHECK(compressor_op.can_implement(arguments));
|
||||||
|
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
|
||||||
|
CUTLASS_CHECK(compressor_op.run());
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||||
|
torch::Tensor const& a) {
|
||||||
|
if (a.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_sparse_compress<cutlass::bfloat16_t, float>(a_nzs, a_meta,
|
||||||
|
a);
|
||||||
|
} else if (a.dtype() == torch::kFloat16) {
|
||||||
|
return cutlass_sparse_compress<cutlass::half_t, float>(a_nzs, a_meta, a);
|
||||||
|
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||||
|
return cutlass_sparse_compress<cutlass::float_e4m3_t, float>(a_nzs, a_meta,
|
||||||
|
a);
|
||||||
|
} else if (a.dtype() == torch::kInt8) {
|
||||||
|
return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
42
csrc/sparse/cutlass/sparse_compressor_entry.cu
Normal file
42
csrc/sparse/cutlass/sparse_compressor_entry.cu
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||||
|
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||||
|
torch::Tensor const& a);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||||
|
torch::Tensor const& a) {
|
||||||
|
// Checks for conformality
|
||||||
|
TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2);
|
||||||
|
TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) &&
|
||||||
|
a_nzs.size(1) * 2 == a.size(1) &&
|
||||||
|
a_meta.size(1) * 2 * 4 == a.size(1));
|
||||||
|
// Considering elemsPerMetaElem = 8b / 2b_per_nz = 4
|
||||||
|
|
||||||
|
// Check for strides and alignment
|
||||||
|
TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 &&
|
||||||
|
a_meta.stride(1) == 1); // Row-major
|
||||||
|
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
|
||||||
|
|
||||||
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
|
||||||
|
// Guard against compilation issues for sm90 kernels
|
||||||
|
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||||
|
if (version_num >= 90) {
|
||||||
|
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
|
||||||
|
"CUDA device capability: ",
|
||||||
|
version_num);
|
||||||
|
}
|
||||||
303
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Normal file
303
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
// clang-format will break include orders
|
||||||
|
// clang-format off
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
|
#include "sparse_scaled_mm_c3x.cuh"
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
using namespace vllm;
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& bt_nzs,
|
||||||
|
torch::Tensor const& bt_meta,
|
||||||
|
EpilogueArgs&&... args) {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||||
|
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
|
using Cutlass3xGemmDefault =
|
||||||
|
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM64 =
|
||||||
|
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM128 =
|
||||||
|
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM256 =
|
||||||
|
typename sm90_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM512 =
|
||||||
|
typename sm90_fp8_config_M512<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
|
||||||
|
using Cutlass3xGemm1 =
|
||||||
|
typename sm90_fp8_config_1<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemm2 =
|
||||||
|
typename sm90_fp8_config_2<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemm3 =
|
||||||
|
typename sm90_fp8_config_3<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemm4 =
|
||||||
|
typename sm90_fp8_config_4<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemm5 =
|
||||||
|
typename sm90_fp8_config_5<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemm6 =
|
||||||
|
typename sm90_fp8_config_6<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemm7 =
|
||||||
|
typename sm90_fp8_config_7<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemm8 =
|
||||||
|
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
|
||||||
|
uint32_t const n = bt_nzs.size(0);
|
||||||
|
uint32_t const m = a.size(0); // Batch size
|
||||||
|
uint32_t const mp2 =
|
||||||
|
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||||
|
|
||||||
|
if (mp2 <= 64) {
|
||||||
|
if (n == 28672) {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemm2>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (n == 4096 || n == 6144) {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemm1>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
} else if (mp2 <= 128) {
|
||||||
|
if (n == 4096) {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemm3>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (n == 28672) {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemm5>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (n == 6144) {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemm4>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
} else if (mp2 <= 256) {
|
||||||
|
if (n == 4096) {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemm6>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (n == 28672) {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (n == 6144) {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (n == 6144 || n == 28672) {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (n == 4096) {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise the default heuristic
|
||||||
|
if (mp2 <= 64) {
|
||||||
|
// n in [1, 64]
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (mp2 <= 128) {
|
||||||
|
// n in (64, 128]
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (mp2 <= 256) {
|
||||||
|
// n in (128, 256]
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else {
|
||||||
|
// n in (256, inf)
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemmM512>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& bt_nzs,
|
||||||
|
torch::Tensor const& bt_meta,
|
||||||
|
EpilogueArgs&&... args) {
|
||||||
|
static_assert(std::is_same<InType, cutlass::half_t>());
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat16);
|
||||||
|
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||||
|
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
|
||||||
|
|
||||||
|
using Cutlass3xGemmDefault =
|
||||||
|
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
|
||||||
|
// m in (128, inf)
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& bt_nzs,
|
||||||
|
torch::Tensor const& bt_meta,
|
||||||
|
EpilogueArgs&&... args) {
|
||||||
|
static_assert(std::is_same<InType, cutlass::bfloat16_t>());
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kBFloat16);
|
||||||
|
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||||
|
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
|
||||||
|
|
||||||
|
using Cutlass3xGemmDefault =
|
||||||
|
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
|
||||||
|
// m in (128, inf)
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& bt_nzs,
|
||||||
|
torch::Tensor const& bt_meta,
|
||||||
|
EpilogueArgs&&... args) {
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
|
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||||
|
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
|
||||||
|
|
||||||
|
using Cutlass3xGemmDefault =
|
||||||
|
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM128 =
|
||||||
|
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM64 =
|
||||||
|
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM32NBig =
|
||||||
|
typename sm90_int8_config_M32_NBig<InType, OutType,
|
||||||
|
Epilogue>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM32NSmall =
|
||||||
|
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
||||||
|
Epilogue>::Cutlass3xGemm;
|
||||||
|
|
||||||
|
uint32_t const n = out.size(1);
|
||||||
|
bool const is_small_n = n < 8192;
|
||||||
|
|
||||||
|
uint32_t const m = a.size(0);
|
||||||
|
uint32_t const mp2 =
|
||||||
|
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||||
|
|
||||||
|
if (mp2 <= 32) {
|
||||||
|
// m in [1, 32]
|
||||||
|
if (is_small_n) {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NSmall>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else {
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NBig>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
} else if (mp2 <= 64) {
|
||||||
|
// m in (32, 64]
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (mp2 <= 128) {
|
||||||
|
// m in (64, 128]
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else {
|
||||||
|
// m in (128, inf)
|
||||||
|
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||||
|
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <template <typename, typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& bt_nzs,
|
||||||
|
torch::Tensor const& bt_meta,
|
||||||
|
EpilogueArgs&&... epilogue_args) {
|
||||||
|
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||||
|
if (a.dtype() == torch::kInt8) {
|
||||||
|
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||||
|
Epilogue>(
|
||||||
|
out, a, bt_nzs, bt_meta,
|
||||||
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||||
|
out, a, bt_nzs, bt_meta,
|
||||||
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
}
|
||||||
|
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||||
|
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
|
cutlass::bfloat16_t, Epilogue>(
|
||||||
|
out, a, bt_nzs, bt_meta,
|
||||||
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
|
cutlass::half_t, Epilogue>(
|
||||||
|
out, a, bt_nzs, bt_meta,
|
||||||
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
}
|
||||||
|
} else if (a.dtype() == torch::kFloat16) {
|
||||||
|
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t,
|
||||||
|
cutlass::bfloat16_t, Epilogue>(
|
||||||
|
out, a, bt_nzs, bt_meta,
|
||||||
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t, cutlass::half_t,
|
||||||
|
Epilogue>(
|
||||||
|
out, a, bt_nzs, bt_meta,
|
||||||
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
}
|
||||||
|
} else { // a.dtype() == torch::kBFloat16
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kBFloat16);
|
||||||
|
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
|
||||||
|
cutlass::bfloat16_t, Epilogue>(
|
||||||
|
out, a, bt_nzs, bt_meta,
|
||||||
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
|
||||||
|
cutlass::half_t, Epilogue>(
|
||||||
|
out, a, bt_nzs, bt_meta,
|
||||||
|
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& bt_nzs,
|
||||||
|
torch::Tensor const& bt_meta,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
if (bias) {
|
||||||
|
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||||
|
"currently bias dtype must match output dtype ", out.dtype());
|
||||||
|
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
||||||
|
out, a, bt_nzs, bt_meta, b_scales, a_scales, *bias);
|
||||||
|
} else {
|
||||||
|
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
||||||
|
out, a, bt_nzs, bt_meta, b_scales, a_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
496
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Normal file
496
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Normal file
@ -0,0 +1,496 @@
|
|||||||
|
// clang-format will break include orders
|
||||||
|
// clang-format off
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
#include "core/math.hpp"
|
||||||
|
#include "cutlass_extensions/cute_utils.cuh"
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
#include "cutlass_extensions/torch_utils.hpp"
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
/*
|
||||||
|
This file defines sparse quantized GEMM operations using the CUTLASS 3.x API,
|
||||||
|
for NVIDIA GPUs with sm90a (Hopper) or later.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||||
|
// architectures that will never use the kernel. The purpose of this is to
|
||||||
|
// reduce the size of the compiled binary.
|
||||||
|
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||||
|
// into code that will be executed on the device where it is defined.
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm90_or_later : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||||
|
Kernel::operator()(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using GemmUniversalMode = cutlass::gemm::GemmUniversalMode;
|
||||||
|
|
||||||
|
template <typename ElementAB_, typename ElementD_,
|
||||||
|
template <typename, typename, typename> typename Epilogue_,
|
||||||
|
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||||
|
typename EpilogueSchedule, typename AccType,
|
||||||
|
typename TileSchedule = cutlass::gemm::PersistentScheduler,
|
||||||
|
GemmUniversalMode Mode_ = GemmUniversalMode::kGemm>
|
||||||
|
struct cutlass_sparse_3x_gemm {
|
||||||
|
static const GemmUniversalMode Mode = Mode_;
|
||||||
|
using ElementAB = ElementAB_;
|
||||||
|
using ElementD = ElementD_;
|
||||||
|
using ElementAcc = AccType;
|
||||||
|
|
||||||
|
using EpilogueDescriptor =
|
||||||
|
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||||
|
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||||
|
ElementD, EpilogueSchedule>;
|
||||||
|
|
||||||
|
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||||
|
|
||||||
|
using ElementC = void;
|
||||||
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
using LayoutD = LayoutC;
|
||||||
|
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
|
||||||
|
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||||
|
|
||||||
|
using LayoutC_Transpose =
|
||||||
|
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||||
|
using LayoutD_Transpose =
|
||||||
|
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||||
|
|
||||||
|
using EVTCompute = typename Epilogue::EVTCompute;
|
||||||
|
|
||||||
|
static constexpr int AlignmentA =
|
||||||
|
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||||
|
static constexpr int AlignmentB =
|
||||||
|
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||||
|
static constexpr int AlignmentCD =
|
||||||
|
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
|
||||||
|
using CollectiveEpilogue =
|
||||||
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||||
|
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAcc, ElementAcc, ElementC, LayoutC_Transpose, AlignmentCD,
|
||||||
|
ElementD, LayoutD_Transpose, AlignmentCD, EpilogueSchedule,
|
||||||
|
EVTCompute>::CollectiveOp;
|
||||||
|
|
||||||
|
static constexpr size_t CEStorageSize =
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||||
|
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||||
|
static_cast<int>(CEStorageSize)>;
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
using CollectiveMainloop =
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
|
||||||
|
ElementAB, cutlass::layout::RowMajor, AlignmentA,
|
||||||
|
ElementAB, cutlass::layout::ColumnMajor, AlignmentB,
|
||||||
|
ElementAcc, TileShape, ClusterShape,
|
||||||
|
Stages,
|
||||||
|
KernelSchedule>::CollectiveOp;
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||||
|
TileSchedule>>;
|
||||||
|
|
||||||
|
struct GemmKernel : public KernelType {};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Gemm, typename... EpilogueArgs>
|
||||||
|
void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& bt_nzs,
|
||||||
|
torch::Tensor const& bt_meta,
|
||||||
|
EpilogueArgs&&... epilogue_params) {
|
||||||
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
|
// Interface stride expected from the argument a (will get transposed)
|
||||||
|
// We compute C^T = B^T * A^T, but we assume B is transposed before
|
||||||
|
// compression and hence the bt_* naming
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||||
|
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||||
|
using LayoutD = cutlass::layout::RowMajor;
|
||||||
|
|
||||||
|
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||||
|
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||||
|
|
||||||
|
auto layout_A = make_cute_layout<StrideA>(a, "A");
|
||||||
|
auto layout_D = make_cute_layout<StrideD>(out, "D");
|
||||||
|
|
||||||
|
// Transpose A and D
|
||||||
|
// A doesn't need to be transposed since cutlass expects a NxK matrix
|
||||||
|
// for B (which is At)
|
||||||
|
auto stride_At = layout_A.stride();
|
||||||
|
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
|
||||||
|
|
||||||
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
|
typename GemmKernel::ProblemShape prob_shape{
|
||||||
|
static_cast<int>(bt_nzs.size(0)), static_cast<int>(size<0>(layout_A)),
|
||||||
|
static_cast<int>(size<1>(layout_A)), 1};
|
||||||
|
|
||||||
|
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||||
|
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||||
|
|
||||||
|
LayoutB b_layout = SparseConfig::fill_layoutA(prob_shape);
|
||||||
|
LayoutE e_layout = SparseConfig::fill_layoutE(prob_shape);
|
||||||
|
|
||||||
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||||
|
auto b_ptr = static_cast<ElementAB*>(bt_nzs.data_ptr());
|
||||||
|
auto e_ptr = static_cast<ElementE*>(bt_meta.data_ptr());
|
||||||
|
typename GemmKernel::MainloopArguments mainloop_args{
|
||||||
|
b_ptr, b_layout, a_ptr, stride_At, e_ptr, e_layout};
|
||||||
|
|
||||||
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
|
Gemm::Epilogue::prepare_args(
|
||||||
|
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||||
|
c_ptr, stride_Dt, c_ptr, stride_Dt};
|
||||||
|
|
||||||
|
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||||
|
prob_shape, mainloop_args, epilogue_args};
|
||||||
|
|
||||||
|
// Launch the CUTLASS GEMM kernel.
|
||||||
|
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
GemmOp gemm_op;
|
||||||
|
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||||
|
|
||||||
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
|
auto const workspace_options =
|
||||||
|
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||||
|
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||||
|
|
||||||
|
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_config_default {};
|
||||||
|
|
||||||
|
template <typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_config_default<half_t, OutType, Epilogue> {
|
||||||
|
// M in (128, inf)
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> {
|
||||||
|
// M in (128, inf)
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape,
|
||||||
|
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||||
|
float>;
|
||||||
|
};
|
||||||
|
|
||||||
|
//////////////////////// Cherry-Picking Kernels ////////////////////////
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_1 {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _64, _256>;
|
||||||
|
using ClusterShape = Shape<_8, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_2 {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||||
|
using TileShape = Shape<_128, _64, _256>;
|
||||||
|
using ClusterShape = Shape<_8, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_3 {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _64, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _2, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_4 {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||||
|
using TileShape = Shape<_64, _128, _256>;
|
||||||
|
using ClusterShape = Shape<_8, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_5 {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_128, _128, _256>;
|
||||||
|
using ClusterShape = Shape<_8, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_6 {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _128, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _2, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_7 {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||||
|
using TileShape = Shape<_128, _128, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_8 {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||||
|
using TileShape = Shape<_128, _256, _128>;
|
||||||
|
using ClusterShape = Shape<_8, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float>;
|
||||||
|
};
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_config_default<cutlass::float_e4m3_t, OutType, Epilogue> {
|
||||||
|
// M in (128, inf)
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_1, _2, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue,
|
||||||
|
TileShape, ClusterShape, KernelSchedule,
|
||||||
|
EpilogueSchedule, float>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_M64 {
|
||||||
|
// M in [1, 64]
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||||
|
using TileShape = Shape<_64, _64, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
|
||||||
|
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float,
|
||||||
|
TileSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_M128 {
|
||||||
|
// M in (64, 128]
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _128, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
|
||||||
|
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float,
|
||||||
|
TileSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_M256 {
|
||||||
|
// M in (128, 256]
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||||
|
using TileShape = Shape<_128, _128, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
|
||||||
|
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float,
|
||||||
|
TileSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_M512 {
|
||||||
|
// M in (256, ]
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||||
|
using TileShape = Shape<_128, _128, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
|
||||||
|
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, float,
|
||||||
|
TileSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_config_default<int8_t, OutType, Epilogue> {
|
||||||
|
// For M > 128 and any N
|
||||||
|
using KernelSchedule =
|
||||||
|
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<int8_t, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_int8_config_M128 {
|
||||||
|
// For M in (64, 128] and any N
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_int8_config_M64 {
|
||||||
|
// For M in (32, 64] and any N
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _64, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_int8_config_M32_NBig {
|
||||||
|
// For M in [1, 32] and N >= 8192
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _128, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _4, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_int8_config_M32_NSmall {
|
||||||
|
// For M in [1, 32] and N < 8192
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _64, _256>;
|
||||||
|
using ClusterShape = Shape<_1, _8, _1>;
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
59
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Normal file
59
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
|
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||||
|
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& e,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& bt_nzs,
|
||||||
|
torch::Tensor const& bt_meta,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales,
|
||||||
|
c10::optional<torch::Tensor> const& bias) {
|
||||||
|
// Checks for conformality
|
||||||
|
TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2);
|
||||||
|
TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) &&
|
||||||
|
a.size(0) == c.size(0));
|
||||||
|
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||||
|
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == bt_nzs.size(0));
|
||||||
|
|
||||||
|
// Check for strides and alignment
|
||||||
|
TORCH_CHECK(a.stride(1) == 1 && bt_nzs.stride(1) == 1 &&
|
||||||
|
c.stride(1) == 1); // Row-major
|
||||||
|
TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment
|
||||||
|
TORCH_CHECK(bt_nzs.stride(0) % 16 == 0); // 16 Byte Alignment
|
||||||
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
|
||||||
|
if (bias) {
|
||||||
|
TORCH_CHECK(bias->numel() == bt_nzs.size(0) && bias->is_contiguous() &&
|
||||||
|
bias->dim() == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
|
||||||
|
// Guard against compilation issues for sm90 kernels
|
||||||
|
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||||
|
if (version_num >= 90) {
|
||||||
|
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
|
||||||
|
bias);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
|
||||||
|
"CUDA device capability: ",
|
||||||
|
version_num);
|
||||||
|
}
|
||||||
@ -321,6 +321,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||||
|
|
||||||
|
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
|
||||||
|
// quantization, as well as bias
|
||||||
|
ops.def(
|
||||||
|
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
|
||||||
|
" Tensor bt_nzs,"
|
||||||
|
" Tensor bt_meta, Tensor a_scales,"
|
||||||
|
" Tensor b_scales, Tensor? bias) -> ()");
|
||||||
|
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
|
||||||
|
|
||||||
|
// CUTLASS sparse matrix compressor
|
||||||
|
ops.def(
|
||||||
|
"cutlass_sparse_compress_entry(Tensor! a_nzs, Tensor! a_meta,"
|
||||||
|
" Tensor a) -> bool");
|
||||||
|
ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry);
|
||||||
|
|
||||||
// Mamba selective scan kernel
|
// Mamba selective scan kernel
|
||||||
ops.def(
|
ops.def(
|
||||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||||
|
|||||||
@ -83,7 +83,7 @@ exclude = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[tool.codespell]
|
[tool.codespell]
|
||||||
ignore-words-list = "dout, te, indicies, subtile"
|
ignore-words-list = "dout, te, indicies, subtile, ElementE"
|
||||||
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
|
|||||||
131
tests/kernels/test_semi_structured.py
Normal file
131
tests/kernels/test_semi_structured.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
"""Tests for sparse cutlass kernels
|
||||||
|
|
||||||
|
Run `pytest tests/kernels/test_semi_structured.py`.
|
||||||
|
"""
|
||||||
|
from typing import Optional, Tuple, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
CUDA_DEVICES = [
|
||||||
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
|
]
|
||||||
|
|
||||||
|
capability = current_platform.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp8(tensor: torch.Tensor):
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
return torch.round(tensor.clamp(
|
||||||
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
|
def to_int8(tensor: torch.Tensor):
|
||||||
|
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
def rand_int8(shape: tuple, device: str = "cuda"):
|
||||||
|
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||||
|
|
||||||
|
|
||||||
|
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return tensor.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return tensor.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
|
||||||
|
def prune_to_2_4(tensor):
|
||||||
|
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||||
|
original_shape = tensor.shape
|
||||||
|
reshaped = tensor.reshape(-1, 4)
|
||||||
|
|
||||||
|
# Get indices of top 2 absolute values in each group of 4
|
||||||
|
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||||
|
|
||||||
|
# Create binary mask
|
||||||
|
mask = torch.zeros_like(reshaped)
|
||||||
|
mask.scatter_(dim=1,
|
||||||
|
index=indices,
|
||||||
|
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||||
|
|
||||||
|
# Apply mask and reshape back
|
||||||
|
pruned = reshaped * mask
|
||||||
|
|
||||||
|
# Turn all -0.0 to 0.0
|
||||||
|
pruned[pruned == -0.0] = 0.0
|
||||||
|
|
||||||
|
return pruned.reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def make_rand_sparse_tensors(
|
||||||
|
dtype: torch.dtype, m: int, n: int, k: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
|
|
||||||
|
b = prune_to_2_4(b.t()).t()
|
||||||
|
|
||||||
|
if dtype == torch.int8:
|
||||||
|
a, b = to_int8(a), to_int8(b)
|
||||||
|
elif dtype == torch.float8_e4m3fn:
|
||||||
|
a, b = to_fp8(a), to_fp8(b)
|
||||||
|
elif dtype == torch.float16:
|
||||||
|
a, b = to_fp16(a), to_fp16(b)
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
a, b = to_bf16(a), to_bf16(b)
|
||||||
|
else:
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||||
|
|
||||||
|
# Compressed B, Metadata, Original A, B
|
||||||
|
return b_compressed, e, a, b
|
||||||
|
|
||||||
|
|
||||||
|
def baseline_scaled_mm(a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor,
|
||||||
|
out_dtype: Type[torch.dtype],
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
output = (scale_a * (scale_b * (torch.mm(
|
||||||
|
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||||
|
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||||
|
# Test working with a subset of A and B for sparse matmul
|
||||||
|
def test_cutlass_sparse_subset():
|
||||||
|
big_m = 1024
|
||||||
|
m, n, k = 512, 512, 512
|
||||||
|
|
||||||
|
# Create tensors
|
||||||
|
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
|
||||||
|
big_m, n, k)
|
||||||
|
a = whole_a[0:m, 0:k]
|
||||||
|
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||||
|
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||||
|
|
||||||
|
out = ops.cutlass_scaled_sparse_mm(a,
|
||||||
|
b_comp,
|
||||||
|
e,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=torch.bfloat16)
|
||||||
|
baseline = baseline_scaled_mm(a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||||
@ -10,9 +10,11 @@ from compressed_tensors.quantization import QuantizationType
|
|||||||
|
|
||||||
from tests.models.utils import check_logprobs_close
|
from tests.models.utils import check_logprobs_close
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||||
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
CompressedTensors24, CompressedTensorsLinearMethod,
|
||||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||||
|
CompressedTensorsWNA16)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -208,3 +210,98 @@ def test_compressed_tensors_kv_cache(vllm_runner):
|
|||||||
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
|
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
|
||||||
output = llm.generate_greedy("Hello world!", max_tokens=20)
|
output = llm.generate_greedy("Hello world!", max_tokens=20)
|
||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||||
|
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||||
|
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
|
||||||
|
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, CompressedTensors24)
|
||||||
|
|
||||||
|
assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
|
||||||
|
assert qkv_proj.scheme.input_quant.strategy == input_strategy
|
||||||
|
assert qkv_proj.scheme.quantized
|
||||||
|
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||||
|
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
|
||||||
|
assert sparsity_map.get("Linear").format == "dense"
|
||||||
|
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||||
|
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("args_2of4", [
|
||||||
|
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel",
|
||||||
|
"token"),
|
||||||
|
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
|
||||||
|
"channel", "tensor"),
|
||||||
|
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor",
|
||||||
|
"tensor"),
|
||||||
|
("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
|
||||||
|
"tensor", "token"),
|
||||||
|
])
|
||||||
|
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
|
||||||
|
model, weight_strategy, input_strategy = args_2of4
|
||||||
|
with vllm_runner(model) as llm:
|
||||||
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
|
||||||
|
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
|
||||||
|
|
||||||
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
|
print(output)
|
||||||
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||||
|
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("args_2of4", [
|
||||||
|
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
|
||||||
|
"channel", "token"),
|
||||||
|
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor",
|
||||||
|
"tensor"),
|
||||||
|
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
|
||||||
|
"tensor", "token"),
|
||||||
|
])
|
||||||
|
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
|
||||||
|
model, weight_strategy, input_strategy = args_2of4
|
||||||
|
with vllm_runner(model) as llm:
|
||||||
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
assert qkv_proj.scheme.weights_dtype == torch.int8
|
||||||
|
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
|
||||||
|
|
||||||
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
|
print(output)
|
||||||
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||||
|
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"args_2of4",
|
||||||
|
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
|
||||||
|
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
|
||||||
|
model = args_2of4
|
||||||
|
with vllm_runner(model) as llm:
|
||||||
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, CompressedTensors24)
|
||||||
|
|
||||||
|
assert qkv_proj.scheme.weight_quant is None
|
||||||
|
assert qkv_proj.scheme.input_quant is None
|
||||||
|
assert not qkv_proj.scheme.quantized
|
||||||
|
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||||
|
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
|
||||||
|
assert sparsity_map.get("Linear").format == "dense"
|
||||||
|
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
||||||
|
|
||||||
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
|
print(output)
|
||||||
|
assert output
|
||||||
|
|||||||
@ -21,6 +21,8 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
|||||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||||
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
|
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
|
||||||
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
|
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
|
||||||
|
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
|
||||||
|
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
|
||||||
awq, casperhansen/mixtral-instruct-awq, main
|
awq, casperhansen/mixtral-instruct-awq, main
|
||||||
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||||
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
|
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
|
||||||
|
|||||||
@ -26,6 +26,10 @@ do
|
|||||||
export QUANTIZATION=${array[0]}
|
export QUANTIZATION=${array[0]}
|
||||||
export MODEL_NAME=${array[1]}
|
export MODEL_NAME=${array[1]}
|
||||||
export REVISION=${array[2]}
|
export REVISION=${array[2]}
|
||||||
|
# If array length is larger than 3, then MIN_CAPABILITY is provided
|
||||||
|
if [ ${#array[@]} -gt 3 ]; then
|
||||||
|
export MIN_CAPABILITY=${array[3]}
|
||||||
|
fi
|
||||||
pytest -s weight_loading/test_weight_loading.py || LOCAL_SUCCESS=$?
|
pytest -s weight_loading/test_weight_loading.py || LOCAL_SUCCESS=$?
|
||||||
|
|
||||||
if [[ $LOCAL_SUCCESS == 0 ]]; then
|
if [[ $LOCAL_SUCCESS == 0 ]]; then
|
||||||
|
|||||||
@ -1,14 +1,21 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
MAX_MODEL_LEN = 1024
|
MAX_MODEL_LEN = 1024
|
||||||
MODEL_NAME = os.environ.get("MODEL_NAME",
|
MODEL_NAME = os.environ.get("MODEL_NAME",
|
||||||
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
|
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
|
||||||
REVISION = os.environ.get("REVISION", "main")
|
REVISION = os.environ.get("REVISION", "main")
|
||||||
QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
|
QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
|
||||||
|
MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "89")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(int(MIN_CAPABILITY)),
|
||||||
|
reason="Current system does not have minimum capability.")
|
||||||
def test_weight_loading(vllm_runner):
|
def test_weight_loading(vllm_runner):
|
||||||
"""
|
"""
|
||||||
Test parameter weight loading with tp>1.
|
Test parameter weight loading with tp>1.
|
||||||
|
|||||||
@ -552,6 +552,109 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_sparse_compress(a: torch.Tensor) \
|
||||||
|
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Compresses a sparse matrix for use with Cutlass sparse operations.
|
||||||
|
|
||||||
|
This function takes a dense tensor and compresses it into two components:
|
||||||
|
non-zero elements and metadata. The compressed representation is compatible
|
||||||
|
with Cutlass sparse kernels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (torch.Tensor):
|
||||||
|
The input tensor to be compressed. Must have one of the following data types:
|
||||||
|
- `torch.int8`
|
||||||
|
- `torch.float8_e4m3fn`
|
||||||
|
- `torch.bfloat16`
|
||||||
|
- `torch.float16`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
A tuple containing:
|
||||||
|
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
|
||||||
|
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the compression operation fails.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The `a_meta` tensor has a data type of `torch.uint8`.
|
||||||
|
- Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
|
||||||
|
- The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
|
||||||
|
- The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
|
||||||
|
"""
|
||||||
|
assert (a.dtype in [
|
||||||
|
torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16
|
||||||
|
])
|
||||||
|
assert (a.is_contiguous())
|
||||||
|
|
||||||
|
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
|
||||||
|
elemsPerMetaElem = 4
|
||||||
|
|
||||||
|
m = a.shape[0]
|
||||||
|
k = a.shape[1]
|
||||||
|
assert (k % 2 == 0)
|
||||||
|
a_nzs = torch.empty((m, k // 2), dtype=a.dtype, device=a.device)
|
||||||
|
a_meta = torch.empty((m, k // 2 // elemsPerMetaElem),
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=a.device)
|
||||||
|
|
||||||
|
if not (torch.ops._C.cutlass_sparse_compress_entry(a_nzs, a_meta, a)):
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
assert (a_nzs.is_contiguous())
|
||||||
|
assert (a_meta.is_contiguous())
|
||||||
|
|
||||||
|
return a_nzs, a_meta
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_scaled_sparse_mm(
|
||||||
|
a: torch.Tensor,
|
||||||
|
bt_nzs: torch.Tensor,
|
||||||
|
bt_meta: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Performs a scaled sparse matrix multiplication using Cutlass.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
|
||||||
|
`a = torch.randn((m, k), device='cuda')`.
|
||||||
|
|
||||||
|
2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
|
||||||
|
`b = torch.randn((k, n), device='cuda')`.
|
||||||
|
|
||||||
|
3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
|
||||||
|
`b = prune_to_2_4(b, dim=0)`.
|
||||||
|
|
||||||
|
4. Compress the transposed sparse matrix `b.t()`:
|
||||||
|
`bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.
|
||||||
|
|
||||||
|
5. Perform sparse matrix multiplication using the compressed matrix,
|
||||||
|
applying scaling factors for `a` and `b`, and the output data type:
|
||||||
|
`out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- The result of the scaled sparse matrix multiplication.
|
||||||
|
"""
|
||||||
|
assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0)
|
||||||
|
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||||
|
assert bias is None or bias.shape[0] == bt_nzs.shape[0] \
|
||||||
|
and bias.dtype == out_dtype
|
||||||
|
|
||||||
|
m = a.shape[0]
|
||||||
|
n = bt_nzs.shape[0]
|
||||||
|
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||||
|
|
||||||
|
torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a,
|
||||||
|
scale_b, bias)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
# aqlm
|
# aqlm
|
||||||
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
||||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
from typing import Any, Dict, List, Optional, cast
|
from typing import Any, Dict, List, Literal, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compressed_tensors.config import CompressionFormat
|
from compressed_tensors.config import (CompressionFormat,
|
||||||
|
SparsityCompressionConfig,
|
||||||
|
SparsityStructure)
|
||||||
from compressed_tensors.quantization import (QuantizationArgs,
|
from compressed_tensors.quantization import (QuantizationArgs,
|
||||||
QuantizationStrategy,
|
QuantizationStrategy,
|
||||||
QuantizationType)
|
QuantizationType)
|
||||||
@ -15,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
|||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||||
CompressedTensorsMoEMethod)
|
CompressedTensorsMoEMethod)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
|
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
|
||||||
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
|
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
|
||||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||||
@ -27,20 +29,29 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
__all__ = ["CompressedTensorsLinearMethod"]
|
__all__ = ["CompressedTensorsLinearMethod"]
|
||||||
|
|
||||||
|
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
||||||
|
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsConfig(QuantizationConfig):
|
class CompressedTensorsConfig(QuantizationConfig):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
target_scheme_map: Dict[str, Any],
|
self,
|
||||||
ignore: List[str],
|
target_scheme_map: Dict[str, Any],
|
||||||
quant_format: str,
|
ignore: List[str],
|
||||||
kv_cache_scheme: Optional[Dict[str, Any]] = None):
|
quant_format: str,
|
||||||
|
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
|
||||||
|
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
|
||||||
self.ignore = ignore
|
self.ignore = ignore
|
||||||
self.quant_format = quant_format
|
self.quant_format = quant_format
|
||||||
# Map from [target -> scheme]
|
# Map from [target -> scheme]
|
||||||
self.target_scheme_map = target_scheme_map
|
self.target_scheme_map = target_scheme_map
|
||||||
self.kv_cache_scheme = kv_cache_scheme
|
self.kv_cache_scheme = kv_cache_scheme
|
||||||
|
self.sparsity_scheme_map = sparsity_scheme_map
|
||||||
|
self.config = config
|
||||||
|
|
||||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||||
return CompressedTensorsLinearMethod(self)
|
return CompressedTensorsLinearMethod(self)
|
||||||
@ -78,8 +89,50 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
||||||
|
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
||||||
|
quant_format = cast(str, config.get("format"))
|
||||||
|
target_scheme_map = cls._quantization_scheme_map_from_config(
|
||||||
|
config=config)
|
||||||
|
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(
|
||||||
|
config=config)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
target_scheme_map=target_scheme_map,
|
||||||
|
ignore=ignore,
|
||||||
|
quant_format=quant_format,
|
||||||
|
sparsity_scheme_map=sparsity_scheme_map,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _sparsity_scheme_map_from_config(
|
||||||
|
cls, config: Dict[str,
|
||||||
|
Any]) -> Dict[str, SparsityCompressionConfig]:
|
||||||
|
"""
|
||||||
|
:param config: The `quantization_config` dictionary from config.json
|
||||||
|
:return: A dictionary mapping target layer names to their corresponding
|
||||||
|
sparsity compression configurations
|
||||||
|
"""
|
||||||
|
if (sparsity_config := config.get(SPARSITY_CONFIG_NAME)) is None:
|
||||||
|
return dict()
|
||||||
|
|
||||||
|
sparsity_config = SparsityCompressionConfig.model_validate(
|
||||||
|
sparsity_config)
|
||||||
|
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
|
||||||
|
target: sparsity_config
|
||||||
|
for target in sparsity_config.targets or list()
|
||||||
|
}
|
||||||
|
return sparse_scheme_map
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _quantization_scheme_map_from_config(
|
||||||
|
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
||||||
|
"""
|
||||||
|
:param config: The `quantization_config` dictionary from config.json
|
||||||
|
:return: A dictionary mapping target layer names to their corresponding
|
||||||
|
quantization_args for weights and input activations
|
||||||
|
"""
|
||||||
target_scheme_map: Dict[str, Any] = dict()
|
target_scheme_map: Dict[str, Any] = dict()
|
||||||
ignore = cast(List[str], config.get("ignore"))
|
|
||||||
quant_format = cast(str, config.get("format"))
|
quant_format = cast(str, config.get("format"))
|
||||||
|
|
||||||
# The quant_config has multiple config_groups, each containing
|
# The quant_config has multiple config_groups, each containing
|
||||||
@ -90,12 +143,14 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
# details follow the structure defined by the QuantizationArgs
|
# details follow the structure defined by the QuantizationArgs
|
||||||
# pydantic model, which is used to verify the structure of the
|
# pydantic model, which is used to verify the structure of the
|
||||||
# quant_config and also store the details for later use.
|
# quant_config and also store the details for later use.
|
||||||
for _, quant_config in config["config_groups"].items():
|
|
||||||
|
config_groups = config.get("config_groups", dict())
|
||||||
|
for _, quant_config in config_groups.items():
|
||||||
targets = quant_config.get("targets")
|
targets = quant_config.get("targets")
|
||||||
for target in targets:
|
for target in targets:
|
||||||
target_scheme_map[target] = {}
|
target_scheme_map[target] = {}
|
||||||
target_scheme_map[target][
|
target_scheme_map[target][
|
||||||
"weights"] = QuantizationArgs.parse_obj(
|
"weights"] = QuantizationArgs.model_validate(
|
||||||
quant_config.get("weights"))
|
quant_config.get("weights"))
|
||||||
|
|
||||||
target_scheme_map[target]["input_activations"] = None
|
target_scheme_map[target]["input_activations"] = None
|
||||||
@ -110,13 +165,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
"weights"].type == QuantizationType.FLOAT
|
"weights"].type == QuantizationType.FLOAT
|
||||||
else:
|
else:
|
||||||
target_scheme_map[target][
|
target_scheme_map[target][
|
||||||
"input_activations"] = QuantizationArgs.parse_obj(
|
"input_activations"] = QuantizationArgs.model_validate( # noqa: E501
|
||||||
quant_config.get("input_activations"))
|
quant_config.get("input_activations"))
|
||||||
|
return target_scheme_map
|
||||||
return cls(target_scheme_map=target_scheme_map,
|
|
||||||
ignore=ignore,
|
|
||||||
quant_format=quant_format,
|
|
||||||
kv_cache_scheme=config.get("kv_cache_scheme"))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
@ -315,23 +366,105 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
# TODO (@robertgshaw): add compressed-tensors as dep
|
# TODO (@robertgshaw): add compressed-tensors as dep
|
||||||
# so we do not have to re-write these functions
|
# so we do not have to re-write these functions
|
||||||
# need to make accelerate optional in ct to do this
|
# need to make accelerate optional in ct to do this
|
||||||
matched_target = find_matched_target(
|
|
||||||
layer_name=layer_name,
|
|
||||||
module=layer,
|
|
||||||
targets=self.target_scheme_map.keys())
|
|
||||||
|
|
||||||
# Find the quant_scheme
|
# Will be empty for models with only sparsity
|
||||||
scheme_dict = self.target_scheme_map[matched_target]
|
if self.target_scheme_map:
|
||||||
scheme = self._get_scheme_from_parts(
|
matched_target = find_matched_target(
|
||||||
weight_quant=scheme_dict["weights"],
|
layer_name=layer_name,
|
||||||
input_quant=scheme_dict["input_activations"])
|
module=layer,
|
||||||
|
targets=self.target_scheme_map.keys())
|
||||||
|
|
||||||
|
scheme_dict = self.target_scheme_map[matched_target]
|
||||||
|
weight_quant = scheme_dict.get("weights")
|
||||||
|
input_quant = scheme_dict.get("input_activations")
|
||||||
|
elif self.sparsity_scheme_map:
|
||||||
|
matched_target = find_matched_target(
|
||||||
|
layer_name=layer_name,
|
||||||
|
module=layer,
|
||||||
|
targets=self.sparsity_scheme_map.keys())
|
||||||
|
weight_quant = None
|
||||||
|
input_quant = None
|
||||||
|
|
||||||
|
# For models with sparsity, assumes that the sparse layers are also
|
||||||
|
# quantized for cutlass 2:4 support
|
||||||
|
sparsity_scheme: Optional[
|
||||||
|
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
|
||||||
|
matched_target)
|
||||||
|
|
||||||
|
if self.supports_cutlass_24(weight_quant=weight_quant,
|
||||||
|
input_quant=input_quant,
|
||||||
|
sparsity_scheme=sparsity_scheme):
|
||||||
|
# Have a valid sparsity scheme
|
||||||
|
# Validate layer is supported by Cutlass 2:4 Kernel
|
||||||
|
scheme = CompressedTensors24(quantized=weight_quant is not None
|
||||||
|
or input_quant is not None,
|
||||||
|
weight_quant=weight_quant,
|
||||||
|
input_quant=input_quant)
|
||||||
|
else:
|
||||||
|
# Find the quant_scheme
|
||||||
|
scheme = self._get_scheme_from_parts( # type: ignore
|
||||||
|
weight_quant=weight_quant,
|
||||||
|
input_quant=input_quant,
|
||||||
|
)
|
||||||
|
|
||||||
# Raise error if device does not support the scheme
|
# Raise error if device does not support the scheme
|
||||||
# (e.g. fp8 needs ada lovelace)
|
# (e.g. fp8 needs ada lovelace)
|
||||||
self._check_scheme_supported(scheme.get_min_capability())
|
self._check_scheme_supported(scheme.get_min_capability())
|
||||||
|
|
||||||
return scheme
|
return scheme
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def supports_cutlass_24(
|
||||||
|
weight_quant: Optional[QuantizationArgs],
|
||||||
|
input_quant: Optional[QuantizationArgs],
|
||||||
|
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the layer is supported by the Cutlass 2:4 Kernel
|
||||||
|
Conditions:
|
||||||
|
- Overarching condition: Sparsity Structure is 2:4
|
||||||
|
- Unquantized cases are supported
|
||||||
|
- Weight only quantization is not-supported
|
||||||
|
- Supported weight quantization strategies are TENSOR and CHANNEL
|
||||||
|
- Supported input quantization strategies are TENSOR and TOKEN
|
||||||
|
- Only 8 bit quantization is supported
|
||||||
|
|
||||||
|
:return: True if the layer is supported by the Cutlass 2:4 Kernel
|
||||||
|
False otherwise
|
||||||
|
"""
|
||||||
|
is_valid_sparsity = (sparsity_scheme is not None
|
||||||
|
and sparsity_scheme.sparsity_structure
|
||||||
|
== SparsityStructure.TWO_FOUR.value
|
||||||
|
and sparsity_scheme.format == "dense")
|
||||||
|
if not is_valid_sparsity:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Unquantized cases are supported
|
||||||
|
if weight_quant is None and input_quant is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Weight only quantization is not-supported
|
||||||
|
if weight_quant is not None and input_quant is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
supported_weight_quant_strategies = [
|
||||||
|
QuantizationStrategy.TENSOR.value,
|
||||||
|
QuantizationStrategy.CHANNEL.value
|
||||||
|
]
|
||||||
|
|
||||||
|
assert weight_quant is not None
|
||||||
|
assert input_quant is not None
|
||||||
|
if weight_quant.strategy not in supported_weight_quant_strategies:
|
||||||
|
return False
|
||||||
|
|
||||||
|
supported_input_quant_strategies = [
|
||||||
|
QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value
|
||||||
|
]
|
||||||
|
|
||||||
|
if input_quant.strategy not in supported_input_quant_strategies:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return weight_quant.num_bits == input_quant.num_bits == 8
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsLinearMethod(LinearMethodBase):
|
class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||||
|
|
||||||
|
|||||||
@ -7,13 +7,12 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
|||||||
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
|
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
|
||||||
CompressedTensorsWNA16)
|
CompressedTensorsWNA16)
|
||||||
|
|
||||||
|
from .compressed_tensors_24 import CompressedTensors24 # isort: skip
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CompressedTensorsScheme",
|
"CompressedTensorsScheme", "CompressedTensorsWNA16",
|
||||||
"CompressedTensorsWNA16",
|
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
|
||||||
"CompressedTensorsW8A16Fp8",
|
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
|
||||||
"CompressedTensorsW4A16Sparse24",
|
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
|
||||||
"CompressedTensorsW8A8Int8",
|
"CompressedTensors24"
|
||||||
"CompressedTensorsW8A8Fp8",
|
|
||||||
"WNA16_SUPPORTED_BITS",
|
|
||||||
"W4A16SPARSE24_SUPPORTED_BITS",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -0,0 +1,203 @@
|
|||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from compressed_tensors.quantization import (QuantizationArgs,
|
||||||
|
QuantizationStrategy,
|
||||||
|
QuantizationType)
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
convert_to_channelwise)
|
||||||
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
ChannelQuantScaleParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
|
PerTensorScaleParameter)
|
||||||
|
|
||||||
|
__all__ = ["CompressedTensors24"]
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensors24(CompressedTensorsScheme):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
quantized: bool = False,
|
||||||
|
weight_quant: Optional[QuantizationArgs] = None,
|
||||||
|
input_quant: Optional[QuantizationArgs] = None):
|
||||||
|
|
||||||
|
self.quantized = quantized
|
||||||
|
self.weight_quant = weight_quant
|
||||||
|
self.input_quant = input_quant
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
# Only cutlass 3.x kernels are implemented so far
|
||||||
|
return 90
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
self.output_dtype = params_dtype
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
|
||||||
|
|
||||||
|
# parameter to store uncompressed weight
|
||||||
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
|
sum(output_partition_sizes),
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=self.weights_dtype),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
|
# Check if quantized, not just 2:4 Sparse
|
||||||
|
if self.quantized:
|
||||||
|
if (self.weight_quant and self.weight_quant.strategy
|
||||||
|
== QuantizationStrategy.CHANNEL.value):
|
||||||
|
weight_scale = ChannelQuantScaleParameter(
|
||||||
|
data=torch.empty((sum(output_partition_sizes), 1),
|
||||||
|
dtype=torch.float32),
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
else:
|
||||||
|
assert (self.weight_quant and self.weight_quant.strategy
|
||||||
|
== QuantizationStrategy.TENSOR.value)
|
||||||
|
weight_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.empty(len(output_partition_sizes),
|
||||||
|
dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
# input quant will be non-none
|
||||||
|
if self.input_quant and not self.input_quant.dynamic:
|
||||||
|
# register input quant scale
|
||||||
|
assert (self.input_quant.strategy ==
|
||||||
|
QuantizationStrategy.TENSOR.value)
|
||||||
|
input_scale = BasevLLMParameter(data=torch.empty(
|
||||||
|
1, dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# for sparse-only, pass in 1 for weight/input scales
|
||||||
|
weight_scale = torch.nn.Parameter(data=torch.ones(
|
||||||
|
1, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
input_scale = torch.nn.Parameter(data=torch.ones(
|
||||||
|
1, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
"""
|
||||||
|
Compress weights after loading. Store compressed weight and meta
|
||||||
|
tensor
|
||||||
|
|
||||||
|
:post-condition: layer.w_compressed and layer.meta are
|
||||||
|
set to the compressed weight and meta tensor in the
|
||||||
|
format expected by the Cutlass kernels
|
||||||
|
:param layer: The layer with the weights to be processed
|
||||||
|
|
||||||
|
"""
|
||||||
|
# torch.compile workaround
|
||||||
|
if hasattr(layer, "input_scale"):
|
||||||
|
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
if self.weight_quant:
|
||||||
|
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
|
||||||
|
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
logical_widths=layer.logical_widths),
|
||||||
|
requires_grad=False)
|
||||||
|
else:
|
||||||
|
# torch.compile workaround
|
||||||
|
layer.weight_scale = torch.nn.Parameter(
|
||||||
|
layer.weight_scale.data, requires_grad=False)
|
||||||
|
|
||||||
|
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
|
||||||
|
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
|
||||||
|
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Returns the output tensor for the layer with 2:4
|
||||||
|
sparse compressed weights, given the input tensor
|
||||||
|
and bias
|
||||||
|
|
||||||
|
:param layer: The layer with 2:4 sparse compressed
|
||||||
|
weights to be used for the computation
|
||||||
|
:param x: The input tensor to the layer
|
||||||
|
:param bias: The bias to be added to the output tensor
|
||||||
|
:return: The output tensor of the layer
|
||||||
|
"""
|
||||||
|
if self.quantized:
|
||||||
|
scale = None
|
||||||
|
if hasattr(layer, "input_scale"):
|
||||||
|
scale = layer.input_scale
|
||||||
|
|
||||||
|
if self.weights_dtype == torch.int8:
|
||||||
|
ops_output = ops.scaled_int8_quant(x, scale=scale)
|
||||||
|
q_input = ops_output[0]
|
||||||
|
input_scale = ops_output[1]
|
||||||
|
else:
|
||||||
|
assert self.weights_dtype == torch.float8_e4m3fn
|
||||||
|
if scale is not None:
|
||||||
|
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
|
||||||
|
else:
|
||||||
|
q_input, input_scale = ops.scaled_fp8_quant(
|
||||||
|
x, use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Not quantized, nothing to do with the input_scales, use as is
|
||||||
|
input_scale = layer.input_scale
|
||||||
|
q_input = x
|
||||||
|
|
||||||
|
out = ops.cutlass_scaled_sparse_mm(a=q_input,
|
||||||
|
bt_nzs=layer.weight,
|
||||||
|
bt_meta=layer.meta,
|
||||||
|
scale_a=input_scale,
|
||||||
|
scale_b=layer.weight_scale,
|
||||||
|
out_dtype=self.output_dtype,
|
||||||
|
bias=bias)
|
||||||
|
assert out.is_contiguous()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
|
||||||
|
if not self.quantized:
|
||||||
|
return params_dtype
|
||||||
|
|
||||||
|
assert self.weight_quant is not None
|
||||||
|
assert self.input_quant is not None
|
||||||
|
|
||||||
|
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
|
||||||
|
|
||||||
|
if not is_8_bits:
|
||||||
|
raise ValueError("Cutlass only supports 8-bit quantization")
|
||||||
|
|
||||||
|
if (self.weight_quant.type == QuantizationType.FLOAT
|
||||||
|
and self.input_quant.type == QuantizationType.FLOAT):
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
|
||||||
|
if (self.weight_quant.type == QuantizationType.INT
|
||||||
|
and self.input_quant.type == QuantizationType.INT):
|
||||||
|
return torch.int8
|
||||||
|
|
||||||
|
raise ValueError("Quantization type not supported by Cutlass")
|
||||||
|
|
||||||
|
|
||||||
|
def check_24(tensor):
|
||||||
|
new_tensor = tensor.view(-1, 4)
|
||||||
|
zero_counts = (new_tensor == 0).sum(dim=1)
|
||||||
|
return (zero_counts >= 2).all().item()
|
||||||
Loading…
x
Reference in New Issue
Block a user