From ba590e4e33a4d2328cf1d2a81c73d68d11ce4b68 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 16 Dec 2025 14:40:06 +0100 Subject: [PATCH] AWQ: Evaluate fused vs unfused GEMM on actual shape Before this PR, the condition ``` FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256 if FP16_MATMUL_HEURISTIC_CONDITION: out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0) return torch.matmul(input, out) else: return awq_gemm(...) ``` was evaluated during `torch.compile` based on `max-num-batched-tokens`. By default, `max-num-batched-tokens` is over the threshold, which meant that `awq_gemm` was never taken, even when doing a single request decode. To evaluate the condition for each specific shape during during cudaGraph capture, this PR wraps `awq_gemm` into a torch custom op, which shields it from being traced through. Signed-off-by: Matthias Gehre --- vllm/_custom_ops.py | 29 ++++++++++++++++++- .../model_executor/layers/quantization/awq.py | 9 +----- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 2319655008c50..ff709d7aedf44 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -9,6 +9,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.scalar_type import ScalarType +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -495,13 +496,20 @@ def awq_dequantize( return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) -def awq_gemm( +def _awq_gemm( input: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, split_k_iters: int, ) -> torch.Tensor: + # num_tokens >= threshold + FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + return torch.matmul(input, out) + if envs.VLLM_USE_TRITON_AWQ: from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton @@ -509,6 +517,25 @@ def awq_gemm( return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters) +def _awq_gemm_fake_impl( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, +) -> torch.Tensor: + M, N = input.shape[0], qweight.shape[1] * 8 + return torch.empty((M, N), dtype=scales.dtype, device=input.device) + + +direct_register_custom_op( + op_name="awq_gemm", + op_func=_awq_gemm, + fake_impl=_awq_gemm_fake_impl, +) +awq_gemm = torch.ops.vllm.awq_gemm + + # gptq def gptq_gemm( a: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index ab68c5dca52c0..a8c7065aa2261 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -264,14 +264,7 @@ class AWQLinearMethod(LinearMethodBase): out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) reshaped_x = x.reshape(-1, x.shape[-1]) - # num_tokens >= threshold - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) - out = torch.matmul(reshaped_x, out) - else: - out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: out.add_(bias) return out.reshape(out_shape)