diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 261f5829631e..3da4cecd7eef 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -5,14 +5,16 @@ import torch.utils.benchmark as benchmark from benchmark_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MarlinWorkspace, marlin_24_quantize, marlin_quantize) + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace, marlin_quantize) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( + marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, quantize_weights, sort_weights) from vllm.utils import FlexibleArgumentParser diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 92ddcb209b69..3bd6680cf813 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -5,19 +5,21 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`. import pytest import torch +from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS, - marlin_permute_scales) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) -from vllm.model_executor.layers.quantization.utils.marlin_perms import ( - marlin_perm) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize, - marlin_quantize, marlin_weights, pack_fp8_to_int32) + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS, + marlin_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + pack_fp8_to_int32) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace, get_weight_perm, marlin_quantize, marlin_weights) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( + marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, quantize_weights, sort_weights) @@ -42,11 +44,16 @@ MNK_FACTORS = [ DTYPES = [torch.float16, torch.bfloat16] +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + def rand_data(shape, dtype=torch.float16): return torch.randn(shape, dtype=dtype, device="cuda") -@pytest.mark.skipif(not is_marlin_supported(), +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @@ -93,8 +100,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Pack to Marlin format - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, - marlin_perm[num_bits]) + weight_perm = get_weight_perm(num_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( @@ -109,7 +116,7 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, assert torch.allclose(marlin_q_w_1, marlin_q_w_2) -@pytest.mark.skipif(not is_marlin_supported(), +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @@ -174,7 +181,7 @@ def test_marlin_gemm( assert max_diff < 0.04 -@pytest.mark.skipif(not is_marlin_supported(), +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) @@ -222,7 +229,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors): assert max_diff < 0.04 -@pytest.mark.skipif(not is_marlin_supported(), +@pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @@ -268,13 +275,10 @@ def test_fp8_marlin_gemm( # expand it to channelwise scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda") # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, - size_k=size_k, - size_n=size_n, - group_size=-1, - num_bits=8, - ) + marlin_scales = marlin_permute_scales(s=scales, + size_k=size_k, + size_n=size_n, + group_size=-1) workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 96223a247657..888e20e51a84 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -6,7 +6,6 @@ Run `pytest tests/quantization/test_compressed_tensors.py`. import pytest import torch -from vllm import SamplingParams from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, @@ -57,12 +56,14 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): assert qkv_proj.weight_scale.dtype is torch.float32 assert qkv_proj.input_scale.dtype is torch.float32 + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + def test_compressed_tensors_no_enforce_eager(vllm_runner): model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" with vllm_runner(model_path) as llm: - sampling_params = SamplingParams() - output = llm.generate("Hello world!", sampling_params=sampling_params) + output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output @@ -84,13 +85,16 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args): assert qkv_proj.scheme.strategy == strategy assert qkv_proj.weight.dtype is torch.int8 + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + @pytest.mark.parametrize( "wNa16_args", [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)]) -def test_compressed_tensors_w4a16(vllm_runner, wNa16_args): +def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): model, strategy, group, pack_factor = wNa16_args with vllm_runner(model) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 @@ -101,12 +105,15 @@ def test_compressed_tensors_w4a16(vllm_runner, wNa16_args): assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16) assert qkv_proj.scheme.strategy == strategy - assert qkv_proj.scheme.group_size == group + assert qkv_proj.scheme.group_size == (-1 if group is None else group) assert qkv_proj.weight_packed.dtype is torch.int32 assert qkv_proj.weight_scale.dtype is torch.float16 assert qkv_proj.weight_packed.pack_factor == pack_factor + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + def test_compressed_tensors_w4a16_marlin24(vllm_runner): model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t" @@ -120,8 +127,7 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner): assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24) assert qkv_proj.weight_packed.dtype is torch.int32 - sampling_params = SamplingParams() - output = llm.generate("Hello world!", sampling_params=sampling_params) + output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output @@ -142,6 +148,5 @@ def test_compressed_tensors_fp8(vllm_runner): assert len(qkv_proj.input_scale.shape) == 0 assert len(qkv_proj.weight_scale.shape) == 0 - sampling_params = SamplingParams() - output = llm.generate("Hello world!", sampling_params=sampling_params) + output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 2243260053ef..ed9fa73c175a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -6,9 +6,10 @@ from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState, - marlin_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, + marlin_permute_scales, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) from vllm.model_executor.utils import set_weight_attrs __all__ = ["CompressedTensorsWNA16"] @@ -22,29 +23,40 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): num_bits: int, group_size: Optional[int] = None): self.num_bits = num_bits + self.pack_factor = 32 // self.num_bits self.strategy = strategy - self.group_size = group_size - if self.strategy == "group" and self.group_size is None: - raise ValueError( - "group_size must be given when using strategy group") + self.group_size: int + if group_size is None: + if self.strategy != "channel": + raise ValueError( + "Marlin kernels require group quantization or " + "channelwise quantization, but found no group " + "size and strategy is not channelwise.") + self.group_size = -1 + else: + self.group_size = group_size - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass + # Verify supported on platform. + verify_marlin_supported(num_bits=self.num_bits, + group_size=self.group_size, + is_sym=True) def create_weights(self, layer: torch.nn.Module, input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - - pack_factor = 32 // self.num_bits output_size_per_partition = sum(output_partition_sizes) - if self.group_size is not None: - group_size = self.group_size - else: - group_size = input_size + # If group_size is -1, we are in channelwise case. + group_size = input_size if self.group_size == -1 else self.group_size + + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=group_size) weight_scale_dim = None scales_and_zp_size = input_size // group_size @@ -57,7 +69,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): weight = Parameter( torch.empty( output_size_per_partition, - input_size_per_partition // pack_factor, + input_size_per_partition // self.pack_factor, dtype=torch.int32, ), requires_grad=False, @@ -68,7 +80,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): "input_dim": 1, "output_dim": 0, "packed_dim": 1, - "pack_factor": pack_factor, + "pack_factor": self.pack_factor, "weight_loader": weight_loader }) layer.register_parameter("weight_packed", weight) @@ -103,73 +115,48 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.marlin_state = GPTQMarlinState.REPACK - layer.is_k_full = True layer.group_size = group_size - max_workspace_size = ( - output_size_per_partition // - GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL + # Checkpoints are serialized in compressed-tensors format, which is + # different from marlin format. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = layer.weight_packed.device - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - requires_grad=False) - layer.workspace = workspace + # Allocate marlin workspace. + layer.workspace = marlin_make_workspace( + layer.output_size_per_partition, device) + + # Act-order not supported in compressed-tensors yet, so set to empty. + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + # Repack weights from compressed-tensors format to marlin format. + marlin_qweight = ops.gptq_marlin_repack( + layer.weight_packed.t().contiguous(), + perm=layer.g_idx_sort_indices, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + num_bits=self.num_bits) + replace_tensor(layer, "weight_packed", marlin_qweight) + + # Permute scales from compressed-tensors format to marlin format. + marlin_scales = marlin_permute_scales( + layer.weight_scale.squeeze().t().contiguous(), + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + group_size=layer.group_size) + replace_tensor(layer, "weight_scale", marlin_scales) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): - reshaped_x = x.reshape(-1, x.shape[-1]) - - size_m = reshaped_x.shape[0] - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - out_shape = x.shape[:-1] + (part_size_n, ) - - if layer.marlin_state == GPTQMarlinState.REPACK: - layer.marlin_state = GPTQMarlinState.READY - - # Newly generated tensors need to replace existing tensors that are - # already registered as parameters by vLLM (and won't be freed) - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - cur_device = layer.weight_packed.device - - # Reset g_idx related tensors - layer.g_idx = Parameter(torch.empty(0, - dtype=torch.int, - device=cur_device), - requires_grad=False) - layer.g_idx_sort_indices = Parameter(torch.empty( - 0, dtype=torch.int, device=cur_device), - requires_grad=False) - - # Repack weights - marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed.t().contiguous(), layer.g_idx_sort_indices, - part_size_k, part_size_n, self.num_bits) - - replace_tensor("weight_packed", marlin_qweight) - - # Permute scales - scales_size_k = part_size_k - scales_size_n = part_size_n - - marlin_scales = marlin_permute_scales( - layer.weight_scale.squeeze().t().contiguous(), scales_size_k, - scales_size_n, layer.group_size, self.num_bits) - replace_tensor("weight_scale", marlin_scales) - - output = ops.gptq_marlin_gemm(reshaped_x, layer.weight_packed, - layer.weight_scale, layer.g_idx, - layer.g_idx_sort_indices, - layer.workspace, self.num_bits, size_m, - part_size_n, part_size_k, - layer.is_k_full) - return output.reshape(out_shape) + return apply_marlin_linear( + input=x, + weight=layer.weight_packed, + weight_scale=layer.weight_scale, + g_idx=layer.g_idx, + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=layer.workspace, + num_bits=self.num_bits, + output_size_per_partition=layer.output_size_per_partition, + input_size_per_partition=layer.input_size_per_partition, + is_k_full=True) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8dba9019f94c..0c2d2bd3fabe 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, apply_fp8_linear, create_per_tensor_scale_param, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 6b971f73d45b..7b808f5216d5 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,5 +1,3 @@ -import enum -from enum import Enum from typing import Any, Dict, List, Optional import torch @@ -12,46 +10,14 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_K, - GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, - GPTQ_MARLIN_SUPPORTED_NUM_BITS, GPTQ_MARLIN_SUPPORTED_SYM, - GPTQ_MARLIN_TILE) + check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, + marlin_permute_scales, marlin_sort_g_idx, replace_tensor, + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.platforms import current_platform logger = init_logger(__name__) -# Permutations for Marlin scale shuffling -def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def get_pack_factor(num_bits: int): - assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS - ), f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, - group_size: int, num_bits: int): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" @@ -63,33 +29,16 @@ class GPTQMarlinConfig(QuantizationConfig): desc_act = False self.weight_bits = weight_bits + self.pack_factor = 32 // self.weight_bits # packed into int32 self.group_size = group_size self.desc_act = desc_act self.is_sym = is_sym self.lm_head_quantized = lm_head_quantized - # Verify - if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: - raise ValueError( - f"Marlin does not support weight_bits = {self.weight_bits}. " - f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} " - "are supported.") - if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: - raise ValueError( - f"Marlin does not support group_size = {self.group_size}. " - f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.") - if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM: - raise ValueError( - f"Marlin does not support is_sym = {self.is_sym}. " - f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.") - - # Init - self.pack_factor = get_pack_factor(weight_bits) - self.tile_size = GPTQ_MARLIN_TILE - self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N - self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K - self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL + # Verify supported on platform. + verify_marlin_supported(num_bits=self.weight_bits, + group_size=self.group_size, + is_sym=self.is_sym) def __repr__(self) -> str: return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " @@ -168,21 +117,10 @@ class GPTQMarlinConfig(QuantizationConfig): or desc_act is None): return False - # If the capability of the device is too low, cannot convert. - major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor - if device_capability < cls.get_min_capability(): - return False - - # Otherwise, can convert if model satisfies marlin constraints. - return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS - and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES - and sym in GPTQ_MARLIN_SUPPORTED_SYM) - - -class GPTQMarlinState(Enum): - REPACK = enum.auto() - READY = enum.auto() + return check_marlin_supported(num_bits=num_bits, + group_size=group_size, + is_sym=sym, + min_capability=cls.get_min_capability()) class GPTQMarlinLinearMethod(LinearMethodBase): @@ -206,6 +144,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): **extra_weight_attrs, ) -> None: del output_size + output_size_per_partition = sum(output_partition_sizes) # Normalize group_size if self.quant_config.group_size != -1: @@ -213,31 +152,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase): else: group_size = input_size - # Validate dtype - if params_dtype not in [torch.float16, torch.bfloat16]: - raise ValueError(f"The params dtype must be float16 " - f"or bfloat16, but got {params_dtype}") - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_thread_n != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {self.quant_config.min_thread_n}.") - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_thread_k != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {self.quant_config.min_thread_k}.") - - if (group_size < input_size - and input_size_per_partition % group_size != 0): - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}.") + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=group_size) # Detect sharding of scales/zp @@ -303,11 +222,6 @@ class GPTQMarlinLinearMethod(LinearMethodBase): }, ) - g_idx_sort_indices = torch.empty( - g_idx.shape, - dtype=torch.int32, - ) - # Scales scales = Parameter( torch.empty( @@ -347,25 +261,50 @@ class GPTQMarlinLinearMethod(LinearMethodBase): }, ) - # Allocate marlin workspace - max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_thread_n) * self.quant_config.max_parallel - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - requires_grad=False) - layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) - layer.g_idx_sort_indices = g_idx_sort_indices - layer.workspace = workspace layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.input_size = input_size layer.is_k_full = is_k_full - layer.marlin_state = GPTQMarlinState.REPACK + + # Checkpoints are serialized in AutoGPTQ format, which is different from the + # marlin format. This function is called after the weights are loaded. + # Here, we handle the repacking, including the activation reordering case. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = layer.qweight.device + # Allocate marlin workspace + layer.workspace = marlin_make_workspace( + layer.output_size_per_partition, device) + + # Handle sorting for activation reordering if needed. + if self.quant_config.desc_act: + g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + replace_tensor(layer, "g_idx", g_idx) + else: + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + # Repack weights from autogptq format to marlin format. + marlin_qweight = ops.gptq_marlin_repack( + layer.qweight, + perm=layer.g_idx_sort_indices, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "qweight", marlin_qweight) + + # Permute scales from autogptq format to marlin format. + marlin_scales = marlin_permute_scales( + layer.scales, + size_k=(layer.input_size if self.quant_config.desc_act else + layer.input_size_per_partition), + size_n=layer.output_size_per_partition, + group_size=self.quant_config.group_size) + replace_tensor(layer, "scales", marlin_scales) def apply( self, @@ -374,87 +313,19 @@ class GPTQMarlinLinearMethod(LinearMethodBase): bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (layer.output_size_per_partition, ) - size_m = reshaped_x.shape[0] - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - full_size_k = layer.input_size - - out_shape = x.shape[:-1] + (part_size_n, ) - - if layer.marlin_state == GPTQMarlinState.REPACK: - layer.marlin_state = GPTQMarlinState.READY - - # Newly generated tensors need to replace existing tensors that are - # already registered as parameters by vLLM (and won't be freed) - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - cur_device = layer.qweight.device - - # Process act_order - if self.quant_config.desc_act: - # Get sorting based on g_idx - g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int) - - sorted_g_idx = layer.g_idx[g_idx_sort_indices] - - replace_tensor("g_idx", sorted_g_idx) - replace_tensor("g_idx_sort_indices", g_idx_sort_indices) - - else: - # Reset g_idx related tensors - layer.g_idx = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - layer.g_idx_sort_indices = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - - # Repack weights - marlin_qweight = ops.gptq_marlin_repack( - layer.qweight, - layer.g_idx_sort_indices, - part_size_k, - part_size_n, - self.quant_config.weight_bits, - ) - replace_tensor("qweight", marlin_qweight) - - # Permute scales - scales_size_k = part_size_k - scales_size_n = part_size_n - if self.quant_config.desc_act: - scales_size_k = full_size_k - - marlin_scales = marlin_permute_scales( - layer.scales, - scales_size_k, - scales_size_n, - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("scales", marlin_scales) - - output = ops.gptq_marlin_gemm( - reshaped_x, - layer.qweight, - layer.scales, - layer.g_idx, - layer.g_idx_sort_indices, - layer.workspace, - self.quant_config.weight_bits, - size_m, - part_size_n, - part_size_k, - layer.is_k_full, - ) + output = ops.gptq_marlin_gemm(reshaped_x, + layer.qweight, + layer.scales, + g_idx=layer.g_idx, + perm=layer.g_idx_sort_indices, + workspace=layer.workspace, + num_bits=self.quant_config.weight_bits, + size_m=reshaped_x.shape[0], + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + is_k_full=layer.is_k_full) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py deleted file mode 100644 index 93f65a20d4e4..000000000000 --- a/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py +++ /dev/null @@ -1,60 +0,0 @@ -"""This file is used for /tests and /benchmarks""" -from typing import Dict, List - -import numpy -import torch - - -# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501 -# -# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 -# with the tensor-core format that is described here: -# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 -# -# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 -# (without the need to use ldmatrix instructions) # noqa: E501 -def get_perms_24(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - col_o = col // 2 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + - 4 * block) - for j in range(4): - perm_list.extend([p + 1 * j for p in perm1]) - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: List[int] = [] - for i in range(8): - scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) - return perm, scale_perm, scale_perm_single - - -marlin_24_perm: Dict[int, torch.Tensor] = {} -marlin_24_scale_perm: Dict[int, List[int]] = {} -marlin_24_scale_perm_single: Dict[int, List[int]] = {} -for num_bits in [4, 8]: - perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) - marlin_24_perm[num_bits] = perm_24 - marlin_24_scale_perm[num_bits] = scale_perm_24 - marlin_24_scale_perm_single[num_bits] = scale_perm_single_24 diff --git a/vllm/model_executor/layers/quantization/utils/marlin_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_perms.py deleted file mode 100644 index db5e6857a884..000000000000 --- a/vllm/model_executor/layers/quantization/utils/marlin_perms.py +++ /dev/null @@ -1,60 +0,0 @@ -"""This file is used for /tests and /benchmarks""" -from typing import Dict, List - -import numpy -import torch - - -# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 -# -# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 -# with the tensor-core format that is described here: -# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 -# -# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 -# (without the need to use ldmatrix instructions) # noqa: E501 -def get_perms(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return perm, scale_perm, scale_perm_single - - -marlin_perm: Dict[int, torch.Tensor] = {} -marlin_scale_perm: Dict[int, List[int]] = {} -marlin_scale_perm_single: Dict[int, List[int]] = {} -for num_bits in [4, 8]: - perm, scale_perm, scale_perm_single = get_perms(num_bits) - marlin_perm[num_bits] = perm - marlin_scale_perm[num_bits] = scale_perm - marlin_scale_perm_single[num_bits] = scale_perm_single diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 9886245269ad..612c5fd20093 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -1,21 +1,9 @@ -"""This file is used for /tests and /benchmarks""" -import random -from typing import Optional +from typing import List, Optional, Tuple -import numpy import torch from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.format_24 import ( - mask_creator, sparse_semi_structured_from_dense_cutlass) -from vllm.model_executor.layers.quantization.utils.marlin_24_perms import ( - marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single) -from vllm.model_executor.layers.quantization.utils.marlin_perms import ( - marlin_perm, marlin_scale_perm, marlin_scale_perm_single) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - get_pack_factor, quantize_weights, sort_weights) from vllm.platforms import current_platform -from vllm.utils import print_warning_once GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_MIN_THREAD_N = 64 @@ -25,135 +13,110 @@ GPTQ_MARLIN_MAX_PARALLEL = 16 GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] GPTQ_MARLIN_SUPPORTED_SYM = [True] +GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER = [-1] -def is_marlin_supported(): - capability = current_platform.get_device_capability() - return capability[0] >= 8 +def check_marlin_supported(num_bits: int, group_size: int, is_sym: bool, + min_capability: int) -> bool: + + # If the capability of the device is too low, cannot convert. + major, minor = current_platform.get_device_capability() + device_capability = major * 10 + minor + if device_capability < min_capability: + return False + + return (device_capability >= min_capability + and num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS + and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES + and is_sym in GPTQ_MARLIN_SUPPORTED_SYM) -def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization +def verify_marlin_supported(num_bits: int, group_size: Optional[int], + is_sym: bool) -> None: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n, ) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) + if num_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + raise ValueError( + f"Marlin does not support weight_bits = {num_bits}. " + f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} " + "are supported.") + if (group_size is None + or group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES): + raise ValueError( + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + if is_sym not in GPTQ_MARLIN_SUPPORTED_SYM: + raise ValueError( + f"Marlin does not support is_sym = is_sym. " + f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.") -def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None: - print_warning_once( - "Your GPU does not have native support for FP8 computation but " - "FP8 quantization is being used. Weight-only FP8 compression will " - "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads.") +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") - device = layer.weight.device + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") - # WEIGHTS - # Repack weights to gptq format (packed int32 elements) - packed_gptq_qweight = pack_fp8_to_int32(layer.weight) + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=packed_gptq_qweight, - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - # WEIGHT SCALES - # Currently Marlin doesn't support per-tensor scales, so we - # expand it to channelwise - scales = layer.weight_scale.repeat(1, part_size_n).to( - layer.orig_dtype).to(device) - # Permute scales - num_bits = 8 - marlin_scales = marlin_permute_scales( - s=scales, - size_k=part_size_k, - size_n=part_size_n, - group_size=-1, - scale_perm=marlin_scale_perm[num_bits], - scale_perm_single=marlin_scale_perm_single[num_bits]) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - # Allocate marlin workspace - max_workspace_size = (part_size_n // +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device=device, - requires_grad=False) - layer.workspace = workspace + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): - assert q_w.shape == (size_k, size_n) - assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" - assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) - q_w = q_w.permute((0, 2, 1, 3)) - q_w = q_w.reshape((size_k // tile, size_n * tile)) - - q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) - - return q_w +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) -def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices -def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, - scale_perm_single): +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: @@ -163,180 +126,44 @@ def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, return s -def marlin_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, - act_order: bool, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, - act_order) - - # For act_order, sort the "weights" and "g_idx" so that group ids are - # increasing - sort_indices = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) - - # Reformat to marlin - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, - marlin_perm[num_bits]) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, - marlin_scale_perm[num_bits], - marlin_scale_perm_single[num_bits]) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_tensor(layer: torch.nn.Module, name: str, + new_t: torch.Tensor) -> None: + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t -def inject_24(w, size_k, size_n): - assert w.shape == (size_k, size_n) +def apply_marlin_linear(input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) - mask = mask_creator(w.t()).t().cuda().bool() + output = ops.gptq_marlin_gemm(reshaped_x, + weight, + weight_scale, + g_idx, + g_idx_sort_indices, + workspace, + num_bits, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full) - return (mask * w).contiguous(), mask.contiguous() + if bias is not None: + output.add_(bias) # In-place add - -def check_24(w, num_rows_to_sample=50, _verbose=False): - BLOCK_SIZE = 4 - MAX_NON_ZEROS = 2 - - w = w.t().contiguous() - - print("check_24: w.shape = {}".format(w.shape)) - - num_rows, num_cols = w.shape - sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) - if _verbose: - print(f"Sampled row idxs = {sampled_row_idxs}") - - total_segments = 0 - non_24_segments = 0 - for i in sampled_row_idxs: - for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): - total_segments += 1 - block = w[i, j:j + BLOCK_SIZE] - num_nonzero = torch.count_nonzero(block) - if num_nonzero > MAX_NON_ZEROS: - print("i = {} j = {} block = {}".format(i, j, block)) - non_24_segments += 1 - - print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") - - -def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): - assert q_24.shape == (size_k, size_n) - - # Remove zp to normalize over 0 - max_q_val = (1 << num_bits) - 1 - zp = (max_q_val + 1) // 2 - q_24_no_zp = q_24 - zp - - # Compress - q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( - q_24_no_zp) - q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - - # Restore zp - q_24_comp = q_24_no_zp_comp + zp - - # Resize meta to its actual shape (without moving any data) - meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) - - return q_24_comp, meta - - -def marlin_24_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Inject 2:4 sparsity - w_24, mask_24 = inject_24(w, size_k, size_n) - - # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, - num_bits, - group_size, - act_order=False) - - # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, - num_bits) - size_k_comp = size_k // 2 - - # Reformat to marlin - marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, - num_bits, marlin_24_perm[num_bits]) - marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size, - marlin_24_scale_perm[num_bits], - marlin_24_scale_perm_single[num_bits]) - - # Create result - res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list - - -def compute_max_diff(output, output_ref): - return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) - - -class MarlinWorkspace: - - def __init__(self, out_features, min_thread_n, max_parallel): - assert (out_features % min_thread_n == 0), ( - "out_features = {} is undivisible by min_thread_n = {}".format( - out_features, min_thread_n)) - - max_workspace_size = ((out_features // min_thread_n) * max_parallel) - - self.scratch = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda") - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements) - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = (byte_tensor[:, 0].to(torch.int32) | - (byte_tensor[:, 1].to(torch.int32) << 8) | - (byte_tensor[:, 2].to(torch.int32) << 16) | - (byte_tensor[:, 3].to(torch.int32) << 24)) - - return packed.view(fp8_tensor.shape[0] // 4, - *fp8_tensor.shape[1:]).contiguous() + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 000000000000..e93eb747ba2e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,109 @@ +from typing import Optional + +import torch + +import vllm._custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils import print_warning_once + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = current_platform.get_device_capability() + return capability[0] >= 8 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None: + print_warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32( + layer.weight), + perm=torch.empty(0, + dtype=torch.int, + device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Currently Marlin doesn't support per-tensor scales, so we + # expand it to channelwise + scales = layer.weight_scale.repeat(1, part_size_n).to( + layer.orig_dtype).to(device) + # Permute scales + marlin_scales = marlin_permute_scales(s=scales, + size_k=part_size_k, + size_n=part_size_n, + group_size=-1) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = (byte_tensor[:, 0].to(torch.int32) | + (byte_tensor[:, 1].to(torch.int32) << 8) | + (byte_tensor[:, 2].to(torch.int32) << 16) | + (byte_tensor[:, 3].to(torch.int32) << 24)) + + return packed.view(fp8_tensor.shape[0] // 4, + *fp8_tensor.shape[1:]).contiguous() diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py new file mode 100644 index 000000000000..1773748a0f22 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -0,0 +1,120 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List + +import numpy +import torch + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales +from .quant_utils import get_pack_factor, quantize_weights, sort_weights + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert (out_features % min_thread_n == 0), ( + "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n)) + + max_workspace_size = ((out_features // min_thread_n) * max_parallel) + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, + act_order: bool): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, + act_order) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/format_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py similarity index 71% rename from vllm/model_executor/layers/quantization/utils/format_24.py rename to vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py index 01c8cf789204..648c32249a57 100644 --- a/vllm/model_executor/layers/quantization/utils/format_24.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -1,9 +1,14 @@ -# -# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es). -# +"""Utility functions used for tests and benchmarks""" +import random +from typing import List + +import numpy import torch +from .marlin_utils_test import marlin_weights +from .quant_utils import quantize_weights + # This is PyTorch implementation of main part of reorder_meta() # function, from tools/util/include/cutlass/util/host_reorder.h file @@ -306,3 +311,155 @@ def mask_creator(tensor): mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j:j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): + assert q_24.shape == (size_k, size_n) + + # Remove zp to normalize over 0 + max_q_val = (1 << num_bits) - 1 + zp = (max_q_val + 1) // 2 + q_24_no_zp = q_24 - zp + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( + q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore zp + q_24_comp = q_24_no_zp_comp + zp + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, + num_bits, + group_size, + act_order=False) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, + num_bits) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(num_bits) + marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, + num_bits, weight_perm) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list