mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-07 12:59:08 +08:00
[CI/Build][AMD] Fix import errors in tests/kernels/attention (#29032)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
parent
2c52c7fd9a
commit
322cb02872
@ -7,11 +7,19 @@ import torch
|
|||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states
|
from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states
|
||||||
from vllm.vllm_flash_attn import (
|
|
||||||
fa_version_unsupported_reason,
|
try:
|
||||||
flash_attn_varlen_func,
|
from vllm.vllm_flash_attn import (
|
||||||
is_fa_version_supported,
|
fa_version_unsupported_reason,
|
||||||
)
|
flash_attn_varlen_func,
|
||||||
|
is_fa_version_supported,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
pytest.skip(
|
||||||
|
"vllm_flash_attn is not supported for vLLM on ROCm.",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||||
HEAD_SIZES = [128, 192, 256]
|
HEAD_SIZES = [128, 192, 256]
|
||||||
|
|||||||
@ -6,11 +6,20 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.vllm_flash_attn import (
|
|
||||||
fa_version_unsupported_reason,
|
try:
|
||||||
flash_attn_varlen_func,
|
from vllm.vllm_flash_attn import (
|
||||||
is_fa_version_supported,
|
fa_version_unsupported_reason,
|
||||||
)
|
flash_attn_varlen_func,
|
||||||
|
is_fa_version_supported,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
pytest.skip(
|
||||||
|
"vllm_flash_attn is not supported for vLLM on ROCm.",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
NUM_HEADS = [(4, 4), (8, 2)]
|
NUM_HEADS = [(4, 4), (8, 2)]
|
||||||
HEAD_SIZES = [40, 72, 80, 128, 256]
|
HEAD_SIZES = [40, 72, 80, 128, 256]
|
||||||
|
|||||||
@ -2,12 +2,20 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
import flashinfer
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flashinfer
|
||||||
|
except ImportError:
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
pytest.skip(
|
||||||
|
"flashinfer is not supported for vLLM on ROCm.", allow_module_level=True
|
||||||
|
)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
NUM_HEADS = [(32, 8), (6, 1)]
|
NUM_HEADS = [(32, 8), (6, 1)]
|
||||||
HEAD_SIZES = [128, 256]
|
HEAD_SIZES = [128, 256]
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16, 32]
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -15,6 +14,8 @@ if not current_platform.has_device_capability(100):
|
|||||||
reason="FlashInfer MLA Requires compute capability of 10 or above.",
|
reason="FlashInfer MLA Requires compute capability of 10 or above.",
|
||||||
allow_module_level=True,
|
allow_module_level=True,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||||
|
|
||||||
|
|
||||||
def ref_mla(
|
def ref_mla(
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import flashinfer
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -16,6 +15,8 @@ if not current_platform.is_device_capability(100):
|
|||||||
pytest.skip(
|
pytest.skip(
|
||||||
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
import flashinfer
|
||||||
|
|
||||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|||||||
@ -22,7 +22,14 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
|
||||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
|
||||||
|
try:
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
except ImportError:
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
pytest.skip(
|
||||||
|
"flashinfer not supported for vLLM on ROCm", allow_module_level=True
|
||||||
|
)
|
||||||
|
|
||||||
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
||||||
90
|
90
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user