mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:28:42 +08:00
[Performance][Spec Decode] Optimize ngram lookup performance (#9333)
This commit is contained in:
parent
5b8a1fde84
commit
8345045833
@ -67,9 +67,16 @@ class NGramWorker(NonLLMProposerWorkerBase):
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
# When seq_len is less than 3072 (3K), we use CPU to perform
|
||||
# the ngram match. Otherwise, we use the device specified in
|
||||
# the model config (normally GPU). 3072 is a rough threshold
|
||||
# based on profiling on H100, and it can be adjusted based
|
||||
# on the actual performance on different hardware.
|
||||
cur_device = "cpu" if seq_len < 3072 else self.device
|
||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
device=cur_device)
|
||||
input_length = seq_data.get_len()
|
||||
|
||||
for ngram_size in range(
|
||||
@ -91,17 +98,15 @@ class NGramWorker(NonLLMProposerWorkerBase):
|
||||
# first_match includes "values" (bool), indicating whether
|
||||
# the match is found, and "indices", indicating the index
|
||||
# of the first match.
|
||||
# Note that "first_match.values.item()" triggers GPU-CPU
|
||||
# sync so it is a bit inefficient, but we have not found
|
||||
# a better way to do this.
|
||||
first_match = matches.max(dim=-1)
|
||||
if first_match.values.item():
|
||||
proposal_start_idx = first_match.indices.add_(ngram_size)
|
||||
spec_indices = (
|
||||
proposal_start_idx).repeat(sample_len) + torch.arange(
|
||||
sample_len, device=self.device)
|
||||
sample_len, device=cur_device)
|
||||
spec_indices.clamp_(max=input_ids.shape[-1] - 1)
|
||||
res = input_ids.gather(dim=-1, index=spec_indices)
|
||||
res = input_ids.gather(dim=-1,
|
||||
index=spec_indices).to(self.device)
|
||||
token_id_list.append(res)
|
||||
token_prob_list.append(
|
||||
torch.nn.functional.one_hot(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user