mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 05:57:14 +08:00
[ROCm] [V1] [SpecDec] Enable Speculative Decoding on ROCm V1 Engine (#21496)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
acf8aeb79e
commit
1ee5ead5f8
@ -986,3 +986,19 @@ def has_module_attribute(module_name, attribute_name):
|
||||
return hasattr(module, attribute_name)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_attn_backend_list_based_on_platform() -> list[str]:
|
||||
if current_platform.is_cuda():
|
||||
return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TREE_ATTN"]
|
||||
elif current_platform.is_rocm():
|
||||
attn_backend_list = ["TRITON_ATTN_VLLM_V1"]
|
||||
try:
|
||||
import aiter # noqa: F401
|
||||
attn_backend_list.append("FLASH_ATTN_VLLM_V1")
|
||||
except Exception:
|
||||
print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed")
|
||||
|
||||
return attn_backend_list
|
||||
else:
|
||||
raise ValueError("Unsupported platform")
|
||||
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
|
||||
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig)
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
@ -119,7 +119,10 @@ def get_attention_backend(backend_name: _Backend):
|
||||
"""
|
||||
backend_map = {
|
||||
_Backend.FLASH_ATTN_VLLM_V1:
|
||||
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend",
|
||||
("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
if current_platform.is_cuda() else
|
||||
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
),
|
||||
_Backend.FLASHINFER_VLLM_V1:
|
||||
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
|
||||
_Backend.FLEX_ATTENTION:
|
||||
|
||||
@ -8,10 +8,12 @@ from typing import Any, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
@ -141,11 +143,14 @@ def test_ngram_correctness(
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
],
|
||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
attn_backend: str,
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
@ -156,6 +161,16 @@ def test_eagle_correctness(
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
method, model_name, spec_model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name,
|
||||
|
||||
@ -6,6 +6,7 @@ from unittest import mock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
@ -120,17 +121,28 @@ def test_prepare_inputs():
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method,proposer_helper", [
|
||||
("eagle", lambda k: _create_proposer("eagle", k)),
|
||||
("eagle3", lambda k: _create_proposer("eagle3", k)),
|
||||
])
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
@pytest.mark.parametrize("pp_size", [1, 2])
|
||||
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
proposer_helper, pp_size, use_distinct_embed_tokens):
|
||||
attn_backend, pp_size, use_distinct_embed_tokens,
|
||||
monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Setup draft model mock
|
||||
mock_model = mock.MagicMock()
|
||||
if use_distinct_embed_tokens:
|
||||
@ -177,7 +189,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
|
||||
# Create proposer using the helper function
|
||||
proposer = proposer_helper(k=8)
|
||||
proposer = _create_proposer(method, k=8)
|
||||
|
||||
# Call the method under test
|
||||
proposer.load_model(target_model)
|
||||
@ -201,10 +213,22 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
target_model.model.embed_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||
@pytest.mark.parametrize("backend",
|
||||
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
|
||||
def test_propose(num_speculative_tokens, backend):
|
||||
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Use GPU device
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
@ -303,7 +327,18 @@ def test_propose(num_speculative_tokens, backend):
|
||||
device=device)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(backend)
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1":
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.FLASH_ATTN_VLLM_V1)
|
||||
elif attn_backend == "TRITON_ATTN_VLLM_V1":
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.TRITON_ATTN_VLLM_V1)
|
||||
elif attn_backend == "TREE_ATTN":
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.TREE_ATTN)
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention backend: {attn_backend}")
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
|
||||
@ -4,7 +4,9 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
_PROMPTS = [
|
||||
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
|
||||
@ -14,36 +16,40 @@ _PROMPTS = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
||||
def test_ngram_max_len(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
num_speculative_tokens: int,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
max_model_len=100,
|
||||
enforce_eager=True, # For faster initialization.
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
},
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
|
||||
llm.generate(_PROMPTS, sampling_params)
|
||||
def test_ngram_max_len(num_speculative_tokens: int):
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
max_model_len=100,
|
||||
enforce_eager=True, # For faster initialization.
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
},
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
|
||||
llm.generate(_PROMPTS, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
||||
def test_eagle_max_len(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
num_speculative_tokens: int,
|
||||
):
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
|
||||
num_speculative_tokens: int, attn_backend: str):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
enforce_eager=True, # For faster initialization.
|
||||
|
||||
@ -17,10 +17,14 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||
AiterFlashAttentionMetadata)
|
||||
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
||||
TreeAttentionMetadataBuilder)
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@ -230,11 +234,19 @@ class EagleProposer:
|
||||
# one layer. Adapt this code to support multiple layers once
|
||||
# there's a multi-layer MTP module.
|
||||
|
||||
# Currently, only FlashAttention and TreeAttention support multi-token
|
||||
# eagle spec decode. This is because the code below
|
||||
# makes assumptions about attn_metadata attributes available.
|
||||
assert isinstance(attn_metadata,
|
||||
(FlashAttentionMetadata, TreeAttentionMetadata))
|
||||
# On ROCm, both AiterFlashAttention and TritonAttention
|
||||
# support multi-token eagle spec decode.
|
||||
if current_platform.is_rocm():
|
||||
assert isinstance(
|
||||
attn_metadata,
|
||||
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
|
||||
FlashAttentionMetadata))
|
||||
else:
|
||||
# Currently, only FlashAttention and TreeAttention support
|
||||
# multi-token eagle spec decode. This is because the code below
|
||||
# makes assumptions about attn_metadata attributes available.
|
||||
assert isinstance(attn_metadata,
|
||||
(FlashAttentionMetadata, TreeAttentionMetadata))
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user