[Speculative Decoding] Fixing hidden states handling in batch expansion (#7508)

This commit is contained in:
Abhinav Goyal 2024-08-20 06:28:14 +05:30 committed by GitHub
parent e54ebc2f8f
commit 312f761232
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 139 additions and 41 deletions

View File

@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
ensure_all_accepted=ensure_all_accepted)
def run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False):
def run_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero (or when temperature is > 0 and seeded).
@ -357,5 +359,10 @@ def run_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids
print(f'{acceptance_rate=}')
if ensure_all_accepted:
assert acceptance_rate == 1.0
if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2

View File

@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize("output_len", [2048])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify acceptance rate with different batch size and large output
length."""
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=0.0,
seeded=True,
force_output_len=True,
expected_acceptance_rate=0.48)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{

View File

@ -1,6 +1,6 @@
from array import array
from itertools import chain, count
from typing import Iterator, List, Tuple
from typing import Iterator, List, Optional, Tuple
import torch
@ -88,21 +88,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
all_tokens, all_probs, spec_logprobs = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
(all_tokens, all_probs, spec_logprobs,
all_hidden_states) = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
logprobs=spec_logprobs,
hidden_states=target_sampler_output.hidden_states,
hidden_states=all_hidden_states,
)
def _expand_batch(
@ -145,10 +146,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
num_scoring_tokens)
def _contract_batch(
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int], k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
@ -156,9 +158,10 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(target_token_ids, target_probs, target_logprobs,
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = self._split_scoring_output(
non_spec_target_logprobs,
non_spec_target_hidden_states) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
@ -176,23 +179,40 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
spec_expanded_bs, k + 1, target_hidden_states.shape[-1])
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))
if target_sampler_output.hidden_states is not None:
all_hidden_states = target_hidden_states.new_zeros(
size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
else:
all_hidden_states = None
if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
if all_hidden_states is not None:
all_hidden_states[
non_spec_indices, :1, :] = non_spec_target_hidden_states
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
return all_tokens, all_probs, all_logprobs
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states
return all_tokens, all_probs, all_logprobs, all_hidden_states
def _create_scoring_model_input(
self,
@ -327,8 +347,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
"""Split the target model output into speculative and non-speculative
output.
"""
@ -353,24 +374,37 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)
if sampler_output.hidden_states is not None:
(
spec_hidden_states,
non_spec_hidden_states,
) = sampler_output.hidden_states.split(split_sizes)
else:
spec_hidden_states, non_spec_hidden_states = None, None
# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
sampler_output.logprobs = spec_logprobs
(target_token_ids, target_probs,
target_logprobs) = sampler_output_to_torch([sampler_output], True)
sampler_output.hidden_states = spec_hidden_states
(target_token_ids, target_probs, target_logprobs,
target_hidden_states) = sampler_output_to_torch([sampler_output],
True)
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
sampler_output.logprobs = non_spec_logprobs
sampler_output.hidden_states = non_spec_hidden_states
(non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
True)
non_spec_target_logprobs,
non_spec_target_hidden_states) = sampler_output_to_torch(
[sampler_output], True)
return (target_token_ids, target_probs, target_logprobs,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs)
target_hidden_states, non_spec_target_token_ids,
non_spec_target_probs, non_spec_target_logprobs,
non_spec_target_hidden_states)
def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:

View File

@ -646,9 +646,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states = proposal_scores.hidden_states
if hidden_states is not None:
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[1]
hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
hs_size)
hs_size = hidden_states.shape[-1]
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
index = accepted_index[:, None, None].expand(-1, 1, hs_size)

View File

@ -242,7 +242,7 @@ class Top1Proposer(SpeculativeProposer):
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
sampler_output, sampler_transposed)
# Now, reformat the output GPU tensors such that each sequence has

View File

@ -123,7 +123,7 @@ def split_batch_by_proposal_len(
def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
@ -169,7 +169,23 @@ def sampler_output_to_torch(
if sampler_transposed:
sampled_token_ids = sampled_token_ids.transpose(0, 1)
return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
if sampler_output_list[0].hidden_states is not None:
# shape: [batch_size, num_sampler_output, hidden_dim]
sampled_hidden_states = torch.stack(
[
sampler_output.hidden_states
for sampler_output in sampler_output_list
],
dim=0,
)
if sampler_transposed:
sampled_hidden_states = sampled_hidden_states.transpose(0, 1)
else:
sampled_hidden_states = None
return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
sampled_hidden_states)
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,