mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 16:35:01 +08:00
[Speculative Decoding] Fixing hidden states handling in batch expansion (#7508)
This commit is contained in:
parent
e54ebc2f8f
commit
312f761232
@ -288,7 +288,8 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
|||||||
ensure_all_accepted=ensure_all_accepted)
|
ensure_all_accepted=ensure_all_accepted)
|
||||||
|
|
||||||
|
|
||||||
def run_equality_correctness_test(baseline_llm_generator,
|
def run_equality_correctness_test(
|
||||||
|
baseline_llm_generator,
|
||||||
test_llm_generator,
|
test_llm_generator,
|
||||||
batch_size,
|
batch_size,
|
||||||
max_output_len,
|
max_output_len,
|
||||||
@ -296,7 +297,8 @@ def run_equality_correctness_test(baseline_llm_generator,
|
|||||||
temperature: float,
|
temperature: float,
|
||||||
seeded: bool,
|
seeded: bool,
|
||||||
print_tokens: bool = False,
|
print_tokens: bool = False,
|
||||||
ensure_all_accepted: bool = False):
|
ensure_all_accepted: bool = False,
|
||||||
|
expected_acceptance_rate: Optional[float] = None):
|
||||||
"""Helper method that compares the outputs of both the baseline LLM and
|
"""Helper method that compares the outputs of both the baseline LLM and
|
||||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||||
the same when temperature is zero (or when temperature is > 0 and seeded).
|
the same when temperature is zero (or when temperature is > 0 and seeded).
|
||||||
@ -357,5 +359,10 @@ def run_equality_correctness_test(baseline_llm_generator,
|
|||||||
print(f'{i=} {spec_token_ids=}')
|
print(f'{i=} {spec_token_ids=}')
|
||||||
assert baseline_token_ids == spec_token_ids
|
assert baseline_token_ids == spec_token_ids
|
||||||
|
|
||||||
|
print(f'{acceptance_rate=}')
|
||||||
|
|
||||||
if ensure_all_accepted:
|
if ensure_all_accepted:
|
||||||
assert acceptance_rate == 1.0
|
assert acceptance_rate == 1.0
|
||||||
|
|
||||||
|
if expected_acceptance_rate is not None:
|
||||||
|
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
||||||
|
|||||||
@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
|||||||
force_output_len=True)
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [2048])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
|
||||||
|
batch_size: int, output_len: int):
|
||||||
|
"""Verify acceptance rate with different batch size and large output
|
||||||
|
length."""
|
||||||
|
run_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
temperature=0.0,
|
||||||
|
seeded=True,
|
||||||
|
force_output_len=True,
|
||||||
|
expected_acceptance_rate=0.48)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from array import array
|
from array import array
|
||||||
from itertools import chain, count
|
from itertools import chain, count
|
||||||
from typing import Iterator, List, Tuple
|
from typing import Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -88,7 +88,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||||
target_sampler_output = target_sampler_output[0]
|
target_sampler_output = target_sampler_output[0]
|
||||||
|
|
||||||
all_tokens, all_probs, spec_logprobs = self._contract_batch(
|
(all_tokens, all_probs, spec_logprobs,
|
||||||
|
all_hidden_states) = self._contract_batch(
|
||||||
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
||||||
target_sampler_output=target_sampler_output,
|
target_sampler_output=target_sampler_output,
|
||||||
proposals=proposals,
|
proposals=proposals,
|
||||||
@ -102,7 +103,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
probs=all_probs,
|
probs=all_probs,
|
||||||
token_ids=all_tokens,
|
token_ids=all_tokens,
|
||||||
logprobs=spec_logprobs,
|
logprobs=spec_logprobs,
|
||||||
hidden_states=target_sampler_output.hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _expand_batch(
|
def _expand_batch(
|
||||||
@ -147,8 +148,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
def _contract_batch(
|
def _contract_batch(
|
||||||
self, contracted_bs: int, target_sampler_output: SamplerOutput,
|
self, contracted_bs: int, target_sampler_output: SamplerOutput,
|
||||||
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
||||||
non_spec_indices: List[int], spec_indices: List[int],
|
non_spec_indices: List[int], spec_indices: List[int], k: int
|
||||||
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
"""Contract the expanded batch back into its original size.
|
"""Contract the expanded batch back into its original size.
|
||||||
This maps the scores of speculative tokens back to their original
|
This maps the scores of speculative tokens back to their original
|
||||||
sequences.
|
sequences.
|
||||||
@ -156,9 +158,10 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
contracted_bs is the original batch size, and the batch size that the
|
contracted_bs is the original batch size, and the batch size that the
|
||||||
target_sampler_output will be contracted to.
|
target_sampler_output will be contracted to.
|
||||||
"""
|
"""
|
||||||
(target_token_ids, target_probs, target_logprobs,
|
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
|
||||||
non_spec_target_token_ids, non_spec_target_probs,
|
non_spec_target_token_ids, non_spec_target_probs,
|
||||||
non_spec_target_logprobs) = self._split_scoring_output(
|
non_spec_target_logprobs,
|
||||||
|
non_spec_target_hidden_states) = self._split_scoring_output(
|
||||||
target_sampler_output, num_scoring_tokens)
|
target_sampler_output, num_scoring_tokens)
|
||||||
|
|
||||||
# Map distinct sequences used to score each token
|
# Map distinct sequences used to score each token
|
||||||
@ -176,23 +179,40 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
self._vocab_size)
|
self._vocab_size)
|
||||||
target_logprobs = target_logprobs.reshape(target_probs.shape)
|
target_logprobs = target_logprobs.reshape(target_probs.shape)
|
||||||
|
|
||||||
|
if target_hidden_states is not None:
|
||||||
|
target_hidden_states = target_hidden_states.reshape(
|
||||||
|
spec_expanded_bs, k + 1, target_hidden_states.shape[-1])
|
||||||
|
|
||||||
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
|
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
|
||||||
fill_value=-1)
|
fill_value=-1)
|
||||||
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
|
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
|
||||||
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
||||||
fill_value=-float("inf"))
|
fill_value=-float("inf"))
|
||||||
|
|
||||||
|
if target_sampler_output.hidden_states is not None:
|
||||||
|
all_hidden_states = target_hidden_states.new_zeros(
|
||||||
|
size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
|
||||||
|
else:
|
||||||
|
all_hidden_states = None
|
||||||
|
|
||||||
if non_spec_indices:
|
if non_spec_indices:
|
||||||
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
||||||
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
||||||
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
|
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
|
||||||
|
|
||||||
|
if all_hidden_states is not None:
|
||||||
|
all_hidden_states[
|
||||||
|
non_spec_indices, :1, :] = non_spec_target_hidden_states
|
||||||
|
|
||||||
if spec_indices:
|
if spec_indices:
|
||||||
all_tokens[spec_indices] = target_token_ids
|
all_tokens[spec_indices] = target_token_ids
|
||||||
all_probs[spec_indices] = target_probs
|
all_probs[spec_indices] = target_probs
|
||||||
all_logprobs[spec_indices] = target_logprobs
|
all_logprobs[spec_indices] = target_logprobs
|
||||||
|
|
||||||
return all_tokens, all_probs, all_logprobs
|
if all_hidden_states is not None:
|
||||||
|
all_hidden_states[spec_indices] = target_hidden_states
|
||||||
|
|
||||||
|
return all_tokens, all_probs, all_logprobs, all_hidden_states
|
||||||
|
|
||||||
def _create_scoring_model_input(
|
def _create_scoring_model_input(
|
||||||
self,
|
self,
|
||||||
@ -327,8 +347,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
|
|
||||||
def _split_scoring_output(
|
def _split_scoring_output(
|
||||||
self, sampler_output: SamplerOutput, num_scoring_tokens: int
|
self, sampler_output: SamplerOutput, num_scoring_tokens: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||||
torch.Tensor, torch.Tensor]:
|
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||||
|
torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""Split the target model output into speculative and non-speculative
|
"""Split the target model output into speculative and non-speculative
|
||||||
output.
|
output.
|
||||||
"""
|
"""
|
||||||
@ -353,24 +374,37 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
non_spec_logprobs,
|
non_spec_logprobs,
|
||||||
) = sampler_output.logprobs.split(split_sizes)
|
) = sampler_output.logprobs.split(split_sizes)
|
||||||
|
|
||||||
|
if sampler_output.hidden_states is not None:
|
||||||
|
(
|
||||||
|
spec_hidden_states,
|
||||||
|
non_spec_hidden_states,
|
||||||
|
) = sampler_output.hidden_states.split(split_sizes)
|
||||||
|
else:
|
||||||
|
spec_hidden_states, non_spec_hidden_states = None, None
|
||||||
|
|
||||||
# Convert scores to tensors.
|
# Convert scores to tensors.
|
||||||
sampler_output.sampled_token_probs = spec_probs
|
sampler_output.sampled_token_probs = spec_probs
|
||||||
sampler_output.sampled_token_ids = spec_sampled_tokens
|
sampler_output.sampled_token_ids = spec_sampled_tokens
|
||||||
sampler_output.logprobs = spec_logprobs
|
sampler_output.logprobs = spec_logprobs
|
||||||
(target_token_ids, target_probs,
|
sampler_output.hidden_states = spec_hidden_states
|
||||||
target_logprobs) = sampler_output_to_torch([sampler_output], True)
|
(target_token_ids, target_probs, target_logprobs,
|
||||||
|
target_hidden_states) = sampler_output_to_torch([sampler_output],
|
||||||
|
True)
|
||||||
|
|
||||||
# Convert non-speculative output tokens to tensors.
|
# Convert non-speculative output tokens to tensors.
|
||||||
sampler_output.sampled_token_probs = non_spec_probs
|
sampler_output.sampled_token_probs = non_spec_probs
|
||||||
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
||||||
sampler_output.logprobs = non_spec_logprobs
|
sampler_output.logprobs = non_spec_logprobs
|
||||||
|
sampler_output.hidden_states = non_spec_hidden_states
|
||||||
(non_spec_target_token_ids, non_spec_target_probs,
|
(non_spec_target_token_ids, non_spec_target_probs,
|
||||||
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
|
non_spec_target_logprobs,
|
||||||
True)
|
non_spec_target_hidden_states) = sampler_output_to_torch(
|
||||||
|
[sampler_output], True)
|
||||||
|
|
||||||
return (target_token_ids, target_probs, target_logprobs,
|
return (target_token_ids, target_probs, target_logprobs,
|
||||||
non_spec_target_token_ids, non_spec_target_probs,
|
target_hidden_states, non_spec_target_token_ids,
|
||||||
non_spec_target_logprobs)
|
non_spec_target_probs, non_spec_target_logprobs,
|
||||||
|
non_spec_target_hidden_states)
|
||||||
|
|
||||||
def _create_target_seq_id_iterator(
|
def _create_target_seq_id_iterator(
|
||||||
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||||
|
|||||||
@ -646,9 +646,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
hidden_states = proposal_scores.hidden_states
|
hidden_states = proposal_scores.hidden_states
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
# Contract hidden states based on accepted tokens
|
# Contract hidden states based on accepted tokens
|
||||||
hs_size = hidden_states.shape[1]
|
hs_size = hidden_states.shape[-1]
|
||||||
hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
|
|
||||||
hs_size)
|
|
||||||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
|
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
|
||||||
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
|
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
|
||||||
|
|||||||
@ -242,7 +242,7 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||||
|
|
||||||
sampler_output = maybe_sampler_output
|
sampler_output = maybe_sampler_output
|
||||||
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
|
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
|
||||||
sampler_output, sampler_transposed)
|
sampler_output, sampler_transposed)
|
||||||
|
|
||||||
# Now, reformat the output GPU tensors such that each sequence has
|
# Now, reformat the output GPU tensors such that each sequence has
|
||||||
|
|||||||
@ -123,7 +123,7 @@ def split_batch_by_proposal_len(
|
|||||||
|
|
||||||
def sampler_output_to_torch(
|
def sampler_output_to_torch(
|
||||||
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
|
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""Utility function which converts a list of SamplerOutput to tensors.
|
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||||
|
|
||||||
sampler_transposed here is used as the indicator for whether
|
sampler_transposed here is used as the indicator for whether
|
||||||
@ -169,7 +169,23 @@ def sampler_output_to_torch(
|
|||||||
if sampler_transposed:
|
if sampler_transposed:
|
||||||
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
||||||
|
|
||||||
return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
|
if sampler_output_list[0].hidden_states is not None:
|
||||||
|
# shape: [batch_size, num_sampler_output, hidden_dim]
|
||||||
|
sampled_hidden_states = torch.stack(
|
||||||
|
[
|
||||||
|
sampler_output.hidden_states
|
||||||
|
for sampler_output in sampler_output_list
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if sampler_transposed:
|
||||||
|
sampled_hidden_states = sampled_hidden_states.transpose(0, 1)
|
||||||
|
else:
|
||||||
|
sampled_hidden_states = None
|
||||||
|
|
||||||
|
return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
|
||||||
|
sampled_hidden_states)
|
||||||
|
|
||||||
|
|
||||||
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
|
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user