[Bugfix] awq_gemm: fix argument order swap (#30364)

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Matthias Gehre 2025-12-14 11:15:37 +01:00 committed by GitHub
parent 3224ea9915
commit e9add129ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 7 deletions

View File

@ -41,9 +41,9 @@ def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
qweight = torch.randint(
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
)
scales = torch.randint(
scales = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
qzeros = torch.randint(
-2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
)
qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
split_k_iters = 8
opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters))
opcheck(torch.ops._C.awq_gemm, (input, qweight, scales, qzeros, split_k_iters))

View File

@ -498,15 +498,15 @@ def awq_dequantize(
def awq_gemm(
input: torch.Tensor,
qweight: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int,
) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
# gptq
@ -632,8 +632,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def _awq_gemm_fake(
input: torch.Tensor,
qweight: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: torch.SymInt,
) -> torch.Tensor:
num_in_feats = input.size(0)