mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 10:57:55 +08:00
[V0 Deprecation] Remove V0 FlashInfer attention backend (#22776)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
6603288736
commit
14006840ea
@ -12,7 +12,6 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import LLM, envs
|
from vllm import LLM, envs
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
|
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
|
||||||
|
|
||||||
from ..conftest import HfRunner, VllmRunner
|
from ..conftest import HfRunner, VllmRunner
|
||||||
@ -78,11 +77,7 @@ def test_models(
|
|||||||
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||||
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||||
|
|
||||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
||||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
|
||||||
|
|
||||||
if backend in ("XFORMERS",
|
|
||||||
"FLASHINFER") and model == "google/gemma-2-2b-it":
|
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
f"{backend} does not support gemma2 with full context length.")
|
f"{backend} does not support gemma2 with full context length.")
|
||||||
|
|
||||||
@ -141,8 +136,6 @@ def test_models(
|
|||||||
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
||||||
("distilbert/distilgpt2", "ray", "", "A100", {}),
|
("distilbert/distilgpt2", "ray", "", "A100", {}),
|
||||||
("distilbert/distilgpt2", "mp", "", "A100", {}),
|
("distilbert/distilgpt2", "mp", "", "A100", {}),
|
||||||
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100", {}),
|
|
||||||
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100", {}),
|
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||||
def test_models_distributed(
|
def test_models_distributed(
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class TestSetting:
|
|||||||
model_args=["--max-model-len", "2048"],
|
model_args=["--max-model-len", "2048"],
|
||||||
pp_size=2,
|
pp_size=2,
|
||||||
tp_size=2,
|
tp_size=2,
|
||||||
attn_backend="FLASHINFER",
|
attn_backend="FLASH_ATTN",
|
||||||
method="generate",
|
method="generate",
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
),
|
),
|
||||||
|
|||||||
@ -32,7 +32,7 @@ BLOCK_SIZE = 16
|
|||||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("batch_size", [5])
|
@pytest.mark.parametrize("batch_size", [5])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"])
|
||||||
def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator,
|
def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator,
|
||||||
batch_size, seed, backend, monkeypatch):
|
batch_size, seed, backend, monkeypatch):
|
||||||
"""
|
"""
|
||||||
@ -43,8 +43,6 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator,
|
|||||||
|
|
||||||
Additionally, we compare the results of the v1 and v2 managers.
|
Additionally, we compare the results of the v1 and v2 managers.
|
||||||
"""
|
"""
|
||||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
|
||||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
|
||||||
if backend == "XFORMERS" and current_platform.is_rocm():
|
if backend == "XFORMERS" and current_platform.is_rocm():
|
||||||
pytest.skip("Xformers does not support ROCm/HIP.")
|
pytest.skip("Xformers does not support ROCm/HIP.")
|
||||||
|
|
||||||
@ -96,7 +94,7 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator,
|
|||||||
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}])
|
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}])
|
||||||
@pytest.mark.parametrize("batch_size", [5])
|
@pytest.mark.parametrize("batch_size", [5])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"])
|
||||||
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
|
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
|
||||||
backend, monkeypatch):
|
backend, monkeypatch):
|
||||||
"""
|
"""
|
||||||
@ -107,8 +105,6 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
|
|||||||
The results with and without chunked prefill are not the same due to
|
The results with and without chunked prefill are not the same due to
|
||||||
numerical instabilities.
|
numerical instabilities.
|
||||||
"""
|
"""
|
||||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
|
||||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
|
||||||
if backend == "XFORMERS" and current_platform.is_rocm():
|
if backend == "XFORMERS" and current_platform.is_rocm():
|
||||||
pytest.skip("Xformers does not support ROCm/HIP.")
|
pytest.skip("Xformers does not support ROCm/HIP.")
|
||||||
override_backend_env_variable(monkeypatch, backend)
|
override_backend_env_variable(monkeypatch, backend)
|
||||||
|
|||||||
@ -17,7 +17,6 @@ if TYPE_CHECKING:
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("ATTN_BACKEND", [
|
@pytest.mark.parametrize("ATTN_BACKEND", [
|
||||||
"FLASH_ATTN",
|
"FLASH_ATTN",
|
||||||
"FLASHINFER",
|
|
||||||
])
|
])
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_pp_cudagraph(
|
def test_pp_cudagraph(
|
||||||
|
|||||||
@ -81,6 +81,9 @@ def test_env(
|
|||||||
m.setenv(STR_BACKEND_ENV_VAR, name)
|
m.setenv(STR_BACKEND_ENV_VAR, name)
|
||||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||||
|
|
||||||
|
if name == "FLASHINFER" and not use_v1:
|
||||||
|
pytest.skip("FlashInfer backend is only available on V1 engine")
|
||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
if not use_v1:
|
if not use_v1:
|
||||||
pytest.skip("CPU backend only supports V1")
|
pytest.skip("CPU backend only supports V1")
|
||||||
|
|||||||
@ -32,7 +32,7 @@ from ..utils import check_logprobs_close
|
|||||||
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||||
@pytest.mark.parametrize("max_tokens", [4])
|
@pytest.mark.parametrize("max_tokens", [4])
|
||||||
@pytest.mark.parametrize("enforce_eager", [True])
|
@pytest.mark.parametrize("enforce_eager", [True])
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"])
|
||||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||||
# reset distributed env properly. Use a value > 1 just when you test.
|
# reset distributed env properly. Use a value > 1 just when you test.
|
||||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||||
@ -57,9 +57,6 @@ def test_models(
|
|||||||
numerical sensitive kernels.
|
numerical sensitive kernels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
|
||||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
|
||||||
|
|
||||||
if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm():
|
if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm():
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
f"{kv_cache_dtype} is currently not supported on ROCm/HIP.")
|
f"{kv_cache_dtype} is currently not supported on ROCm/HIP.")
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -350,17 +350,7 @@ class CudaPlatformBase(Platform):
|
|||||||
return FLEX_ATTENTION_V1
|
return FLEX_ATTENTION_V1
|
||||||
|
|
||||||
# Backends for V0 engine
|
# Backends for V0 engine
|
||||||
if selected_backend == _Backend.FLASHINFER:
|
if selected_backend == _Backend.XFORMERS:
|
||||||
logger.info("Using FlashInfer backend.")
|
|
||||||
if cls.has_device_capability(100):
|
|
||||||
from vllm.v1.attention.backends.utils import (
|
|
||||||
set_kv_cache_layout)
|
|
||||||
logger.info_once(
|
|
||||||
"Using HND KV cache layout on V1 engine by default for "
|
|
||||||
"Blackwell (SM 10.0) GPUs.")
|
|
||||||
set_kv_cache_layout("HND")
|
|
||||||
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
|
||||||
elif selected_backend == _Backend.XFORMERS:
|
|
||||||
logger.info("Using XFormers backend.")
|
logger.info("Using XFormers backend.")
|
||||||
return "vllm.attention.backends.xformers.XFormersBackend"
|
return "vllm.attention.backends.xformers.XFormersBackend"
|
||||||
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
|
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
|
||||||
@ -416,10 +406,6 @@ class CudaPlatformBase(Platform):
|
|||||||
if (fp8_kv_cache and not flash_attn_supports_fp8()):
|
if (fp8_kv_cache and not flash_attn_supports_fp8()):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cannot use FlashAttention backend for FP8 KV cache.")
|
"Cannot use FlashAttention backend for FP8 KV cache.")
|
||||||
logger.warning(
|
|
||||||
"Please use FlashInfer backend with FP8 KV Cache for "
|
|
||||||
"better performance by setting environment variable "
|
|
||||||
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
|
||||||
target_backend = _Backend.XFORMERS
|
target_backend = _Backend.XFORMERS
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user