[V0 Deprecation] Remove V0 FlashInfer attention backend (#22776)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-18 19:54:16 -07:00 committed by GitHub
parent 6603288736
commit 14006840ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 9 additions and 1133 deletions

View File

@ -12,7 +12,6 @@ import pytest
import torch import torch
from vllm import LLM, envs from vllm import LLM, envs
from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
from ..conftest import HfRunner, VllmRunner from ..conftest import HfRunner, VllmRunner
@ -78,11 +77,7 @@ def test_models(
"VLLM_USE_V1") and envs.VLLM_USE_V1: "VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.") pytest.skip("enable_prompt_embeds is not supported in v1.")
if backend == "FLASHINFER" and current_platform.is_rocm(): if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip("Flashinfer does not support ROCm/HIP.")
if backend in ("XFORMERS",
"FLASHINFER") and model == "google/gemma-2-2b-it":
pytest.skip( pytest.skip(
f"{backend} does not support gemma2 with full context length.") f"{backend} does not support gemma2 with full context length.")
@ -141,8 +136,6 @@ def test_models(
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
("distilbert/distilgpt2", "ray", "", "A100", {}), ("distilbert/distilgpt2", "ray", "", "A100", {}),
("distilbert/distilgpt2", "mp", "", "A100", {}), ("distilbert/distilgpt2", "mp", "", "A100", {}),
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100", {}),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100", {}),
]) ])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models_distributed( def test_models_distributed(

View File

@ -34,7 +34,7 @@ class TestSetting:
model_args=["--max-model-len", "2048"], model_args=["--max-model-len", "2048"],
pp_size=2, pp_size=2,
tp_size=2, tp_size=2,
attn_backend="FLASHINFER", attn_backend="FLASH_ATTN",
method="generate", method="generate",
fullgraph=True, fullgraph=True,
), ),

View File

@ -32,7 +32,7 @@ BLOCK_SIZE = 16
@pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"])
def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator,
batch_size, seed, backend, monkeypatch): batch_size, seed, backend, monkeypatch):
""" """
@ -43,8 +43,6 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator,
Additionally, we compare the results of the v1 and v2 managers. Additionally, we compare the results of the v1 and v2 managers.
""" """
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")
if backend == "XFORMERS" and current_platform.is_rocm(): if backend == "XFORMERS" and current_platform.is_rocm():
pytest.skip("Xformers does not support ROCm/HIP.") pytest.skip("Xformers does not support ROCm/HIP.")
@ -96,7 +94,7 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator,
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) @pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}])
@pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"])
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
backend, monkeypatch): backend, monkeypatch):
""" """
@ -107,8 +105,6 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
The results with and without chunked prefill are not the same due to The results with and without chunked prefill are not the same due to
numerical instabilities. numerical instabilities.
""" """
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")
if backend == "XFORMERS" and current_platform.is_rocm(): if backend == "XFORMERS" and current_platform.is_rocm():
pytest.skip("Xformers does not support ROCm/HIP.") pytest.skip("Xformers does not support ROCm/HIP.")
override_backend_env_variable(monkeypatch, backend) override_backend_env_variable(monkeypatch, backend)

View File

@ -17,7 +17,6 @@ if TYPE_CHECKING:
]) ])
@pytest.mark.parametrize("ATTN_BACKEND", [ @pytest.mark.parametrize("ATTN_BACKEND", [
"FLASH_ATTN", "FLASH_ATTN",
"FLASHINFER",
]) ])
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_pp_cudagraph( def test_pp_cudagraph(

View File

@ -81,6 +81,9 @@ def test_env(
m.setenv(STR_BACKEND_ENV_VAR, name) m.setenv(STR_BACKEND_ENV_VAR, name)
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
if name == "FLASHINFER" and not use_v1:
pytest.skip("FlashInfer backend is only available on V1 engine")
if device == "cpu": if device == "cpu":
if not use_v1: if not use_v1:
pytest.skip("CPU backend only supports V1") pytest.skip("CPU backend only supports V1")

View File

@ -32,7 +32,7 @@ from ..utils import check_logprobs_close
# Due to low-precision numerical divergence, we only test logprob of 4 tokens # Due to low-precision numerical divergence, we only test logprob of 4 tokens
@pytest.mark.parametrize("max_tokens", [4]) @pytest.mark.parametrize("max_tokens", [4])
@pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"])
# NOTE: Increasing this in this suite will fail CI because we currently cannot # NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test. # reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("tensor_parallel_size", [1])
@ -57,9 +57,6 @@ def test_models(
numerical sensitive kernels. numerical sensitive kernels.
""" """
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")
if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm(): if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm():
pytest.skip( pytest.skip(
f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") f"{kv_cache_dtype} is currently not supported on ROCm/HIP.")

File diff suppressed because it is too large Load Diff

View File

@ -350,17 +350,7 @@ class CudaPlatformBase(Platform):
return FLEX_ATTENTION_V1 return FLEX_ATTENTION_V1
# Backends for V0 engine # Backends for V0 engine
if selected_backend == _Backend.FLASHINFER: if selected_backend == _Backend.XFORMERS:
logger.info("Using FlashInfer backend.")
if cls.has_device_capability(100):
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
logger.info_once(
"Using HND KV cache layout on V1 engine by default for "
"Blackwell (SM 10.0) GPUs.")
set_kv_cache_layout("HND")
return "vllm.attention.backends.flashinfer.FlashInferBackend"
elif selected_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.") logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend" return "vllm.attention.backends.xformers.XFormersBackend"
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN: elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
@ -416,10 +406,6 @@ class CudaPlatformBase(Platform):
if (fp8_kv_cache and not flash_attn_supports_fp8()): if (fp8_kv_cache and not flash_attn_supports_fp8()):
logger.info( logger.info(
"Cannot use FlashAttention backend for FP8 KV cache.") "Cannot use FlashAttention backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
target_backend = _Backend.XFORMERS target_backend = _Backend.XFORMERS
except ImportError: except ImportError:
logger.info( logger.info(