mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 03:15:29 +08:00
[Performance][Spec Decode] Optimize ngram lookup performance (#9333)
This commit is contained in:
parent
5b8a1fde84
commit
8345045833
@ -67,9 +67,16 @@ class NGramWorker(NonLLMProposerWorkerBase):
|
|||||||
execute_model_req.seq_group_metadata_list):
|
execute_model_req.seq_group_metadata_list):
|
||||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||||
|
|
||||||
|
seq_len = seq_data.get_len()
|
||||||
|
# When seq_len is less than 3072 (3K), we use CPU to perform
|
||||||
|
# the ngram match. Otherwise, we use the device specified in
|
||||||
|
# the model config (normally GPU). 3072 is a rough threshold
|
||||||
|
# based on profiling on H100, and it can be adjusted based
|
||||||
|
# on the actual performance on different hardware.
|
||||||
|
cur_device = "cpu" if seq_len < 3072 else self.device
|
||||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=self.device)
|
device=cur_device)
|
||||||
input_length = seq_data.get_len()
|
input_length = seq_data.get_len()
|
||||||
|
|
||||||
for ngram_size in range(
|
for ngram_size in range(
|
||||||
@ -91,17 +98,15 @@ class NGramWorker(NonLLMProposerWorkerBase):
|
|||||||
# first_match includes "values" (bool), indicating whether
|
# first_match includes "values" (bool), indicating whether
|
||||||
# the match is found, and "indices", indicating the index
|
# the match is found, and "indices", indicating the index
|
||||||
# of the first match.
|
# of the first match.
|
||||||
# Note that "first_match.values.item()" triggers GPU-CPU
|
|
||||||
# sync so it is a bit inefficient, but we have not found
|
|
||||||
# a better way to do this.
|
|
||||||
first_match = matches.max(dim=-1)
|
first_match = matches.max(dim=-1)
|
||||||
if first_match.values.item():
|
if first_match.values.item():
|
||||||
proposal_start_idx = first_match.indices.add_(ngram_size)
|
proposal_start_idx = first_match.indices.add_(ngram_size)
|
||||||
spec_indices = (
|
spec_indices = (
|
||||||
proposal_start_idx).repeat(sample_len) + torch.arange(
|
proposal_start_idx).repeat(sample_len) + torch.arange(
|
||||||
sample_len, device=self.device)
|
sample_len, device=cur_device)
|
||||||
spec_indices.clamp_(max=input_ids.shape[-1] - 1)
|
spec_indices.clamp_(max=input_ids.shape[-1] - 1)
|
||||||
res = input_ids.gather(dim=-1, index=spec_indices)
|
res = input_ids.gather(dim=-1,
|
||||||
|
index=spec_indices).to(self.device)
|
||||||
token_id_list.append(res)
|
token_id_list.append(res)
|
||||||
token_prob_list.append(
|
token_prob_list.append(
|
||||||
torch.nn.functional.one_hot(
|
torch.nn.functional.one_hot(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user