From 1ee5ead5f8f1c3c77b73effcb230ee02952fbe1f Mon Sep 17 00:00:00 2001 From: TJian Date: Thu, 7 Aug 2025 19:13:17 -0700 Subject: [PATCH] [ROCm] [V1] [SpecDec] Enable Speculative Decoding on ROCm V1 Engine (#21496) Signed-off-by: tjtanaa --- tests/utils.py | 16 ++++++++ tests/v1/attention/utils.py | 7 +++- tests/v1/e2e/test_spec_decode.py | 15 ++++++++ tests/v1/spec_decode/test_eagle.py | 55 +++++++++++++++++++++++----- tests/v1/spec_decode/test_max_len.py | 54 +++++++++++++++------------ vllm/v1/spec_decode/eagle.py | 22 ++++++++--- 6 files changed, 128 insertions(+), 41 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 1c1a1cc6014ec..741b4401cc213 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -986,3 +986,19 @@ def has_module_attribute(module_name, attribute_name): return hasattr(module, attribute_name) except ImportError: return False + + +def get_attn_backend_list_based_on_platform() -> list[str]: + if current_platform.is_cuda(): + return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TREE_ATTN"] + elif current_platform.is_rocm(): + attn_backend_list = ["TRITON_ATTN_VLLM_V1"] + try: + import aiter # noqa: F401 + attn_backend_list.append("FLASH_ATTN_VLLM_V1") + except Exception: + print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed") + + return attn_backend_list + else: + raise ValueError("Unsupported platform") diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index e9e574501d63e..a4e38eb32f6a1 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -11,7 +11,7 @@ import torch from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, LoadConfig, ModelConfig, ModelDType, ParallelConfig, SchedulerConfig, VllmConfig) -from vllm.platforms import _Backend +from vllm.platforms import _Backend, current_platform from vllm.utils import resolve_obj_by_qualname from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec @@ -119,7 +119,10 @@ def get_attention_backend(backend_name: _Backend): """ backend_map = { _Backend.FLASH_ATTN_VLLM_V1: - "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", + ("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + if current_platform.is_cuda() else + "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" + ), _Backend.FLASHINFER_VLLM_V1: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", _Backend.FLEX_ATTENTION: diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 31f25e94c5b4b..4950faf826b86 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -8,10 +8,12 @@ from typing import Any, Union import pytest import torch +from tests.utils import get_attn_backend_list_based_on_platform from vllm import LLM, SamplingParams from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR from vllm.distributed import cleanup_dist_env_and_memory +from vllm.platforms import current_platform def get_test_prompts(mm_enabled: bool): @@ -141,11 +143,14 @@ def test_ngram_correctness( marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), ], ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"]) +@pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, model_setup: tuple[str, str, str, int], mm_enabled: bool, + attn_backend: str, ): # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) @@ -156,6 +161,16 @@ def test_eagle_correctness( ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + + if (attn_backend == "TRITON_ATTN_VLLM_V1" + and not current_platform.is_rocm()): + pytest.skip("TRITON_ATTN_VLLM_V1 does not support " + "multi-token eagle spec decode on current platform") + + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + m.setenv("VLLM_ROCM_USE_AITER", "1") + method, model_name, spec_model_name, tp_size = model_setup ref_llm = LLM(model=model_name, diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 73b47f8974397..2b4f8bd2a8b90 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -6,6 +6,7 @@ from unittest import mock import pytest import torch +from tests.utils import get_attn_backend_list_based_on_platform from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata, create_standard_kv_cache_spec, @@ -120,17 +121,28 @@ def test_prepare_inputs(): assert torch.equal(token_indices, expected_token_indices) -@pytest.mark.parametrize("method,proposer_helper", [ - ("eagle", lambda k: _create_proposer("eagle", k)), - ("eagle3", lambda k: _create_proposer("eagle3", k)), -]) +@pytest.mark.parametrize("method", ["eagle", "eagle3"]) +@pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("pp_size", [1, 2]) @pytest.mark.parametrize("use_distinct_embed_tokens", [True, False]) @mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') @mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') @mock.patch('vllm.v1.spec_decode.eagle.get_model') def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, - proposer_helper, pp_size, use_distinct_embed_tokens): + attn_backend, pp_size, use_distinct_embed_tokens, + monkeypatch): + + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + + if (attn_backend == "TRITON_ATTN_VLLM_V1" + and not current_platform.is_rocm()): + pytest.skip("TRITON_ATTN_VLLM_V1 does not support " + "multi-token eagle spec decode on current platform") + + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # Setup draft model mock mock_model = mock.MagicMock() if use_distinct_embed_tokens: @@ -177,7 +189,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, target_model.lm_head = mock.MagicMock() # Create proposer using the helper function - proposer = proposer_helper(k=8) + proposer = _create_proposer(method, k=8) # Call the method under test proposer.load_model(target_model) @@ -201,10 +213,22 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, target_model.model.embed_tokens +@pytest.mark.parametrize("method", ["eagle", "eagle3"]) +@pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) -@pytest.mark.parametrize("backend", - [_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN]) -def test_propose(num_speculative_tokens, backend): +def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): + + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + + if (attn_backend == "TRITON_ATTN_VLLM_V1" + and not current_platform.is_rocm()): + pytest.skip("TRITON_ATTN_VLLM_V1 does not support " + "multi-token eagle spec decode on current platform") + + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # Use GPU device device = torch.device(current_platform.device_type) @@ -303,7 +327,18 @@ def test_propose(num_speculative_tokens, backend): device=device) sampling_metadata = mock.MagicMock() - attn_metadata_builder_cls, _ = get_attention_backend(backend) + if attn_backend == "FLASH_ATTN_VLLM_V1": + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.FLASH_ATTN_VLLM_V1) + elif attn_backend == "TRITON_ATTN_VLLM_V1": + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.TRITON_ATTN_VLLM_V1) + elif attn_backend == "TREE_ATTN": + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.TREE_ATTN) + else: + raise ValueError(f"Unsupported attention backend: {attn_backend}") + attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py index 9070d2b10f8b5..fef6a5421b435 100644 --- a/tests/v1/spec_decode/test_max_len.py +++ b/tests/v1/spec_decode/test_max_len.py @@ -4,7 +4,9 @@ import pytest +from tests.utils import get_attn_backend_list_based_on_platform from vllm import LLM, SamplingParams +from vllm.platforms import current_platform _PROMPTS = [ "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1", @@ -14,36 +16,40 @@ _PROMPTS = [ @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) -def test_ngram_max_len( - monkeypatch: pytest.MonkeyPatch, - num_speculative_tokens: int, -): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - - llm = LLM( - model="facebook/opt-125m", - max_model_len=100, - enforce_eager=True, # For faster initialization. - speculative_config={ - "method": "ngram", - "prompt_lookup_max": 5, - "prompt_lookup_min": 3, - "num_speculative_tokens": num_speculative_tokens, - }, - ) - sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) - llm.generate(_PROMPTS, sampling_params) +def test_ngram_max_len(num_speculative_tokens: int): + llm = LLM( + model="facebook/opt-125m", + max_model_len=100, + enforce_eager=True, # For faster initialization. + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": num_speculative_tokens, + }, + ) + sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) + llm.generate(_PROMPTS, sampling_params) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) -def test_eagle_max_len( - monkeypatch: pytest.MonkeyPatch, - num_speculative_tokens: int, -): +@pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) +def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch, + num_speculative_tokens: int, attn_backend: str): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + + if (attn_backend == "TRITON_ATTN_VLLM_V1" + and not current_platform.is_rocm()): + pytest.skip("TRITON_ATTN_VLLM_V1 does not support " + "multi-token eagle spec decode on current platform") + + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + m.setenv("VLLM_ROCM_USE_AITER", "1") + llm = LLM( model="meta-llama/Meta-Llama-3-8B-Instruct", enforce_eager=True, # For faster initialization. diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3c36971fe5b49..f75d76dd978fd 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -17,10 +17,14 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata) from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, TreeAttentionMetadataBuilder) +from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata @@ -230,11 +234,19 @@ class EagleProposer: # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - # Currently, only FlashAttention and TreeAttention support multi-token - # eagle spec decode. This is because the code below - # makes assumptions about attn_metadata attributes available. - assert isinstance(attn_metadata, - (FlashAttentionMetadata, TreeAttentionMetadata)) + # On ROCm, both AiterFlashAttention and TritonAttention + # support multi-token eagle spec decode. + if current_platform.is_rocm(): + assert isinstance( + attn_metadata, + (TritonAttentionMetadata, AiterFlashAttentionMetadata, + FlashAttentionMetadata)) + else: + # Currently, only FlashAttention and TreeAttention support + # multi-token eagle spec decode. This is because the code below + # makes assumptions about attn_metadata attributes available. + assert isinstance(attn_metadata, + (FlashAttentionMetadata, TreeAttentionMetadata)) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids]