mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 01:42:17 +08:00
[ROCm] [V1] [SpecDec] Enable Speculative Decoding on ROCm V1 Engine (#21496)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
acf8aeb79e
commit
1ee5ead5f8
@ -986,3 +986,19 @@ def has_module_attribute(module_name, attribute_name):
|
|||||||
return hasattr(module, attribute_name)
|
return hasattr(module, attribute_name)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_backend_list_based_on_platform() -> list[str]:
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TREE_ATTN"]
|
||||||
|
elif current_platform.is_rocm():
|
||||||
|
attn_backend_list = ["TRITON_ATTN_VLLM_V1"]
|
||||||
|
try:
|
||||||
|
import aiter # noqa: F401
|
||||||
|
attn_backend_list.append("FLASH_ATTN_VLLM_V1")
|
||||||
|
except Exception:
|
||||||
|
print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed")
|
||||||
|
|
||||||
|
return attn_backend_list
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported platform")
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import torch
|
|||||||
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
|
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
|
||||||
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
|
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
|
||||||
SchedulerConfig, VllmConfig)
|
SchedulerConfig, VllmConfig)
|
||||||
from vllm.platforms import _Backend
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.utils import resolve_obj_by_qualname
|
from vllm.utils import resolve_obj_by_qualname
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||||
@ -119,7 +119,10 @@ def get_attention_backend(backend_name: _Backend):
|
|||||||
"""
|
"""
|
||||||
backend_map = {
|
backend_map = {
|
||||||
_Backend.FLASH_ATTN_VLLM_V1:
|
_Backend.FLASH_ATTN_VLLM_V1:
|
||||||
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend",
|
("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||||
|
if current_platform.is_cuda() else
|
||||||
|
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||||
|
),
|
||||||
_Backend.FLASHINFER_VLLM_V1:
|
_Backend.FLASHINFER_VLLM_V1:
|
||||||
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
|
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
|
||||||
_Backend.FLEX_ATTENTION:
|
_Backend.FLEX_ATTENTION:
|
||||||
|
|||||||
@ -8,10 +8,12 @@ from typing import Any, Union
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.utils import get_attn_backend_list_based_on_platform
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||||
from vllm.assets.image import VLM_IMAGES_DIR
|
from vllm.assets.image import VLM_IMAGES_DIR
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
def get_test_prompts(mm_enabled: bool):
|
def get_test_prompts(mm_enabled: bool):
|
||||||
@ -141,11 +143,14 @@ def test_ngram_correctness(
|
|||||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||||
],
|
],
|
||||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
|
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
|
||||||
|
@pytest.mark.parametrize("attn_backend",
|
||||||
|
get_attn_backend_list_based_on_platform())
|
||||||
def test_eagle_correctness(
|
def test_eagle_correctness(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_setup: tuple[str, str, str, int],
|
model_setup: tuple[str, str, str, int],
|
||||||
mm_enabled: bool,
|
mm_enabled: bool,
|
||||||
|
attn_backend: str,
|
||||||
):
|
):
|
||||||
# Generate test prompts inside the function instead of using fixture
|
# Generate test prompts inside the function instead of using fixture
|
||||||
test_prompts = get_test_prompts(mm_enabled)
|
test_prompts = get_test_prompts(mm_enabled)
|
||||||
@ -156,6 +161,16 @@ def test_eagle_correctness(
|
|||||||
'''
|
'''
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|
||||||
|
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||||
|
and not current_platform.is_rocm()):
|
||||||
|
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||||
|
"multi-token eagle spec decode on current platform")
|
||||||
|
|
||||||
|
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||||
|
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
method, model_name, spec_model_name, tp_size = model_setup
|
method, model_name, spec_model_name, tp_size = model_setup
|
||||||
|
|
||||||
ref_llm = LLM(model=model_name,
|
ref_llm = LLM(model=model_name,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from unittest import mock
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.utils import get_attn_backend_list_based_on_platform
|
||||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||||
create_common_attn_metadata,
|
create_common_attn_metadata,
|
||||||
create_standard_kv_cache_spec,
|
create_standard_kv_cache_spec,
|
||||||
@ -120,17 +121,28 @@ def test_prepare_inputs():
|
|||||||
assert torch.equal(token_indices, expected_token_indices)
|
assert torch.equal(token_indices, expected_token_indices)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("method,proposer_helper", [
|
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||||
("eagle", lambda k: _create_proposer("eagle", k)),
|
@pytest.mark.parametrize("attn_backend",
|
||||||
("eagle3", lambda k: _create_proposer("eagle3", k)),
|
get_attn_backend_list_based_on_platform())
|
||||||
])
|
|
||||||
@pytest.mark.parametrize("pp_size", [1, 2])
|
@pytest.mark.parametrize("pp_size", [1, 2])
|
||||||
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
|
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||||
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||||
proposer_helper, pp_size, use_distinct_embed_tokens):
|
attn_backend, pp_size, use_distinct_embed_tokens,
|
||||||
|
monkeypatch):
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|
||||||
|
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||||
|
and not current_platform.is_rocm()):
|
||||||
|
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||||
|
"multi-token eagle spec decode on current platform")
|
||||||
|
|
||||||
|
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||||
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
# Setup draft model mock
|
# Setup draft model mock
|
||||||
mock_model = mock.MagicMock()
|
mock_model = mock.MagicMock()
|
||||||
if use_distinct_embed_tokens:
|
if use_distinct_embed_tokens:
|
||||||
@ -177,7 +189,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
|||||||
target_model.lm_head = mock.MagicMock()
|
target_model.lm_head = mock.MagicMock()
|
||||||
|
|
||||||
# Create proposer using the helper function
|
# Create proposer using the helper function
|
||||||
proposer = proposer_helper(k=8)
|
proposer = _create_proposer(method, k=8)
|
||||||
|
|
||||||
# Call the method under test
|
# Call the method under test
|
||||||
proposer.load_model(target_model)
|
proposer.load_model(target_model)
|
||||||
@ -201,10 +213,22 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
|||||||
target_model.model.embed_tokens
|
target_model.model.embed_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||||
|
@pytest.mark.parametrize("attn_backend",
|
||||||
|
get_attn_backend_list_based_on_platform())
|
||||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||||
@pytest.mark.parametrize("backend",
|
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||||
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
|
|
||||||
def test_propose(num_speculative_tokens, backend):
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|
||||||
|
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||||
|
and not current_platform.is_rocm()):
|
||||||
|
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||||
|
"multi-token eagle spec decode on current platform")
|
||||||
|
|
||||||
|
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||||
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
# Use GPU device
|
# Use GPU device
|
||||||
device = torch.device(current_platform.device_type)
|
device = torch.device(current_platform.device_type)
|
||||||
|
|
||||||
@ -303,7 +327,18 @@ def test_propose(num_speculative_tokens, backend):
|
|||||||
device=device)
|
device=device)
|
||||||
sampling_metadata = mock.MagicMock()
|
sampling_metadata = mock.MagicMock()
|
||||||
|
|
||||||
attn_metadata_builder_cls, _ = get_attention_backend(backend)
|
if attn_backend == "FLASH_ATTN_VLLM_V1":
|
||||||
|
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||||
|
_Backend.FLASH_ATTN_VLLM_V1)
|
||||||
|
elif attn_backend == "TRITON_ATTN_VLLM_V1":
|
||||||
|
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||||
|
_Backend.TRITON_ATTN_VLLM_V1)
|
||||||
|
elif attn_backend == "TREE_ATTN":
|
||||||
|
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||||
|
_Backend.TREE_ATTN)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported attention backend: {attn_backend}")
|
||||||
|
|
||||||
attn_metadata_builder = attn_metadata_builder_cls(
|
attn_metadata_builder = attn_metadata_builder_cls(
|
||||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||||
layer_names=proposer.attn_layer_names,
|
layer_names=proposer.attn_layer_names,
|
||||||
|
|||||||
@ -4,7 +4,9 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.utils import get_attn_backend_list_based_on_platform
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
_PROMPTS = [
|
_PROMPTS = [
|
||||||
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
|
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
|
||||||
@ -14,36 +16,40 @@ _PROMPTS = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
||||||
def test_ngram_max_len(
|
def test_ngram_max_len(num_speculative_tokens: int):
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
llm = LLM(
|
||||||
num_speculative_tokens: int,
|
model="facebook/opt-125m",
|
||||||
):
|
max_model_len=100,
|
||||||
with monkeypatch.context() as m:
|
enforce_eager=True, # For faster initialization.
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
speculative_config={
|
||||||
|
"method": "ngram",
|
||||||
llm = LLM(
|
"prompt_lookup_max": 5,
|
||||||
model="facebook/opt-125m",
|
"prompt_lookup_min": 3,
|
||||||
max_model_len=100,
|
"num_speculative_tokens": num_speculative_tokens,
|
||||||
enforce_eager=True, # For faster initialization.
|
},
|
||||||
speculative_config={
|
)
|
||||||
"method": "ngram",
|
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
|
||||||
"prompt_lookup_max": 5,
|
llm.generate(_PROMPTS, sampling_params)
|
||||||
"prompt_lookup_min": 3,
|
|
||||||
"num_speculative_tokens": num_speculative_tokens,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
|
|
||||||
llm.generate(_PROMPTS, sampling_params)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
||||||
def test_eagle_max_len(
|
@pytest.mark.parametrize("attn_backend",
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
get_attn_backend_list_based_on_platform())
|
||||||
num_speculative_tokens: int,
|
def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
|
||||||
):
|
num_speculative_tokens: int, attn_backend: str):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|
||||||
|
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||||
|
and not current_platform.is_rocm()):
|
||||||
|
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||||
|
"multi-token eagle spec decode on current platform")
|
||||||
|
|
||||||
|
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||||
|
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
enforce_eager=True, # For faster initialization.
|
enforce_eager=True, # For faster initialization.
|
||||||
|
|||||||
@ -17,10 +17,14 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import supports_multimodal
|
from vllm.model_executor.models import supports_multimodal
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
|
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||||
|
AiterFlashAttentionMetadata)
|
||||||
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
||||||
TreeAttentionMetadataBuilder)
|
TreeAttentionMetadataBuilder)
|
||||||
|
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
@ -230,11 +234,19 @@ class EagleProposer:
|
|||||||
# one layer. Adapt this code to support multiple layers once
|
# one layer. Adapt this code to support multiple layers once
|
||||||
# there's a multi-layer MTP module.
|
# there's a multi-layer MTP module.
|
||||||
|
|
||||||
# Currently, only FlashAttention and TreeAttention support multi-token
|
# On ROCm, both AiterFlashAttention and TritonAttention
|
||||||
# eagle spec decode. This is because the code below
|
# support multi-token eagle spec decode.
|
||||||
# makes assumptions about attn_metadata attributes available.
|
if current_platform.is_rocm():
|
||||||
assert isinstance(attn_metadata,
|
assert isinstance(
|
||||||
(FlashAttentionMetadata, TreeAttentionMetadata))
|
attn_metadata,
|
||||||
|
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
|
||||||
|
FlashAttentionMetadata))
|
||||||
|
else:
|
||||||
|
# Currently, only FlashAttention and TreeAttention support
|
||||||
|
# multi-token eagle spec decode. This is because the code below
|
||||||
|
# makes assumptions about attn_metadata attributes available.
|
||||||
|
assert isinstance(attn_metadata,
|
||||||
|
(FlashAttentionMetadata, TreeAttentionMetadata))
|
||||||
|
|
||||||
# Generate the remaining draft tokens.
|
# Generate the remaining draft tokens.
|
||||||
draft_token_ids_list = [draft_token_ids]
|
draft_token_ids_list = [draft_token_ids]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user