[ROCm] [V1] [SpecDec] Enable Speculative Decoding on ROCm V1 Engine (#21496)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian 2025-08-07 19:13:17 -07:00 committed by GitHub
parent acf8aeb79e
commit 1ee5ead5f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 128 additions and 41 deletions

View File

@ -986,3 +986,19 @@ def has_module_attribute(module_name, attribute_name):
return hasattr(module, attribute_name) return hasattr(module, attribute_name)
except ImportError: except ImportError:
return False return False
def get_attn_backend_list_based_on_platform() -> list[str]:
if current_platform.is_cuda():
return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TREE_ATTN"]
elif current_platform.is_rocm():
attn_backend_list = ["TRITON_ATTN_VLLM_V1"]
try:
import aiter # noqa: F401
attn_backend_list.append("FLASH_ATTN_VLLM_V1")
except Exception:
print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed")
return attn_backend_list
else:
raise ValueError("Unsupported platform")

View File

@ -11,7 +11,7 @@ import torch
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
LoadConfig, ModelConfig, ModelDType, ParallelConfig, LoadConfig, ModelConfig, ModelDType, ParallelConfig,
SchedulerConfig, VllmConfig) SchedulerConfig, VllmConfig)
from vllm.platforms import _Backend from vllm.platforms import _Backend, current_platform
from vllm.utils import resolve_obj_by_qualname from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
@ -119,7 +119,10 @@ def get_attention_backend(backend_name: _Backend):
""" """
backend_map = { backend_map = {
_Backend.FLASH_ATTN_VLLM_V1: _Backend.FLASH_ATTN_VLLM_V1:
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", ("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
if current_platform.is_cuda() else
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
),
_Backend.FLASHINFER_VLLM_V1: _Backend.FLASHINFER_VLLM_V1:
"vllm.v1.attention.backends.flashinfer.FlashInferBackend", "vllm.v1.attention.backends.flashinfer.FlashInferBackend",
_Backend.FLEX_ATTENTION: _Backend.FLEX_ATTENTION:

View File

@ -8,10 +8,12 @@ from typing import Any, Union
import pytest import pytest
import torch import torch
from tests.utils import get_attn_backend_list_based_on_platform
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
def get_test_prompts(mm_enabled: bool): def get_test_prompts(mm_enabled: bool):
@ -141,11 +143,14 @@ def test_ngram_correctness(
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
], ],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"]) ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
def test_eagle_correctness( def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int], model_setup: tuple[str, str, str, int],
mm_enabled: bool, mm_enabled: bool,
attn_backend: str,
): ):
# Generate test prompts inside the function instead of using fixture # Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled) test_prompts = get_test_prompts(mm_enabled)
@ -156,6 +161,16 @@ def test_eagle_correctness(
''' '''
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")
method, model_name, spec_model_name, tp_size = model_setup method, model_name, spec_model_name, tp_size = model_setup
ref_llm = LLM(model=model_name, ref_llm = LLM(model=model_name,

View File

@ -6,6 +6,7 @@ from unittest import mock
import pytest import pytest
import torch import torch
from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata, create_common_attn_metadata,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
@ -120,17 +121,28 @@ def test_prepare_inputs():
assert torch.equal(token_indices, expected_token_indices) assert torch.equal(token_indices, expected_token_indices)
@pytest.mark.parametrize("method,proposer_helper", [ @pytest.mark.parametrize("method", ["eagle", "eagle3"])
("eagle", lambda k: _create_proposer("eagle", k)), @pytest.mark.parametrize("attn_backend",
("eagle3", lambda k: _create_proposer("eagle3", k)), get_attn_backend_list_based_on_platform())
])
@pytest.mark.parametrize("pp_size", [1, 2]) @pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False]) @pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') @mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') @mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model') @mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
proposer_helper, pp_size, use_distinct_embed_tokens): attn_backend, pp_size, use_distinct_embed_tokens,
monkeypatch):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Setup draft model mock # Setup draft model mock
mock_model = mock.MagicMock() mock_model = mock.MagicMock()
if use_distinct_embed_tokens: if use_distinct_embed_tokens:
@ -177,7 +189,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
target_model.lm_head = mock.MagicMock() target_model.lm_head = mock.MagicMock()
# Create proposer using the helper function # Create proposer using the helper function
proposer = proposer_helper(k=8) proposer = _create_proposer(method, k=8)
# Call the method under test # Call the method under test
proposer.load_model(target_model) proposer.load_model(target_model)
@ -201,10 +213,22 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
target_model.model.embed_tokens target_model.model.embed_tokens
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
@pytest.mark.parametrize("backend", def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
def test_propose(num_speculative_tokens, backend): monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Use GPU device # Use GPU device
device = torch.device(current_platform.device_type) device = torch.device(current_platform.device_type)
@ -303,7 +327,18 @@ def test_propose(num_speculative_tokens, backend):
device=device) device=device)
sampling_metadata = mock.MagicMock() sampling_metadata = mock.MagicMock()
attn_metadata_builder_cls, _ = get_attention_backend(backend) if attn_backend == "FLASH_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1)
elif attn_backend == "TRITON_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.TRITON_ATTN_VLLM_V1)
elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.TREE_ATTN)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")
attn_metadata_builder = attn_metadata_builder_cls( attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names, layer_names=proposer.attn_layer_names,

View File

@ -4,7 +4,9 @@
import pytest import pytest
from tests.utils import get_attn_backend_list_based_on_platform
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
_PROMPTS = [ _PROMPTS = [
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1", "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
@ -14,36 +16,40 @@ _PROMPTS = [
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_ngram_max_len( def test_ngram_max_len(num_speculative_tokens: int):
monkeypatch: pytest.MonkeyPatch, llm = LLM(
num_speculative_tokens: int, model="facebook/opt-125m",
): max_model_len=100,
with monkeypatch.context() as m: enforce_eager=True, # For faster initialization.
m.setenv("VLLM_USE_V1", "1") speculative_config={
"method": "ngram",
llm = LLM( "prompt_lookup_max": 5,
model="facebook/opt-125m", "prompt_lookup_min": 3,
max_model_len=100, "num_speculative_tokens": num_speculative_tokens,
enforce_eager=True, # For faster initialization. },
speculative_config={ )
"method": "ngram", sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
"prompt_lookup_max": 5, llm.generate(_PROMPTS, sampling_params)
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_eagle_max_len( @pytest.mark.parametrize("attn_backend",
monkeypatch: pytest.MonkeyPatch, get_attn_backend_list_based_on_platform())
num_speculative_tokens: int, def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
): num_speculative_tokens: int, attn_backend: str):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")
llm = LLM( llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct", model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization. enforce_eager=True, # For faster initialization.

View File

@ -17,10 +17,14 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder) TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
@ -230,11 +234,19 @@ class EagleProposer:
# one layer. Adapt this code to support multiple layers once # one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module. # there's a multi-layer MTP module.
# Currently, only FlashAttention and TreeAttention support multi-token # On ROCm, both AiterFlashAttention and TritonAttention
# eagle spec decode. This is because the code below # support multi-token eagle spec decode.
# makes assumptions about attn_metadata attributes available. if current_platform.is_rocm():
assert isinstance(attn_metadata, assert isinstance(
(FlashAttentionMetadata, TreeAttentionMetadata)) attn_metadata,
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
FlashAttentionMetadata))
else:
# Currently, only FlashAttention and TreeAttention support
# multi-token eagle spec decode. This is because the code below
# makes assumptions about attn_metadata attributes available.
assert isinstance(attn_metadata,
(FlashAttentionMetadata, TreeAttentionMetadata))
# Generate the remaining draft tokens. # Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids] draft_token_ids_list = [draft_token_ids]