mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:34:57 +08:00
[Bugfix] Try to handle older versions of pytorch (#9086)
This commit is contained in:
parent
de24046fcd
commit
bd37b9fbe2
@ -1,11 +1,14 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
|
||||
reason="AWQ is not supported on this GPU type.")
|
||||
def test_awq_dequantize_opcheck():
|
||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||
qweight = torch.randint(-2000000000,
|
||||
@ -21,6 +24,8 @@ def test_awq_dequantize_opcheck():
|
||||
(qweight, scales, zeros, split_k_iters, thx, thy))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
|
||||
reason="AWQ is not supported on this GPU type.")
|
||||
def test_awq_gemm_opcheck():
|
||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
|
||||
torch_moe_single)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe, single_marlin_moe)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
@ -21,6 +22,9 @@ from vllm.scalar_type import scalar_types
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.skipif(not (ops.supports_moe_ops
|
||||
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
def test_fused_marlin_moe_awq(
|
||||
m: int,
|
||||
n: int,
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.library
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._core_ext import ScalarType
|
||||
@ -25,6 +26,16 @@ with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
supports_moe_ops = True
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def register_fake(fn):
|
||||
return lambda name: fn
|
||||
else:
|
||||
try:
|
||||
from torch.library import register_fake
|
||||
except ImportError:
|
||||
from torch.library import impl_abstract as register_fake
|
||||
|
||||
|
||||
def hint_on_error(fn):
|
||||
|
||||
@ -266,7 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
|
||||
if hasattr(torch.ops._C, "gptq_gemm"):
|
||||
|
||||
@torch.library.register_fake("_C::gptq_gemm")
|
||||
@register_fake("_C::gptq_gemm")
|
||||
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_gptq_qzeros: torch.Tensor,
|
||||
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
|
||||
@ -301,7 +312,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
|
||||
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
|
||||
@torch.library.register_fake("_C::gptq_marlin_24_gemm")
|
||||
@register_fake("_C::gptq_marlin_24_gemm")
|
||||
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
@ -309,7 +320,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
size_n: int, size_k: int) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||
|
||||
@torch.library.register_fake("_C::gptq_marlin_gemm")
|
||||
@register_fake("_C::gptq_marlin_gemm")
|
||||
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
@ -326,12 +337,12 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
use_fp32_reduce: bool = False) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||
|
||||
@torch.library.register_fake("_C::ggml_dequantize")
|
||||
@register_fake("_C::ggml_dequantize")
|
||||
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
|
||||
n: int) -> torch.Tensor:
|
||||
return torch.empty((m, n), dtype=torch.float16, device=W.device)
|
||||
|
||||
@torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
|
||||
@register_fake("_C::ggml_mul_mat_vec_a8")
|
||||
def _ggml_mul_mat_vec_a8_fake(
|
||||
W: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
@ -340,7 +351,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((1, row), dtype=torch.float16, device=W.device)
|
||||
|
||||
@torch.library.register_fake("_C::ggml_mul_mat_a8")
|
||||
@register_fake("_C::ggml_mul_mat_a8")
|
||||
def _ggml_mul_mat_a8_fake(
|
||||
W: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
@ -350,7 +361,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
batch = X.size(0)
|
||||
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
|
||||
|
||||
@torch.library.register_fake("_C::marlin_qqq_gemm")
|
||||
@register_fake("_C::marlin_qqq_gemm")
|
||||
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
||||
s_group: torch.Tensor, workspace: torch.Tensor,
|
||||
@ -360,7 +371,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
dtype=torch.float16,
|
||||
device=a.device)
|
||||
|
||||
@torch.library.register_fake("_C::marlin_gemm")
|
||||
@register_fake("_C::marlin_gemm")
|
||||
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||
size_m: int, size_n: int,
|
||||
@ -369,7 +380,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
dtype=torch.float16,
|
||||
device=a.device)
|
||||
|
||||
@torch.library.register_fake("_C::awq_dequantize")
|
||||
@register_fake("_C::awq_dequantize")
|
||||
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
|
||||
zeros: torch.Tensor, split_k_iters: int, thx: int,
|
||||
thy: int) -> torch.Tensor:
|
||||
@ -380,7 +391,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
dtype=scales.dtype,
|
||||
device=scales.device)
|
||||
|
||||
@torch.library.register_fake("_C::awq_gemm")
|
||||
@register_fake("_C::awq_gemm")
|
||||
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
|
||||
qzeros: torch.Tensor, scales: torch.Tensor,
|
||||
split_k_iters: int) -> torch.Tensor:
|
||||
@ -389,7 +400,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
dtype=input.dtype,
|
||||
device=input.device).sum(0)
|
||||
|
||||
@torch.library.register_fake("_C::aqlm_gemm")
|
||||
@register_fake("_C::aqlm_gemm")
|
||||
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
|
||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||
codebook_partition_sizes: List[int],
|
||||
@ -405,7 +416,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
output_sizes.append(-1)
|
||||
return flat_output.reshape(tuple(output_sizes))
|
||||
|
||||
@torch.library.register_fake("_C::aqlm_dequant")
|
||||
@register_fake("_C::aqlm_dequant")
|
||||
def _aqlm_dequant_fake(
|
||||
codes: torch.Tensor, codebooks: torch.Tensor,
|
||||
codebook_partition_sizes: List[int]) -> torch.Tensor:
|
||||
@ -415,14 +426,14 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
dtype=codebooks.dtype,
|
||||
device=codebooks.device)
|
||||
|
||||
@torch.library.register_fake("_C::fp8_marlin_gemm")
|
||||
@register_fake("_C::fp8_marlin_gemm")
|
||||
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||
num_bits: int, size_m: int, size_n: int,
|
||||
size_k: int) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
||||
|
||||
@torch.library.register_fake("_C::machete_gemm")
|
||||
@register_fake("_C::machete_gemm")
|
||||
def machete_gemm_fake(
|
||||
a: torch.Tensor,
|
||||
# Should be the tensor returned by machete_prepack_B
|
||||
@ -440,13 +451,13 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
n = b_q.size(1)
|
||||
return torch.empty((m, n), device=a.device, dtype=a.dtype)
|
||||
|
||||
@torch.library.register_fake("_C::machete_prepack_B")
|
||||
@register_fake("_C::machete_prepack_B")
|
||||
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
|
||||
b_type: ScalarType) -> torch.Tensor:
|
||||
return torch.empty_like(b_q_weight,
|
||||
memory_format=torch.contiguous_format)
|
||||
|
||||
@torch.library.register_fake("_C::causal_conv1d_fwd")
|
||||
@register_fake("_C::causal_conv1d_fwd")
|
||||
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
conv_states: Optional[torch.Tensor],
|
||||
@ -456,7 +467,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
silu_activation: bool) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@torch.library.register_fake("_C::causal_conv1d_update")
|
||||
@register_fake("_C::causal_conv1d_update")
|
||||
def causal_conv1d_update_fake(
|
||||
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor], silu_activation: bool,
|
||||
@ -464,7 +475,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@torch.library.register_fake("_C::selective_scan_fwd")
|
||||
@register_fake("_C::selective_scan_fwd")
|
||||
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
|
||||
A: torch.Tensor, B: torch.Tensor,
|
||||
C: torch.Tensor, D_: Optional[torch.Tensor],
|
||||
@ -639,7 +650,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
|
||||
|
||||
if hasattr(torch.ops._C, "permute_cols"):
|
||||
|
||||
@torch.library.register_fake("_C::permute_cols")
|
||||
@register_fake("_C::permute_cols")
|
||||
def _permute_cols_fake(a: torch.Tensor,
|
||||
perm: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(a)
|
||||
@ -837,7 +848,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
|
||||
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
||||
|
||||
@torch.library.register_fake("_moe_C::marlin_gemm_moe")
|
||||
@register_fake("_moe_C::marlin_gemm_moe")
|
||||
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
|
||||
sorted_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user