diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c1519fc177250..b60e5fc97ac73 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -9,6 +9,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.scalar_type import ScalarType +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -495,13 +496,20 @@ def awq_dequantize( return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) -def awq_gemm( +def _awq_gemm( input: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, split_k_iters: int, ) -> torch.Tensor: + # num_tokens >= threshold + FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + return torch.matmul(input, out) + if envs.VLLM_USE_TRITON_AWQ: from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton @@ -509,6 +517,25 @@ def awq_gemm( return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters) +def _awq_gemm_fake_impl( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, +) -> torch.Tensor: + M, N = input.shape[0], qweight.shape[1] * 8 + return torch.empty((M, N), dtype=scales.dtype, device=input.device) + + +direct_register_custom_op( + op_name="awq_gemm", + op_func=_awq_gemm, + fake_impl=_awq_gemm_fake_impl, +) +awq_gemm = torch.ops.vllm.awq_gemm + + # gptq def gptq_gemm( a: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index ab68c5dca52c0..a8c7065aa2261 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -264,14 +264,7 @@ class AWQLinearMethod(LinearMethodBase): out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) reshaped_x = x.reshape(-1, x.shape[-1]) - # num_tokens >= threshold - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) - out = torch.matmul(reshaped_x, out) - else: - out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: out.add_(bias) return out.reshape(out_shape)