mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 01:03:44 +08:00
AWQ: Evaluate fused vs unfused GEMM on actual shape
Before this PR, the condition
```
FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
return torch.matmul(input, out)
else:
return awq_gemm(...)
```
was evaluated during `torch.compile` based on `max-num-batched-tokens`.
By default, `max-num-batched-tokens` is over the threshold,
which meant that `awq_gemm` was never taken, even when doing a single
request decode.
To evaluate the condition for each specific shape during during
cudaGraph capture, this PR wraps `awq_gemm` into a torch custom op,
which shields it from being traced through.
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
This commit is contained in:
parent
676db55eec
commit
ba590e4e33
@ -9,6 +9,7 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -495,13 +496,20 @@ def awq_dequantize(
|
||||
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy)
|
||||
|
||||
|
||||
def awq_gemm(
|
||||
def _awq_gemm(
|
||||
input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
split_k_iters: int,
|
||||
) -> torch.Tensor:
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256
|
||||
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
return torch.matmul(input, out)
|
||||
|
||||
if envs.VLLM_USE_TRITON_AWQ:
|
||||
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
|
||||
|
||||
@ -509,6 +517,25 @@ def awq_gemm(
|
||||
return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
|
||||
|
||||
|
||||
def _awq_gemm_fake_impl(
|
||||
input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
split_k_iters: int,
|
||||
) -> torch.Tensor:
|
||||
M, N = input.shape[0], qweight.shape[1] * 8
|
||||
return torch.empty((M, N), dtype=scales.dtype, device=input.device)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="awq_gemm",
|
||||
op_func=_awq_gemm,
|
||||
fake_impl=_awq_gemm_fake_impl,
|
||||
)
|
||||
awq_gemm = torch.ops.vllm.awq_gemm
|
||||
|
||||
|
||||
# gptq
|
||||
def gptq_gemm(
|
||||
a: torch.Tensor,
|
||||
|
||||
@ -264,14 +264,7 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
||||
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
else:
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out.reshape(out_shape)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user