mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 01:47:00 +08:00
AWQ: Evaluate fused vs unfused GEMM on actual shape
Before this PR, the condition
```
FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
return torch.matmul(input, out)
else:
return awq_gemm(...)
```
was evaluated during `torch.compile` based on `max-num-batched-tokens`.
By default, `max-num-batched-tokens` is over the threshold,
which meant that `awq_gemm` was never taken, even when doing a single
request decode.
To evaluate the condition for each specific shape during during
cudaGraph capture, this PR wraps `awq_gemm` into a torch custom op,
which shields it from being traced through.
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
This commit is contained in:
parent
676db55eec
commit
ba590e4e33
@ -9,6 +9,7 @@ import vllm.envs as envs
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import ScalarType
|
from vllm.scalar_type import ScalarType
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -495,13 +496,20 @@ def awq_dequantize(
|
|||||||
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy)
|
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy)
|
||||||
|
|
||||||
|
|
||||||
def awq_gemm(
|
def _awq_gemm(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
qweight: torch.Tensor,
|
qweight: torch.Tensor,
|
||||||
scales: torch.Tensor,
|
scales: torch.Tensor,
|
||||||
qzeros: torch.Tensor,
|
qzeros: torch.Tensor,
|
||||||
split_k_iters: int,
|
split_k_iters: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# num_tokens >= threshold
|
||||||
|
FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256
|
||||||
|
|
||||||
|
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||||
|
out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||||
|
return torch.matmul(input, out)
|
||||||
|
|
||||||
if envs.VLLM_USE_TRITON_AWQ:
|
if envs.VLLM_USE_TRITON_AWQ:
|
||||||
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
|
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
|
||||||
|
|
||||||
@ -509,6 +517,25 @@ def awq_gemm(
|
|||||||
return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
|
return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
|
||||||
|
|
||||||
|
|
||||||
|
def _awq_gemm_fake_impl(
|
||||||
|
input: torch.Tensor,
|
||||||
|
qweight: torch.Tensor,
|
||||||
|
scales: torch.Tensor,
|
||||||
|
qzeros: torch.Tensor,
|
||||||
|
split_k_iters: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
M, N = input.shape[0], qweight.shape[1] * 8
|
||||||
|
return torch.empty((M, N), dtype=scales.dtype, device=input.device)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="awq_gemm",
|
||||||
|
op_func=_awq_gemm,
|
||||||
|
fake_impl=_awq_gemm_fake_impl,
|
||||||
|
)
|
||||||
|
awq_gemm = torch.ops.vllm.awq_gemm
|
||||||
|
|
||||||
|
|
||||||
# gptq
|
# gptq
|
||||||
def gptq_gemm(
|
def gptq_gemm(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
|||||||
@ -264,14 +264,7 @@ class AWQLinearMethod(LinearMethodBase):
|
|||||||
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
||||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
|
||||||
# num_tokens >= threshold
|
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
|
||||||
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
|
||||||
|
|
||||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
|
||||||
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
|
||||||
out = torch.matmul(reshaped_x, out)
|
|
||||||
else:
|
|
||||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
out.add_(bias)
|
out.add_(bias)
|
||||||
return out.reshape(out_shape)
|
return out.reshape(out_shape)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user