diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py new file mode 100644 index 0000000000000..1d40f4915a1be --- /dev/null +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +import importlib.util +import logging + +import torch + +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.utils import direct_register_custom_op + +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None +if has_deep_gemm: + import deep_gemm + +logger = logging.getLogger(__name__) + + +def prepare_block_fp8_matmul_inputs( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, +) -> tuple[int, int, int, torch.Tensor]: + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + assert A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 + assert B.is_contiguous() + assert Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N, ) + C = A.new_empty(C_shape, dtype=output_dtype) + + return M, N, K, C + + +def w8a8_block_fp8_matmul_deepgemm( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, + output_dtype) + # Deepgemm only supports output tensor type as bfloat16 + assert C.dtype == torch.bfloat16 + deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + return C + + +def w8a8_block_fp8_matmul_deepgemm_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, + output_dtype) + return C + + +direct_register_custom_op( + op_name="w8a8_block_fp8_matmul_deepgemm", + op_func=w8a8_block_fp8_matmul_deepgemm, + mutates_args=[], + fake_impl=w8a8_block_fp8_matmul_deepgemm_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c785e0d1674da..b3042bfaed3d7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -402,6 +402,7 @@ class Fp8LinearMethod(LinearMethodBase): if self.block_quant: assert self.quant_config.weight_block_size is not None + return torch.ops.vllm.apply_w8a8_block_fp8_linear( input=x, weight=layer.weight, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 08dc99e07597b..3d67c09de58e8 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -3,12 +3,14 @@ # Adapted from https://github.com/sgl-project/sglang/pull/2575 import functools +import importlib.util import json import os from typing import Any, Callable, Optional, Union import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -20,6 +22,7 @@ from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op logger = init_logger(__name__) +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: @@ -98,6 +101,19 @@ def dispatch_w8a8_blockscale_func( return w8a8_block_fp8_matmul +def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor): + """ + Check if DeepGEMM should be used based on the output dtype and weight shape. + DeepGEMM is only supported for bfloat16 output dtype and weights with shape + divisible by 128. + """ + + return (current_platform.is_cuda() + and current_platform.is_device_capability(90) and has_deep_gemm + and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16 + and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + + # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 def apply_w8a8_block_fp8_linear( @@ -114,6 +130,29 @@ def apply_w8a8_block_fp8_linear( # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] + output_dtype = input.dtype + + if should_use_deepgemm(output_dtype, weight): + + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = per_token_group_quant_fp8( + input_2d, + block_size[1], + column_major_scales=True, + ) + + output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( + q_input, + weight, + x_scale, + weight_scale, + block_size, + output_dtype=output_dtype) + if bias is not None: + output += bias + return output.to(dtype=output_dtype).view(*output_shape) if current_platform.is_cuda(): if current_platform.has_device_capability(100): @@ -134,7 +173,6 @@ def apply_w8a8_block_fp8_linear( w8a8_blockscale_func = dispatch_w8a8_blockscale_func( use_cutlass, use_aiter_and_is_supported) - if use_cutlass: q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass)