diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index 23684a2c55ce..3541ef89bfc1 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -5,6 +5,7 @@ import random from dataclasses import dataclass from typing import TypeAlias +import numpy as np import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -369,9 +370,9 @@ class MockEngineCore: self.generated_logprobs_raw[req_idx][token_idx] ) logprobs = LogprobsLists( - [logprobs_token_ids_], - [logprobs_], - [sampled_token_ranks_], + np.array([logprobs_token_ids_]), + np.array([logprobs_]), + np.array([sampled_token_ranks_]), ) else: logprobs = None diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 4c5955d7ee2e..b618d2347265 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -74,7 +74,12 @@ class LogprobsProcessor: token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists - for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst): + for rank_np, logprobs_np, token_ids_np in zip( + ranks_lst, logprobs_lst, token_ids_lst + ): + rank = rank_np.tolist() + logprobs = logprobs_np.tolist() + token_ids = token_ids_np.tolist() # Detokenize (non-incrementally). decoded_tokens = ( NONES diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index b5cba96e1026..5f65e4ee0d1f 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, NamedTuple +import numpy as np import torch if TYPE_CHECKING: @@ -15,11 +16,11 @@ else: class LogprobsLists(NamedTuple): # [num_reqs x num_generated_tokens, max_num_logprobs + 1] - logprob_token_ids: list[list[int]] + logprob_token_ids: np.ndarray # [num_reqs x num_generated_tokens, max_num_logprobs + 1] - logprobs: list[list[float]] + logprobs: np.ndarray # [num_reqs x num_generated_tokens] - sampled_token_ranks: list[int] + sampled_token_ranks: np.ndarray # [num_reqs] # Used for slicing the logprobs in cases like speculative # decoding where the number of generated tokens may be @@ -60,9 +61,9 @@ class LogprobsTensors(NamedTuple): def tolists(self, cu_num_generated_tokens: list[int] | None = None): return LogprobsLists( - self.logprob_token_ids.tolist(), - self.logprobs.tolist(), - self.selected_token_ranks.tolist(), + self.logprob_token_ids.cpu().numpy(), + self.logprobs.cpu().numpy(), + self.selected_token_ranks.cpu().numpy(), cu_num_generated_tokens, )