mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 18:07:32 +08:00
[Bugfix] awq_gemm: fix argument order swap (#30364)
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
3224ea9915
commit
e9add129ad
@ -41,9 +41,9 @@ def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
qweight = torch.randint(
|
||||
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
scales = torch.randint(
|
||||
scales = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
|
||||
qzeros = torch.randint(
|
||||
-2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
|
||||
split_k_iters = 8
|
||||
opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters))
|
||||
opcheck(torch.ops._C.awq_gemm, (input, qweight, scales, qzeros, split_k_iters))
|
||||
|
||||
@ -498,15 +498,15 @@ def awq_dequantize(
|
||||
def awq_gemm(
|
||||
input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
split_k_iters: int,
|
||||
) -> torch.Tensor:
|
||||
if envs.VLLM_USE_TRITON_AWQ:
|
||||
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
|
||||
|
||||
return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
|
||||
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
|
||||
return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
|
||||
return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
|
||||
|
||||
|
||||
# gptq
|
||||
@ -632,8 +632,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
def _awq_gemm_fake(
|
||||
input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
split_k_iters: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
num_in_feats = input.size(0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user