[Performance][Spec Decode] Optimize ngram lookup performance (#9333)

This commit is contained in:
Lily Liu 2024-10-16 12:37:45 -07:00 committed by GitHub
parent 5b8a1fde84
commit 8345045833
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -67,9 +67,16 @@ class NGramWorker(NonLLMProposerWorkerBase):
execute_model_req.seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# When seq_len is less than 3072 (3K), we use CPU to perform
# the ngram match. Otherwise, we use the device specified in
# the model config (normally GPU). 3072 is a rough threshold
# based on profiling on H100, and it can be adjusted based
# on the actual performance on different hardware.
cur_device = "cpu" if seq_len < 3072 else self.device
input_ids = torch.as_tensor(seq_data.get_token_ids(),
dtype=torch.long,
device=self.device)
device=cur_device)
input_length = seq_data.get_len()
for ngram_size in range(
@ -91,17 +98,15 @@ class NGramWorker(NonLLMProposerWorkerBase):
# first_match includes "values" (bool), indicating whether
# the match is found, and "indices", indicating the index
# of the first match.
# Note that "first_match.values.item()" triggers GPU-CPU
# sync so it is a bit inefficient, but we have not found
# a better way to do this.
first_match = matches.max(dim=-1)
if first_match.values.item():
proposal_start_idx = first_match.indices.add_(ngram_size)
spec_indices = (
proposal_start_idx).repeat(sample_len) + torch.arange(
sample_len, device=self.device)
sample_len, device=cur_device)
spec_indices.clamp_(max=input_ids.shape[-1] - 1)
res = input_ids.gather(dim=-1, index=spec_indices)
res = input_ids.gather(dim=-1,
index=spec_indices).to(self.device)
token_id_list.append(res)
token_prob_list.append(
torch.nn.functional.one_hot(