mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 21:44:39 +08:00
[CI/Build] Make test_attention_selector.py run tests on correct platform (#29064)
Signed-off-by: Randall Smith <ransmith@amd.com> Signed-off-by: rasmith <Randall.Smith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
3d84ef9054
commit
5e5a7eb16f
@ -7,6 +7,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
@ -47,9 +48,11 @@ DEVICE_MLA_BLOCK_SIZES = {
|
||||
|
||||
|
||||
def generate_params():
|
||||
is_rocm = current_platform.is_rocm()
|
||||
params = []
|
||||
device_list = ["cuda", "cpu"] if not is_rocm else ["hip", "cpu"]
|
||||
for use_mla in [True, False]:
|
||||
for device in ["cuda", "hip", "cpu"]:
|
||||
for device in device_list:
|
||||
backends = (
|
||||
DEVICE_MLA_BACKENDS[device]
|
||||
if use_mla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user