mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 12:44:27 +08:00
Fix cu_num_generated_tokens slicing logic in LogprobsLists.slice() method (#28214)
Signed-off-by: Bradley <bradley.b.pitt@gmail.com>
This commit is contained in:
parent
636efd10a5
commit
4a8d6bd168
101
tests/v1/test_outputs.py
Normal file
101
tests/v1/test_outputs.py
Normal file
@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest import TestCase
|
||||
|
||||
from vllm.v1.outputs import LogprobsLists
|
||||
|
||||
|
||||
class TestLogprobsLists(TestCase):
|
||||
def setUp(self):
|
||||
self.logprobsLists = LogprobsLists(
|
||||
logprob_token_ids=[
|
||||
[1, 2], # Request 0 token 0
|
||||
[3, 4], # Request 0 token 1
|
||||
[5, 6], # Request 1 token 0
|
||||
[7, 8], # Request 1 token 1
|
||||
[9, 10], # Request 1 token 2
|
||||
[11, 12], # Request 2 token 0
|
||||
[13, 14], # Request 2 token 1
|
||||
[15, 16], # Request 2 token 2
|
||||
[17, 18], # Request 2 token 3
|
||||
],
|
||||
logprobs=[
|
||||
[0.1, 0.2],
|
||||
[0.3, 0.4],
|
||||
[0.5, 0.6],
|
||||
[0.7, 0.8],
|
||||
[0.9, 1.0],
|
||||
[1.1, 1.2],
|
||||
[1.3, 1.4],
|
||||
[1.5, 1.6],
|
||||
[1.7, 1.8],
|
||||
],
|
||||
sampled_token_ranks=[1, 3, 5, 7, 9, 11, 13, 15, 17],
|
||||
cu_num_generated_tokens=[0, 2, 5, 9],
|
||||
)
|
||||
|
||||
def test_slice_without_cu_num_generated_tokens(self):
|
||||
"""Test slicing without cu_num_generated_tokens"""
|
||||
logprobsLists = LogprobsLists(
|
||||
logprob_token_ids=[[1], [2], [3]],
|
||||
logprobs=[[0.1], [0.2], [0.3]],
|
||||
sampled_token_ranks=[1, 2, 3],
|
||||
cu_num_generated_tokens=None,
|
||||
)
|
||||
|
||||
sliced = logprobsLists.slice(1, 3)
|
||||
assert sliced.logprob_token_ids == [[2], [3]]
|
||||
assert sliced.logprobs == [[0.2], [0.3]]
|
||||
assert sliced.sampled_token_ranks == [2, 3]
|
||||
assert sliced.cu_num_generated_tokens is None
|
||||
|
||||
def test_slice_from_start(self):
|
||||
"""Test slicing from the start position"""
|
||||
sliced = self.logprobsLists.slice(0, 2)
|
||||
assert len(sliced.logprob_token_ids) == 5
|
||||
assert sliced.logprob_token_ids == [
|
||||
[1, 2],
|
||||
[3, 4],
|
||||
[5, 6],
|
||||
[7, 8],
|
||||
[9, 10],
|
||||
]
|
||||
assert sliced.cu_num_generated_tokens == [0, 2, 5]
|
||||
|
||||
def test_slice_from_middle(self):
|
||||
"""Test slicing from the middle position"""
|
||||
sliced = self.logprobsLists.slice(1, 3)
|
||||
assert len(sliced.logprob_token_ids) == 7
|
||||
assert sliced.logprob_token_ids == [
|
||||
[5, 6],
|
||||
[7, 8],
|
||||
[9, 10],
|
||||
[11, 12],
|
||||
[13, 14],
|
||||
[15, 16],
|
||||
[17, 18],
|
||||
]
|
||||
assert sliced.cu_num_generated_tokens == [0, 3, 7]
|
||||
|
||||
def test_slice_single_request(self):
|
||||
"""Test slicing a single request"""
|
||||
sliced = self.logprobsLists.slice(1, 2)
|
||||
assert len(sliced.logprob_token_ids) == 3
|
||||
assert sliced.logprob_token_ids == [[5, 6], [7, 8], [9, 10]]
|
||||
assert sliced.cu_num_generated_tokens == [0, 3]
|
||||
|
||||
def test_slice_last_request(self):
|
||||
"""Test slicing the last request"""
|
||||
sliced = self.logprobsLists.slice(2, 3)
|
||||
assert len(sliced.logprob_token_ids) == 4
|
||||
assert sliced.logprob_token_ids == [[11, 12], [13, 14], [15, 16], [17, 18]]
|
||||
assert sliced.cu_num_generated_tokens == [0, 4]
|
||||
|
||||
def test_slice_all_requests(self):
|
||||
"""Test slicing all requests (full slice)"""
|
||||
sliced = self.logprobsLists.slice(0, 3)
|
||||
assert len(sliced.logprob_token_ids) == 9 # All tokens
|
||||
assert sliced.logprob_token_ids == self.logprobsLists.logprob_token_ids
|
||||
assert (
|
||||
sliced.cu_num_generated_tokens == self.logprobsLists.cu_num_generated_tokens
|
||||
)
|
||||
@ -30,16 +30,23 @@ class LogprobsLists(NamedTuple):
|
||||
if self.cu_num_generated_tokens:
|
||||
start = self.cu_num_generated_tokens[start_req_idx]
|
||||
end = self.cu_num_generated_tokens[end_req_idx]
|
||||
# Recompute cumulative array starting from 0
|
||||
cu_num_offset = self.cu_num_generated_tokens[start_req_idx]
|
||||
sliced_cu_num_generated_tokens = [
|
||||
cu_num - cu_num_offset
|
||||
for cu_num in self.cu_num_generated_tokens[
|
||||
start_req_idx : end_req_idx + 1
|
||||
]
|
||||
]
|
||||
else:
|
||||
start = start_req_idx
|
||||
end = end_req_idx
|
||||
sliced_cu_num_generated_tokens = None
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids[start:end],
|
||||
self.logprobs[start:end],
|
||||
self.sampled_token_ranks[start:end],
|
||||
self.cu_num_generated_tokens[start_req_idx:end_req_idx]
|
||||
if self.cu_num_generated_tokens
|
||||
else None,
|
||||
sliced_cu_num_generated_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user