Merge ba590e4e33a4d2328cf1d2a81c73d68d11ce4b68 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
Matthias Gehre 2025-12-25 08:22:28 +08:00 committed by GitHub
commit 9add9f13c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 9 deletions

View File

@ -9,6 +9,7 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
@ -495,13 +496,20 @@ def awq_dequantize(
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy)
def awq_gemm(
def _awq_gemm(
input: torch.Tensor,
qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int,
) -> torch.Tensor:
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
return torch.matmul(input, out)
if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
@ -509,6 +517,25 @@ def awq_gemm(
return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
def _awq_gemm_fake_impl(
input: torch.Tensor,
qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int,
) -> torch.Tensor:
M, N = input.shape[0], qweight.shape[1] * 8
return torch.empty((M, N), dtype=scales.dtype, device=input.device)
direct_register_custom_op(
op_name="awq_gemm",
op_func=_awq_gemm,
fake_impl=_awq_gemm_fake_impl,
)
awq_gemm = torch.ops.vllm.awq_gemm
# gptq
def gptq_gemm(
a: torch.Tensor,

View File

@ -264,14 +264,7 @@ class AWQLinearMethod(LinearMethodBase):
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
reshaped_x = x.reshape(-1, x.shape[-1])
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
out = torch.matmul(reshaped_x, out)
else:
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)