Fix 1D query issue from _prune_hidden_states (#3539)

This commit is contained in:
SangBin Cho 2024-03-21 17:49:06 +09:00 committed by GitHub
parent 6ebd02bdef
commit 3bbff9e5ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -77,7 +77,6 @@ def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)