mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 19:54:50 +08:00
Merge ba590e4e33a4d2328cf1d2a81c73d68d11ce4b68 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
9add9f13c9
@ -9,6 +9,7 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -495,13 +496,20 @@ def awq_dequantize(
|
||||
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy)
|
||||
|
||||
|
||||
def awq_gemm(
|
||||
def _awq_gemm(
|
||||
input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
split_k_iters: int,
|
||||
) -> torch.Tensor:
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = input.shape[0] >= 256
|
||||
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
out = awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
return torch.matmul(input, out)
|
||||
|
||||
if envs.VLLM_USE_TRITON_AWQ:
|
||||
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
|
||||
|
||||
@ -509,6 +517,25 @@ def awq_gemm(
|
||||
return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
|
||||
|
||||
|
||||
def _awq_gemm_fake_impl(
|
||||
input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
split_k_iters: int,
|
||||
) -> torch.Tensor:
|
||||
M, N = input.shape[0], qweight.shape[1] * 8
|
||||
return torch.empty((M, N), dtype=scales.dtype, device=input.device)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="awq_gemm",
|
||||
op_func=_awq_gemm,
|
||||
fake_impl=_awq_gemm_fake_impl,
|
||||
)
|
||||
awq_gemm = torch.ops.vllm.awq_gemm
|
||||
|
||||
|
||||
# gptq
|
||||
def gptq_gemm(
|
||||
a: torch.Tensor,
|
||||
|
||||
@ -264,14 +264,7 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
||||
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
else:
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out.reshape(out_shape)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user