Improve _prune_hidden_states micro-benchmark (#707)

This commit is contained in:
Aman Gupta Karmani 2023-08-31 00:28:43 -04:00 committed by GitHub
parent 0080d8329d
commit 28873a2799
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -100,7 +100,8 @@ def _prune_hidden_states(
start_idx += prompt_len
last_token_indicies.extend(
range(start_idx, start_idx + input_metadata.num_generation_tokens))
return hidden_states[last_token_indicies]
return hidden_states.index_select(
0, torch.tensor(last_token_indicies, device=hidden_states.device))
def _get_penalties(