mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:44:57 +08:00
[SpecDec] Remove Batch Expansion (2/3) (#9298)
This commit is contained in:
parent
ec10cb8511
commit
89feb4c84d
@ -1,3 +1,6 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -10,31 +13,45 @@ from vllm.worker.worker import Worker
|
||||
from .utils import create_batch, create_worker
|
||||
|
||||
|
||||
def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
|
||||
def create_proposal(propose_lens: List[int], vocab_size: int,
|
||||
device: str) -> SpeculativeProposals:
|
||||
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
|
||||
batch_size = len(propose_lens)
|
||||
max_propose_len = max(propose_lens)
|
||||
proposal_probs = torch.rand((batch_size, max_propose_len, vocab_size),
|
||||
device=device)
|
||||
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
|
||||
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
|
||||
|
||||
proposal_token_ids = torch.full((batch_size, max_propose_len),
|
||||
fill_value=-1,
|
||||
device=device)
|
||||
for i in range(batch_size):
|
||||
proposal_token_ids[i][:propose_lens[i]] = torch.argmax(
|
||||
proposal_probs[i][:propose_lens[i]], dim=-1)
|
||||
|
||||
propose_lens = torch.tensor(propose_lens, device=device)
|
||||
return SpeculativeProposals(proposal_token_ids, proposal_probs,
|
||||
proposal_lens)
|
||||
propose_lens)
|
||||
|
||||
|
||||
def assert_score_equal(score1: SpeculativeScores,
|
||||
score2: SpeculativeScores) -> None:
|
||||
assert torch.allclose(score1.probs, score2.probs)
|
||||
assert torch.allclose(score1.logprobs, score2.logprobs)
|
||||
assert torch.equal(score1.token_ids, score2.token_ids)
|
||||
assert torch.equal(
|
||||
score1.token_ids,
|
||||
score2.token_ids), f"{score1.token_ids}, {score2.token_ids}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize('propose_len', [1, 3, 5])
|
||||
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
|
||||
@pytest.mark.parametrize('mixed_propose_len', [True])
|
||||
@pytest.mark.parametrize('device', ['cuda'])
|
||||
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
|
||||
device: str) -> None:
|
||||
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
||||
mixed_propose_len: bool, device: str) -> None:
|
||||
"""
|
||||
Compare the batch expansion scorer and mqa scorer return the same score
|
||||
Compare the batch expansion scorer and mqa scorer return the same score.
|
||||
We test for both queries with the same propose length and different
|
||||
propose length.
|
||||
"""
|
||||
seed = 0
|
||||
block_size = 32
|
||||
@ -46,13 +63,22 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int,
|
||||
should_modify_greedy_probs_inplace = True
|
||||
|
||||
vocab_size = scorer_worker.vocab_size
|
||||
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
|
||||
|
||||
if not mixed_propose_len:
|
||||
propose_lens = [max_propose_len] * batch_size
|
||||
else:
|
||||
non_zero_cnt = random.randint(0, batch_size)
|
||||
propose_lens = [max_propose_len
|
||||
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
|
||||
random.shuffle(propose_lens)
|
||||
|
||||
proposals = create_proposal(propose_lens, vocab_size, device)
|
||||
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
||||
propose_len,
|
||||
max_propose_len,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
requests = ExecuteModelRequest(seq_group_metadatalist,
|
||||
num_lookahead_slots=propose_len)
|
||||
num_lookahead_slots=max_propose_len)
|
||||
|
||||
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
|
||||
vocab_size)
|
||||
|
||||
@ -186,11 +186,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Number of query tokens for each request in the batch.
|
||||
# Currently, we require that all requests have the same number of query
|
||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||
decode_query_len: Optional[int] = None
|
||||
# Max number of query tokens for among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
_cached_prefill_metadata: Optional[
|
||||
"BlocksparseFlashAttentionMetadata"] = None
|
||||
|
||||
@ -111,11 +111,8 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int]
|
||||
|
||||
# Number of query tokens for each request in the batch.
|
||||
# Currently, we require that all requests have the same number of query
|
||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||
decode_query_len: Optional[int]
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int]
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
@ -173,9 +170,9 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
decode_query_len=0,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_query_len=0,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
@ -202,12 +199,14 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
decode_query_len=self.decode_query_len,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
query_start_loc=self.query_start_loc[self.num_prefills:]
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
@ -413,9 +412,9 @@ class FlashAttentionMetadataBuilder(
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
decode_query_len = max(decode_query_lens)
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
decode_query_len = 1
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
@ -468,7 +467,7 @@ class FlashAttentionMetadataBuilder(
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
decode_query_len=decode_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc,
|
||||
@ -714,20 +713,37 @@ def unified_flash_attention(
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
_, num_head, head_dim = decode_query.shape
|
||||
decode_query = decode_query.reshape(-1, decode_meta.decode_query_len,
|
||||
num_head, head_dim)
|
||||
decode_output = flash_attn_with_kvcache(
|
||||
q=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
).squeeze(1)
|
||||
# Use flash_attn_varlen_func kernel for speculative decoding
|
||||
# because different queries might have different lengths.
|
||||
assert decode_meta.max_decode_query_len is not None
|
||||
if decode_meta.max_decode_query_len > 1:
|
||||
decode_output = flash_attn_varlen_func(
|
||||
q=decode_query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=decode_meta.query_start_loc,
|
||||
max_seqlen_q=decode_meta.max_decode_query_len,
|
||||
cu_seqlens_k=decode_meta.seq_start_loc,
|
||||
max_seqlen_k=decode_meta.max_decode_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
block_table=decode_meta.block_tables,
|
||||
)
|
||||
else:
|
||||
# Use flash_attn_with_kvcache for normal decoding.
|
||||
decode_output = flash_attn_with_kvcache(
|
||||
q=decode_query.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
).squeeze(1)
|
||||
|
||||
if prefill_output is None:
|
||||
assert decode_output is not None
|
||||
@ -739,7 +755,6 @@ def unified_flash_attention(
|
||||
# Chunked prefill does not work with speculative decoding.
|
||||
# Therefore, the query length for decode should be 1 in chunked prefill.
|
||||
assert decode_meta is not None
|
||||
assert decode_meta.decode_query_len == 1
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
@ -121,11 +121,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Number of query tokens for each request in the batch.
|
||||
# Currently, we require that all requests have the same number of query
|
||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||
decode_query_len: Optional[int] = None
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
|
||||
@ -313,7 +313,7 @@ class CommonAttentionState(AttentionState):
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=1,
|
||||
decode_query_len=1,
|
||||
max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
|
||||
@ -118,11 +118,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Number of query tokens for each request in the batch.
|
||||
# Currently, we require that all requests have the same number of query
|
||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||
decode_query_len: Optional[int] = None
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
|
||||
@ -18,6 +18,7 @@ class MQAScorer(SpeculativeScorer):
|
||||
target_seq_id_start = max(
|
||||
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
|
||||
all_proposal_tokens = proposals.proposal_token_ids.tolist()
|
||||
all_proposal_lengths = proposals.proposal_lens.tolist()
|
||||
for i, seq_group_metadata in enumerate(
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
seq_data_dict = seq_group_metadata.seq_data
|
||||
@ -27,7 +28,8 @@ class MQAScorer(SpeculativeScorer):
|
||||
seq_data: SequenceData = seq_data_dict[seq_id]
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
output_token_ids = seq_data.get_output_token_ids()
|
||||
proposal_token_ids = all_proposal_tokens[i]
|
||||
proposal_token_ids = all_proposal_tokens[
|
||||
i][:all_proposal_lengths[i]]
|
||||
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
|
||||
|
||||
target_seq_id = target_seq_id_start + i
|
||||
@ -62,18 +64,42 @@ class MQAScorer(SpeculativeScorer):
|
||||
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
bs, k = proposals.proposal_token_ids.shape
|
||||
all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1)
|
||||
|
||||
all_probs = target_sampler_output.sampled_token_probs.reshape(
|
||||
bs, k + 1, self._vocab_size)
|
||||
all_logprobs = target_sampler_output.logprobs.reshape(
|
||||
bs, k + 1, self._vocab_size)
|
||||
k = execute_model_req.num_lookahead_slots
|
||||
bs = len(execute_model_req.seq_group_metadata_list)
|
||||
target_token_ids = target_sampler_output.sampled_token_ids
|
||||
target_probs = target_sampler_output.sampled_token_probs
|
||||
target_logprobs = target_sampler_output.logprobs
|
||||
# If all requests have the same number of query tokens, we can avoid
|
||||
# the for loop to build output for better performance.
|
||||
if min(all_proposal_lengths) == k:
|
||||
bs, _ = proposals.proposal_token_ids.shape
|
||||
all_tokens = target_token_ids.reshape(bs, k + 1)
|
||||
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
|
||||
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
|
||||
else:
|
||||
all_tokens = target_token_ids.new_full(size=(bs, k + 1),
|
||||
fill_value=-1)
|
||||
all_probs = target_probs.new_zeros(*all_tokens.shape,
|
||||
self._vocab_size)
|
||||
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
||||
fill_value=-float("inf"))
|
||||
target_token_ids = target_token_ids.flatten()
|
||||
start_loc = 0
|
||||
for i, proposed_len in enumerate(all_proposal_lengths):
|
||||
output_len = proposed_len + 1
|
||||
end_loc = start_loc + output_len
|
||||
all_tokens[
|
||||
i, :output_len] = target_token_ids[start_loc:end_loc]
|
||||
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
|
||||
all_logprobs[
|
||||
i, :output_len] = target_logprobs[start_loc:end_loc]
|
||||
start_loc = end_loc
|
||||
|
||||
hidden_states = None
|
||||
if target_sampler_output.hidden_states is not None:
|
||||
hidden_states = target_sampler_output.hidden_states.reshape(
|
||||
bs, (k + 1), -1)
|
||||
|
||||
return SpeculativeScores(probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=all_logprobs,
|
||||
|
||||
@ -190,12 +190,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"MQA is only available with flash attn backend.")
|
||||
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
disable_mqa_scorer = True
|
||||
logger.info(
|
||||
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"NGramWorker does not support MQA scorer.")
|
||||
|
||||
if "model_config" in draft_worker_kwargs and \
|
||||
draft_worker_kwargs["model_config"].max_model_len < \
|
||||
scorer_worker.model_config.max_model_len:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user