From 28873a2799ddfdd0624edd4619e6fbeeb49cd02c Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Thu, 31 Aug 2023 00:28:43 -0400 Subject: [PATCH] Improve _prune_hidden_states micro-benchmark (#707) --- vllm/model_executor/layers/sampler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 9717e008972f..1d8036310323 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -100,7 +100,8 @@ def _prune_hidden_states( start_idx += prompt_len last_token_indicies.extend( range(start_idx, start_idx + input_metadata.num_generation_tokens)) - return hidden_states[last_token_indicies] + return hidden_states.index_select( + 0, torch.tensor(last_token_indicies, device=hidden_states.device)) def _get_penalties(