mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 02:44:28 +08:00
[CI/Build][AMD] Fix import errors in tests/kernels/attention (#29032)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
parent
2c52c7fd9a
commit
322cb02872
@ -7,11 +7,19 @@ import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states
|
||||
from vllm.vllm_flash_attn import (
|
||||
fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import (
|
||||
fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
except ImportError:
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"vllm_flash_attn is not supported for vLLM on ROCm.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 192, 256]
|
||||
|
||||
@ -6,11 +6,20 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.vllm_flash_attn import (
|
||||
fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import (
|
||||
fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
except ImportError:
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"vllm_flash_attn is not supported for vLLM on ROCm.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [40, 72, 80, 128, 256]
|
||||
|
||||
@ -2,12 +2,20 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
try:
|
||||
import flashinfer
|
||||
except ImportError:
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"flashinfer is not supported for vLLM on ROCm.", allow_module_level=True
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
NUM_HEADS = [(32, 8), (6, 1)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
from torch import Tensor
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
@ -15,6 +14,8 @@ if not current_platform.has_device_capability(100):
|
||||
reason="FlashInfer MLA Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
else:
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
|
||||
def ref_mla(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -16,6 +15,8 @@ if not current_platform.is_device_capability(100):
|
||||
pytest.skip(
|
||||
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
||||
)
|
||||
else:
|
||||
import flashinfer
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
@ -22,7 +22,14 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
try:
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
except ImportError:
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"flashinfer not supported for vLLM on ROCm", allow_module_level=True
|
||||
)
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
||||
90
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user