mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 03:07:04 +08:00
[Spec Decode] Enable FlashInfer Spec Decoding (#25196)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
0e54bbe108
commit
177c37e960
@ -9,7 +9,8 @@ from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata)
|
||||
split_attn_metadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
|
||||
|
||||
|
||||
@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
||||
|
||||
|
||||
def apply_split_decodes_and_prefills(query_lens: list[int],
|
||||
decode_threshold: int,
|
||||
require_uniform: bool):
|
||||
"""Helper function to apply split_decodes_and_prefills and return
|
||||
the results."""
|
||||
device = torch.device("cpu")
|
||||
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
|
||||
common_metadata = create_common_attn_metadata(BatchSpec(
|
||||
seq_lens=seq_lens, query_lens=query_lens),
|
||||
block_size=16,
|
||||
device=device)
|
||||
return split_decodes_and_prefills(common_metadata,
|
||||
decode_threshold=decode_threshold,
|
||||
require_uniform=require_uniform)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_all_ones():
|
||||
query_lens = [1, 1, 1]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 1, False))
|
||||
assert num_decodes == 3
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == 3
|
||||
assert num_prefill_tokens == 0
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
|
||||
query_lens = [1, 2, 1, 3, 2, 1, 2]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False))
|
||||
assert num_decodes == 7
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == sum(query_lens)
|
||||
assert num_prefill_tokens == 0
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_all_prefills():
|
||||
query_lens = [4, 5, 6, 7]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False))
|
||||
assert num_decodes == 0
|
||||
assert num_prefills == 4
|
||||
assert num_decode_tokens == 0
|
||||
assert num_prefill_tokens == sum(query_lens)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_mixed_batch():
|
||||
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, False))
|
||||
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
|
||||
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
|
||||
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
|
||||
assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_all_ones():
|
||||
query_lens = [1, 1, 1]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 1, True))
|
||||
assert num_decodes == 3
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == 3
|
||||
assert num_prefill_tokens == 0
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_all_short_decodes():
|
||||
query_lens = [2, 2, 1, 3, 2, 1, 2]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True))
|
||||
assert num_decodes == 2
|
||||
assert num_prefills == 5
|
||||
assert num_decode_tokens == 4
|
||||
assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_all_prefills():
|
||||
query_lens = [4, 5, 6, 7]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True))
|
||||
assert num_decodes == 0
|
||||
assert num_prefills == 4
|
||||
assert num_decode_tokens == 0
|
||||
assert num_prefill_tokens == sum(query_lens)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
|
||||
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True))
|
||||
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
|
||||
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
|
||||
assert num_decode_tokens == 6 # 2 + 2 + 2
|
||||
assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
|
||||
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True))
|
||||
assert num_decodes == 1 # only the first 2 is taken as decode
|
||||
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
|
||||
assert num_decode_tokens == 2 # only the first 2
|
||||
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
|
||||
[
|
||||
|
||||
@ -181,6 +181,12 @@ def force_use_trtllm_attention() -> Optional[bool]:
|
||||
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
||||
|
||||
|
||||
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
|
||||
"""Check if the current configuration supports TRTLLM attention."""
|
||||
has_trtllm = supports_trtllm_attention()
|
||||
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
|
||||
|
||||
|
||||
def use_trtllm_attention(
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
@ -188,7 +194,9 @@ def use_trtllm_attention(
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
q_dtype: torch.dtype,
|
||||
is_prefill: bool,
|
||||
has_sinks: bool = False,
|
||||
has_spec: bool = False,
|
||||
) -> bool:
|
||||
"""Return ``True`` if TRTLLM attention is used."""
|
||||
force_use_trtllm = force_use_trtllm_attention()
|
||||
@ -214,6 +222,12 @@ def use_trtllm_attention(
|
||||
)
|
||||
return False
|
||||
|
||||
if has_spec and not is_prefill:
|
||||
# Speculative decoding requires TRTLLM attention for decodes
|
||||
logger.info_once(
|
||||
"Using TRTLLM attention (enabled for speculative decoding).")
|
||||
return True
|
||||
|
||||
# Must use TRTLLM attention if query is FP8 quantized
|
||||
if q_dtype == current_platform.fp8_dtype():
|
||||
if has_sinks:
|
||||
@ -391,6 +405,7 @@ __all__ = [
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_nvidia_artifactory",
|
||||
"supports_trtllm_attention",
|
||||
"can_use_trtllm_attention",
|
||||
"use_trtllm_attention",
|
||||
"flashinfer_disable_q_quantization",
|
||||
"flashinfer_scaled_fp4_mm",
|
||||
|
||||
@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization,
|
||||
from vllm.utils.flashinfer import (can_use_trtllm_attention,
|
||||
flashinfer_disable_q_quantization,
|
||||
supports_trtllm_attention,
|
||||
use_trtllm_attention)
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
@ -223,6 +224,7 @@ class FlashInferMetadata:
|
||||
|
||||
# For flashinfer trtllm batch decode
|
||||
max_q_len: int
|
||||
max_q_len_prefill: int
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table_tensor: torch.Tensor
|
||||
@ -250,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@ -302,6 +304,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
else:
|
||||
self.q_data_type = self.model_config.dtype
|
||||
|
||||
supports_spec_as_decode = \
|
||||
can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
|
||||
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
# Global hyperparameters shared by all attention layers
|
||||
@ -416,7 +422,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=True)
|
||||
|
||||
page_size = self.page_size
|
||||
max_q_len = common_attn_metadata.max_query_len
|
||||
@ -491,20 +498,25 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_last_page_len_np,
|
||||
)
|
||||
|
||||
uses_spec_reorder = self.reorder_batch_threshold > 1
|
||||
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
num_prefill_tokens,
|
||||
max_seq_len,
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
has_sinks=self.has_sinks)
|
||||
is_prefill=True,
|
||||
has_sinks=self.has_sinks,
|
||||
has_spec=uses_spec_reorder)
|
||||
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
num_decode_tokens,
|
||||
max_seq_len,
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
has_sinks=self.has_sinks)
|
||||
is_prefill=False,
|
||||
has_sinks=self.has_sinks,
|
||||
has_spec=uses_spec_reorder)
|
||||
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
|
||||
raise NotImplementedError(
|
||||
"FlashInfer backend currently does not support attention "
|
||||
@ -521,6 +533,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
q_data_type=self.q_data_type,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
max_q_len=max_q_len,
|
||||
max_q_len_prefill=max_q_len,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table_tensor=block_table_tensor,
|
||||
@ -577,6 +590,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
|
||||
prefill_start]
|
||||
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
|
||||
|
||||
# Recompute max_q_len for the slice of requests we are using
|
||||
# for prefills. This can be different from max_q_len when
|
||||
# we have a non-uniform batch with some short decodes offloaded
|
||||
# to the prefill pathway
|
||||
query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
|
||||
attn_metadata.max_q_len_prefill = \
|
||||
int(query_lens_prefill.max().item())
|
||||
|
||||
if not attn_metadata.prefill_use_trtllm:
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr_cpu,
|
||||
@ -607,7 +629,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_decodes <= self._decode_cudagraph_max_bs)
|
||||
if use_cudagraph:
|
||||
num_input_tokens = (
|
||||
self.vllm_config.pad_for_cudagraph(num_decodes))
|
||||
self.vllm_config.pad_for_cudagraph(num_decode_tokens))
|
||||
# Carefully fulfill the padding region with reasonable value
|
||||
# on cpu.
|
||||
# Make sure paged_kv_indptr_cpu is not decreasing
|
||||
@ -621,7 +643,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_decodes:num_input_tokens].fill_(1)
|
||||
|
||||
else:
|
||||
num_input_tokens = num_decodes
|
||||
num_input_tokens = num_decode_tokens
|
||||
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
||||
num_input_tokens, use_cudagraph)
|
||||
@ -842,6 +864,9 @@ class FlashInferImpl(AttentionImpl):
|
||||
output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
|
||||
return output
|
||||
|
||||
# When using spec decoding, num_decodes can be < num_decode_tokens
|
||||
# because some decode requests may have more than one query token.
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
|
||||
@ -874,8 +899,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
prefill_query = prefill_query.contiguous()
|
||||
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||
block_tables_prefill = attn_metadata.block_table_tensor[
|
||||
num_decode_tokens:]
|
||||
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]
|
||||
num_decodes:]
|
||||
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
|
||||
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
@ -919,7 +944,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=mock_block_table,
|
||||
seq_lens=seq_lens_prefill,
|
||||
max_q_len=attn_metadata.max_q_len,
|
||||
max_q_len=attn_metadata.max_q_len_prefill,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
@ -976,6 +1001,14 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert self.o_sf_scale is None
|
||||
out = output[:num_decode_tokens]
|
||||
|
||||
if num_decode_tokens % attn_metadata.num_decodes != 0:
|
||||
# This gets triggered when the dummy_run forces
|
||||
# attention to be initialized with q_len = 0
|
||||
q_len_per_req = 1
|
||||
else:
|
||||
q_len_per_req = \
|
||||
num_decode_tokens // attn_metadata.num_decodes
|
||||
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
@ -989,7 +1022,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
sinks=self.sinks,
|
||||
o_sf_scale=self.o_sf_scale,
|
||||
out=out,
|
||||
)
|
||||
q_len_per_req=q_len_per_req)
|
||||
return output_padded
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Backend for GatedDeltaNet attention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -62,7 +62,7 @@ class GDNAttentionMetadataBuilder(
|
||||
|
||||
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@ -76,7 +76,7 @@ class GDNAttentionMetadataBuilder(
|
||||
else:
|
||||
self.num_spec = 0
|
||||
self.use_spec_decode = self.num_spec > 0
|
||||
self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc]
|
||||
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||
|
||||
self.use_full_cuda_graph = \
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
@ -35,7 +34,7 @@ class LinearAttentionMetadata:
|
||||
class LinearAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[LinearAttentionMetadata]):
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
|
||||
@ -16,7 +16,7 @@ M = TypeVar("M")
|
||||
|
||||
|
||||
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
reorder_batch_threshold: int = 1
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
|
||||
@ -190,7 +190,7 @@ return curr_o @ W_O
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar, Generic, Optional, TypeVar, Union
|
||||
from typing import Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
@ -434,7 +434,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
@staticmethod
|
||||
def determine_chunked_prefill_workspace_size(
|
||||
|
||||
@ -64,7 +64,7 @@ class FlashAttnMLAMetadataBuilder(
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 512
|
||||
reorder_batch_threshold: int = 512
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@ -99,7 +99,7 @@ class FlashAttnMLAMetadataBuilder(
|
||||
|
||||
# TODO(lucas): Until we add support for the DCP custom masking we need
|
||||
# to restrict decodes to q_len == 1 when DCP is enabled.
|
||||
self.__class__.reorder_batch_threshold = 1 \
|
||||
self.reorder_batch_threshold = 1 \
|
||||
if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
|
||||
|
||||
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -41,7 +41,7 @@ class ShortConvAttentionMetadata:
|
||||
class ShortConvAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[ShortConvAttentionMetadata]):
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
|
||||
@ -236,7 +236,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
reorder_batch_threshold: ClassVar[Optional[int]] = None
|
||||
reorder_batch_threshold: Optional[int] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
@ -246,6 +246,22 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
def _init_reorder_batch_threshold(
|
||||
self,
|
||||
reorder_batch_threshold: int = 1,
|
||||
supports_spec_as_decode: bool = False) -> None:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold
|
||||
if self.reorder_batch_threshold is not None \
|
||||
and supports_spec_as_decode:
|
||||
# If the backend supports spec-as-decode kernels, then we can set
|
||||
# the reorder_batch_threshold based on the number of speculative
|
||||
# tokens from the config.
|
||||
speculative_config = self.vllm_config.speculative_config
|
||||
if (speculative_config is not None
|
||||
and speculative_config.num_speculative_tokens is not None):
|
||||
self.reorder_batch_threshold = \
|
||||
1 + speculative_config.num_speculative_tokens
|
||||
|
||||
@abstractmethod
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
@ -703,9 +719,9 @@ def subclass_attention_backend(
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
) -> tuple[int, int, int, int]:
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
require_uniform: bool = False) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||
requests.
|
||||
@ -714,6 +730,9 @@ def split_decodes_and_prefills(
|
||||
common_attn_metadata: CommonAttentionMetadata object containing the
|
||||
batch metadata.
|
||||
decode_threshold: The maximum query length to be considered a decode.
|
||||
require_uniform: If True, requires that all decode requests have the
|
||||
same query length. When set, some queries may be considered prefills
|
||||
even if they are <= decode_threshold, in order to ensure uniformity.
|
||||
|
||||
Returns:
|
||||
num_decodes: The number of decode requests.
|
||||
@ -726,11 +745,20 @@ def split_decodes_and_prefills(
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
if max_query_len <= decode_threshold:
|
||||
if max_query_len <= decode_threshold and \
|
||||
(not require_uniform or decode_threshold <= 1):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
is_prefill = query_lens > decode_threshold
|
||||
if query_lens[0].item() > decode_threshold:
|
||||
# first request is not decode, so no decode requests
|
||||
return 0, num_reqs, 0, num_tokens
|
||||
|
||||
if require_uniform:
|
||||
is_prefill = query_lens != query_lens[0]
|
||||
else:
|
||||
is_prefill = query_lens > decode_threshold
|
||||
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
@ -806,6 +834,38 @@ def reorder_batch_to_split_decodes_and_prefills(
|
||||
return modified_batch
|
||||
|
||||
|
||||
def reshape_query_for_spec_decode(query: torch.Tensor,
|
||||
batch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Reshapes the query tensor for the specified batch size, so that
|
||||
it has shape (batch_size, seq_len, num_heads, head_dim).
|
||||
"""
|
||||
assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
|
||||
total_tokens = query.shape[0]
|
||||
num_heads = query.shape[1]
|
||||
head_dim = query.shape[2]
|
||||
assert total_tokens % batch_size == 0, (
|
||||
f"{total_tokens=} is not divisible by {batch_size=}")
|
||||
seq_len = total_tokens // batch_size
|
||||
return query.view(batch_size, seq_len, num_heads, head_dim)
|
||||
|
||||
|
||||
def reshape_attn_output_for_spec_decode(
|
||||
attn_output: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reshapes the attention output tensor, so that
|
||||
the batch_size and seq_len dimensions are combined.
|
||||
"""
|
||||
if attn_output.dim() == 3:
|
||||
# Already in the correct shape
|
||||
return attn_output
|
||||
assert attn_output.dim() == 4, \
|
||||
f"attn_output must be 4D, got {attn_output.dim()}D"
|
||||
total_tokens = attn_output.shape[0] * attn_output.shape[1]
|
||||
return attn_output.view(total_tokens, attn_output.shape[2],
|
||||
attn_output.shape[3])
|
||||
|
||||
|
||||
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
|
||||
('logits_indices_padded', Optional[torch.Tensor], None),
|
||||
('num_logits_indices', int, 0),
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Attention layer with XFormersAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -197,7 +197,7 @@ class XFormersAttentionMetadata:
|
||||
class XFormersAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[XFormersAttentionMetadata]):
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import ast
|
||||
from dataclasses import replace
|
||||
from importlib.util import find_spec
|
||||
from typing import Optional, Protocol
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -37,17 +37,6 @@ logger = init_logger(__name__)
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
class EagleAttentionMetadata(Protocol):
|
||||
# Required attributes
|
||||
num_actual_tokens: int
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
|
||||
class EagleProposer:
|
||||
|
||||
def __init__(
|
||||
@ -120,7 +109,7 @@ class EagleProposer:
|
||||
with_numpy=True)
|
||||
|
||||
# Determine allowed attention backends once during initialization.
|
||||
self.allowed_attn_types: tuple[type, ...]
|
||||
self.allowed_attn_types: Optional[tuple] = None
|
||||
if current_platform.is_rocm():
|
||||
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
|
||||
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
|
||||
@ -129,9 +118,6 @@ class EagleProposer:
|
||||
AiterFlashAttentionMetadata)
|
||||
rocm_types.append(AiterFlashAttentionMetadata)
|
||||
self.allowed_attn_types = tuple(rocm_types)
|
||||
else:
|
||||
self.allowed_attn_types = (FlashAttentionMetadata,
|
||||
TreeAttentionMetadata)
|
||||
|
||||
# Parse the speculative token tree.
|
||||
spec_token_tree = self.speculative_config.speculative_token_tree
|
||||
@ -266,7 +252,8 @@ class EagleProposer:
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
if not isinstance(attn_metadata, self.allowed_attn_types):
|
||||
if self.allowed_attn_types is not None and \
|
||||
not isinstance(attn_metadata, self.allowed_attn_types):
|
||||
raise ValueError(
|
||||
f"Unsupported attention metadata type for speculative "
|
||||
"decoding with num_speculative_tokens > 1: "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user