mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 05:04:58 +08:00
[Speculative decoding] Improve n-gram efficiency (#4724)
This commit is contained in:
parent
8bc68e198c
commit
ce532ff45c
@ -34,8 +34,8 @@ def test_ngram_algo_correctness_for_single_no_match():
|
|||||||
max_proposal_len=20,
|
max_proposal_len=20,
|
||||||
)
|
)
|
||||||
|
|
||||||
# set ngram window (0, 3], which is window=1/2/3
|
# set ngram window [1, 3], which is window=1/2/3
|
||||||
ngram_worker.set_ngram_window_size(0, 3)
|
ngram_worker.set_ngram_window_size(1, 3)
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
# shall find no candidate
|
# shall find no candidate
|
||||||
@ -90,8 +90,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
|||||||
max_proposal_len=20,
|
max_proposal_len=20,
|
||||||
)
|
)
|
||||||
|
|
||||||
# set ngram window (0, 3], which is window=1/2/3
|
# set ngram window [1, 3], which is window=1/2/3
|
||||||
ngram_worker.set_ngram_window_size(0, 3)
|
ngram_worker.set_ngram_window_size(1, 3)
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
# shall find no candidate
|
# shall find no candidate
|
||||||
@ -128,11 +128,12 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
|||||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
|
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
|
||||||
assert proposals.proposal_lens.shape == torch.Size([5])
|
assert proposals.proposal_lens.shape == torch.Size([5])
|
||||||
|
|
||||||
|
# the first sequence has no match so proposal_len should be overwritten to 0
|
||||||
assert proposals.proposal_lens.tolist(
|
assert proposals.proposal_lens.tolist(
|
||||||
) == [proposal_len for _ in range(4)] + [0]
|
) == [0] + [proposal_len for _ in range(3)] + [0]
|
||||||
|
|
||||||
for i in range(proposal_len):
|
for i in range(proposal_len):
|
||||||
assert proposals.proposal_token_ids[0][i] == 0
|
assert proposals.proposal_token_ids[0][i] == -1
|
||||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
|
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
|
||||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
|
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
|
||||||
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
|
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
|
||||||
@ -167,8 +168,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
|||||||
max_proposal_len=20,
|
max_proposal_len=20,
|
||||||
)
|
)
|
||||||
|
|
||||||
# set ngram window (0, 3], which is window=1/2/3
|
# set ngram window [0, 3], which is window=1/2/3
|
||||||
ngram_worker.set_ngram_window_size(0, 3)
|
ngram_worker.set_ngram_window_size(1, 3)
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
# shall find candidate 12,13,14,15,16
|
# shall find candidate 12,13,14,15,16
|
||||||
|
|||||||
@ -784,12 +784,15 @@ class SpeculativeConfig:
|
|||||||
draft_quantization = None
|
draft_quantization = None
|
||||||
|
|
||||||
if speculative_model == "[ngram]":
|
if speculative_model == "[ngram]":
|
||||||
assert (ngram_prompt_lookup_max is not None
|
|
||||||
and ngram_prompt_lookup_max > 0)
|
|
||||||
if ngram_prompt_lookup_min is None:
|
if ngram_prompt_lookup_min is None:
|
||||||
ngram_prompt_lookup_min = 0
|
ngram_prompt_lookup_min = 1
|
||||||
else:
|
if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
|
||||||
assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
|
raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
|
||||||
|
if ngram_prompt_lookup_min < 1:
|
||||||
|
raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
|
||||||
|
if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
|
||||||
|
raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
|
||||||
|
f"larger than {ngram_prompt_lookup_max=}")
|
||||||
|
|
||||||
# TODO: current we still need extract vocab_size from target model
|
# TODO: current we still need extract vocab_size from target model
|
||||||
# config, in future, we may try refactor it out, and set
|
# config, in future, we may try refactor it out, and set
|
||||||
|
|||||||
@ -77,9 +77,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||||||
"""
|
"""
|
||||||
self._raise_if_unsupported(execute_model_req)
|
self._raise_if_unsupported(execute_model_req)
|
||||||
|
|
||||||
arr = []
|
|
||||||
has_spec_out = False
|
has_spec_out = False
|
||||||
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
|
token_id_list = []
|
||||||
|
token_prob_list = []
|
||||||
|
for idx, seq_group_metadata in enumerate(
|
||||||
|
execute_model_req.seq_group_metadata_list):
|
||||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||||
|
|
||||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||||
@ -89,59 +91,64 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
for ngram_size in range(
|
for ngram_size in range(
|
||||||
min(self.ngram_prompt_lookup_max, input_length - 1),
|
min(self.ngram_prompt_lookup_max, input_length - 1),
|
||||||
self.ngram_prompt_lookup_min,
|
self.ngram_prompt_lookup_min - 1,
|
||||||
-1,
|
-1,
|
||||||
):
|
):
|
||||||
ngram_tensor = input_ids[-1 * ngram_size:]
|
ngram_tensor = input_ids[-ngram_size:]
|
||||||
windows = input_ids.unfold(dimension=0,
|
proposal_start_idx = None
|
||||||
size=ngram_size,
|
if ngram_size == 1:
|
||||||
step=1)
|
# Do not match itself and do not use unfold and all
|
||||||
matches = (windows == ngram_tensor).all(dim=1)
|
matches = (input_ids[:-1] == ngram_tensor)
|
||||||
match_indices = matches.nonzero(as_tuple=True)[0]
|
else:
|
||||||
if match_indices.size()[0] > 1:
|
windows = input_ids.unfold(dimension=0,
|
||||||
has_spec_out = True
|
size=ngram_size,
|
||||||
res = seq_data.get_token_ids()
|
step=1)
|
||||||
res = res[match_indices[0] + ngram_size:match_indices[0] +
|
# Do not match itself
|
||||||
ngram_size + sample_len]
|
matches = (windows[:-1] == ngram_tensor).all(dim=-1)
|
||||||
res_len = len(res)
|
|
||||||
# pad 0 towards output as sample_len tokens required
|
|
||||||
res += [0] * (sample_len - res_len)
|
|
||||||
|
|
||||||
|
# first_match includes "values" (bool), indicating whether
|
||||||
|
# the match is found, and "indices", indicating the index
|
||||||
|
# of the first match.
|
||||||
|
# Note that "first_match.values.item()" triggers GPU-CPU
|
||||||
|
# sync so it is a bit inefficient, but we have not found
|
||||||
|
# a better way to do this.
|
||||||
|
first_match = matches.max(dim=-1)
|
||||||
|
if first_match.values.item():
|
||||||
|
proposal_start_idx = first_match.indices.add_(ngram_size)
|
||||||
|
spec_indices = (
|
||||||
|
proposal_start_idx).repeat(sample_len) + torch.arange(
|
||||||
|
sample_len, device=self.device)
|
||||||
|
spec_indices.clamp_(max=input_ids.shape[-1] - 1)
|
||||||
|
res = input_ids.gather(dim=-1, index=spec_indices)
|
||||||
|
token_id_list.append(res)
|
||||||
|
token_prob_list.append(
|
||||||
|
torch.nn.functional.one_hot(
|
||||||
|
res,
|
||||||
|
num_classes=self.vocab_size).to(torch.float32))
|
||||||
|
has_spec_out = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# if no candidate found, fill with 0
|
token_id_list.append(None)
|
||||||
res = [0] * sample_len
|
token_prob_list.append(None)
|
||||||
|
|
||||||
arr.append(res)
|
|
||||||
|
|
||||||
if not has_spec_out:
|
if not has_spec_out:
|
||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
outputs = []
|
outputs: List[Optional[SamplerOutput]] = []
|
||||||
token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device)
|
for idx in range(len(execute_model_req.seq_group_metadata_list)):
|
||||||
indices = token_ids.unsqueeze(2)
|
if token_id_list[idx] is None:
|
||||||
|
outputs.append(None)
|
||||||
|
else:
|
||||||
|
outputs.append(
|
||||||
|
SamplerOutput(
|
||||||
|
outputs=None,
|
||||||
|
sampled_token_probs=token_prob_list[idx],
|
||||||
|
logprobs=torch.zeros((sample_len, self.vocab_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device),
|
||||||
|
sampled_token_ids=token_id_list[idx],
|
||||||
|
))
|
||||||
|
|
||||||
token_probs = torch.zeros(
|
|
||||||
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
|
||||||
self.vocab_size),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
token_probs.scatter_(2, indices, 1)
|
|
||||||
token_logprobs = torch.zeros(
|
|
||||||
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
|
||||||
self.vocab_size),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
for i in range(len(execute_model_req.seq_group_metadata_list)):
|
|
||||||
outputs.append(
|
|
||||||
SamplerOutput(
|
|
||||||
outputs=None,
|
|
||||||
sampled_token_probs=token_probs[i],
|
|
||||||
logprobs=token_logprobs[i],
|
|
||||||
sampled_token_ids=token_ids[i],
|
|
||||||
))
|
|
||||||
return outputs, False
|
return outputs, False
|
||||||
|
|
||||||
def get_spec_proposals(
|
def get_spec_proposals(
|
||||||
|
|||||||
@ -73,6 +73,14 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
execute_model_req=nonzero_execute_model_req,
|
execute_model_req=nonzero_execute_model_req,
|
||||||
sample_len=proposal_len,
|
sample_len=proposal_len,
|
||||||
)
|
)
|
||||||
|
(
|
||||||
|
proposal_lens,
|
||||||
|
maybe_sampler_output,
|
||||||
|
nonzero_proposal_len_indices,
|
||||||
|
) = self._remove_no_proposal_seqs(proposal_lens,
|
||||||
|
maybe_sampler_output,
|
||||||
|
nonzero_proposal_len_indices,
|
||||||
|
transposed)
|
||||||
else:
|
else:
|
||||||
# If no sequences can be speculated, set sampler output to None.
|
# If no sequences can be speculated, set sampler output to None.
|
||||||
maybe_sampler_output = None
|
maybe_sampler_output = None
|
||||||
@ -140,6 +148,61 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
nonzero_proposal_len_indices,
|
nonzero_proposal_len_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
|
||||||
|
nonzero_proposal_len_indices, transposed):
|
||||||
|
"""Remove sequences from nonzero_proposal_len_indices and reset
|
||||||
|
their proposal_len to 0 the draft worker does not provide a proposal
|
||||||
|
(maybe_sampler_output=None). This can avoid scoring overheads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If maybe_sampler_output is None, then the draft worker did not
|
||||||
|
# provide a proposal for any sequence and thus no action needed.
|
||||||
|
# Also we do not support transposed maybe_sampler_output for now
|
||||||
|
# because it seems not straightforward for draft workers outputting
|
||||||
|
# transposed sampler outputs to handle the case of no proposal.
|
||||||
|
if maybe_sampler_output is None or transposed:
|
||||||
|
return (proposal_lens, maybe_sampler_output,
|
||||||
|
nonzero_proposal_len_indices)
|
||||||
|
|
||||||
|
new_proposal_lens: List[int] = []
|
||||||
|
new_nonzero_proposal_len_indices: List[int] = []
|
||||||
|
new_maybe_sampler_output: List[SamplerOutput] = []
|
||||||
|
nonzero_proposal_len_idx_ptr = 0
|
||||||
|
seq_idx = 0
|
||||||
|
while seq_idx < len(
|
||||||
|
proposal_lens) and nonzero_proposal_len_idx_ptr < len(
|
||||||
|
nonzero_proposal_len_indices):
|
||||||
|
if seq_idx < nonzero_proposal_len_indices[
|
||||||
|
nonzero_proposal_len_idx_ptr]:
|
||||||
|
# Sequence is not in the original nonzero_proposal_len_indices,
|
||||||
|
# meaning that it has a proposal length of 0 before sending to
|
||||||
|
# the draft worker.
|
||||||
|
assert proposal_lens[seq_idx] == 0
|
||||||
|
new_proposal_lens.append(0)
|
||||||
|
else:
|
||||||
|
# Sequence is in the original nonzero_proposal_len_indices
|
||||||
|
if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
|
||||||
|
# but does not have a proposal from the draft worker.
|
||||||
|
new_proposal_lens.append(0)
|
||||||
|
else:
|
||||||
|
# and has a proposal from the draft worker. Add it to the
|
||||||
|
# new nonzero proposal list and keep the sampler output.
|
||||||
|
new_proposal_lens.append(proposal_lens[seq_idx])
|
||||||
|
new_nonzero_proposal_len_indices.append(seq_idx)
|
||||||
|
new_maybe_sampler_output.append(
|
||||||
|
maybe_sampler_output[nonzero_proposal_len_idx_ptr])
|
||||||
|
nonzero_proposal_len_idx_ptr += 1
|
||||||
|
seq_idx += 1
|
||||||
|
|
||||||
|
# The remaining sequences should have proposal length of 0.
|
||||||
|
new_proposal_lens.extend(proposal_lens[seq_idx:])
|
||||||
|
|
||||||
|
# We assume sampler_output will not be a list of all Nones.
|
||||||
|
# In this case this function should not be called.
|
||||||
|
assert new_maybe_sampler_output
|
||||||
|
return (new_proposal_lens, new_maybe_sampler_output,
|
||||||
|
new_nonzero_proposal_len_indices)
|
||||||
|
|
||||||
def _merge_outputs(
|
def _merge_outputs(
|
||||||
self,
|
self,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user