[Bugfix] Try to handle older versions of pytorch (#9086)

This commit is contained in:
bnellnm 2024-10-08 17:28:12 -04:00 committed by GitHub
parent de24046fcd
commit bd37b9fbe2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 21 deletions

View File

@ -1,11 +1,14 @@
import os
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
reason="AWQ is not supported on this GPU type.")
def test_awq_dequantize_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
qweight = torch.randint(-2000000000,
@ -21,6 +24,8 @@ def test_awq_dequantize_opcheck():
(qweight, scales, zeros, split_k_iters, thx, thy))
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
reason="AWQ is not supported on this GPU type.")
def test_awq_gemm_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)

View File

@ -7,6 +7,7 @@ import torch
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
@ -21,6 +22,9 @@ from vllm.scalar_type import scalar_types
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.skipif(not (ops.supports_moe_ops
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
reason="Marlin is not supported on this GPU type.")
def test_fused_marlin_moe_awq(
m: int,
n: int,

View File

@ -1,8 +1,9 @@
import contextlib
import functools
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch.library
import vllm.envs as envs
from vllm._core_ext import ScalarType
@ -25,6 +26,16 @@ with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
supports_moe_ops = True
if TYPE_CHECKING:
def register_fake(fn):
return lambda name: fn
else:
try:
from torch.library import register_fake
except ImportError:
from torch.library import impl_abstract as register_fake
def hint_on_error(fn):
@ -266,7 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
if hasattr(torch.ops._C, "gptq_gemm"):
@torch.library.register_fake("_C::gptq_gemm")
@register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
@ -301,7 +312,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@torch.library.register_fake("_C::gptq_marlin_24_gemm")
@register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
@ -309,7 +320,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
size_n: int, size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::gptq_marlin_gemm")
@register_fake("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
@ -326,12 +337,12 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::ggml_dequantize")
@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
@register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
@ -340,7 +351,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::ggml_mul_mat_a8")
@register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
@ -350,7 +361,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::marlin_qqq_gemm")
@register_fake("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
@ -360,7 +371,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=torch.float16,
device=a.device)
@torch.library.register_fake("_C::marlin_gemm")
@register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
@ -369,7 +380,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=torch.float16,
device=a.device)
@torch.library.register_fake("_C::awq_dequantize")
@register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
@ -380,7 +391,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=scales.dtype,
device=scales.device)
@torch.library.register_fake("_C::awq_gemm")
@register_fake("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: int) -> torch.Tensor:
@ -389,7 +400,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=input.dtype,
device=input.device).sum(0)
@torch.library.register_fake("_C::aqlm_gemm")
@register_fake("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int],
@ -405,7 +416,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
output_sizes.append(-1)
return flat_output.reshape(tuple(output_sizes))
@torch.library.register_fake("_C::aqlm_dequant")
@register_fake("_C::aqlm_dequant")
def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor:
@ -415,14 +426,14 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=codebooks.dtype,
device=codebooks.device)
@torch.library.register_fake("_C::fp8_marlin_gemm")
@register_fake("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
@torch.library.register_fake("_C::machete_gemm")
@register_fake("_C::machete_gemm")
def machete_gemm_fake(
a: torch.Tensor,
# Should be the tensor returned by machete_prepack_B
@ -440,13 +451,13 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
n = b_q.size(1)
return torch.empty((m, n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::machete_prepack_B")
@register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)
@torch.library.register_fake("_C::causal_conv1d_fwd")
@register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
@ -456,7 +467,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
silu_activation: bool) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.register_fake("_C::causal_conv1d_update")
@register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool,
@ -464,7 +475,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.register_fake("_C::selective_scan_fwd")
@register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor],
@ -639,7 +650,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
if hasattr(torch.ops._C, "permute_cols"):
@torch.library.register_fake("_C::permute_cols")
@register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
@ -837,7 +848,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@torch.library.register_fake("_moe_C::marlin_gemm_moe")
@register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,