diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 9717e008972f..1d8036310323 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -100,7 +100,8 @@ def _prune_hidden_states( start_idx += prompt_len last_token_indicies.extend( range(start_idx, start_idx + input_metadata.num_generation_tokens)) - return hidden_states[last_token_indicies] + return hidden_states.index_select( + 0, torch.tensor(last_token_indicies, device=hidden_states.device)) def _get_penalties(