AWQ: Evaluate fused vs unfused GEMM on actual shape

Before this PR, the condition
```
FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
    out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
    return torch.matmul(input, out)
else:
    return awq_gemm(...)
```
was evaluated during `torch.compile` based on `max-num-batched-tokens`.
By default, `max-num-batched-tokens` is over the threshold,
which meant that `awq_gemm` was never taken, even when doing a single
request decode.

To evaluate the condition for each specific shape during during
cudaGraph capture, this PR wraps `awq_gemm` into a torch custom op,
which shields it from being traced through.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
This commit is contained in:
Matthias Gehre 2025-12-16 14:40:06 +01:00
parent 676db55eec
commit ba590e4e33
2 changed files with 29 additions and 9 deletions

View File

@ -9,6 +9,7 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
@ -495,13 +496,20 @@ def awq_dequantize(
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy)
def awq_gemm(
def _awq_gemm(
input: torch.Tensor,
qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int,
) -> torch.Tensor:
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
return torch.matmul(input, out)
if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
@ -509,6 +517,25 @@ def awq_gemm(
return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
def _awq_gemm_fake_impl(
input: torch.Tensor,
qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int,
) -> torch.Tensor:
M, N = input.shape[0], qweight.shape[1] * 8
return torch.empty((M, N), dtype=scales.dtype, device=input.device)
direct_register_custom_op(
op_name="awq_gemm",
op_func=_awq_gemm,
fake_impl=_awq_gemm_fake_impl,
)
awq_gemm = torch.ops.vllm.awq_gemm
# gptq
def gptq_gemm(
a: torch.Tensor,

View File

@ -264,14 +264,7 @@ class AWQLinearMethod(LinearMethodBase):
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
reshaped_x = x.reshape(-1, x.shape[-1])
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
out = torch.matmul(reshaped_x, out)
else:
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)