[BugFix] Fix spec decoding max_tokens scheduling perf issue (#29542)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-11-28 04:52:23 -08:00 committed by GitHub
parent 953d9c820b
commit 8e7a891602
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 38 deletions

View File

@ -43,7 +43,7 @@ class TestLogprobsLists(TestCase):
cu_num_generated_tokens=None, cu_num_generated_tokens=None,
) )
sliced = logprobsLists.slice(1, 3) sliced = logprobsLists.slice_request(1, num_positions=2)
assert sliced.logprob_token_ids == [[2], [3]] assert sliced.logprob_token_ids == [[2], [3]]
assert sliced.logprobs == [[0.2], [0.3]] assert sliced.logprobs == [[0.2], [0.3]]
assert sliced.sampled_token_ranks == [2, 3] assert sliced.sampled_token_ranks == [2, 3]
@ -51,7 +51,7 @@ class TestLogprobsLists(TestCase):
def test_slice_from_start(self): def test_slice_from_start(self):
"""Test slicing from the start position""" """Test slicing from the start position"""
sliced = self.logprobsLists.slice(0, 2) sliced = self.logprobsLists.slice_request(0, num_positions=5)
assert len(sliced.logprob_token_ids) == 5 assert len(sliced.logprob_token_ids) == 5
assert sliced.logprob_token_ids == [ assert sliced.logprob_token_ids == [
[1, 2], [1, 2],
@ -60,11 +60,11 @@ class TestLogprobsLists(TestCase):
[7, 8], [7, 8],
[9, 10], [9, 10],
] ]
assert sliced.cu_num_generated_tokens == [0, 2, 5] assert sliced.cu_num_generated_tokens is None
def test_slice_from_middle(self): def test_slice_from_middle(self):
"""Test slicing from the middle position""" """Test slicing from the middle position"""
sliced = self.logprobsLists.slice(1, 3) sliced = self.logprobsLists.slice_request(1, num_positions=7)
assert len(sliced.logprob_token_ids) == 7 assert len(sliced.logprob_token_ids) == 7
assert sliced.logprob_token_ids == [ assert sliced.logprob_token_ids == [
[5, 6], [5, 6],
@ -75,27 +75,25 @@ class TestLogprobsLists(TestCase):
[15, 16], [15, 16],
[17, 18], [17, 18],
] ]
assert sliced.cu_num_generated_tokens == [0, 3, 7] assert sliced.cu_num_generated_tokens is None
def test_slice_single_request(self): def test_slice_single_request(self):
"""Test slicing a single request""" """Test slicing a single request"""
sliced = self.logprobsLists.slice(1, 2) sliced = self.logprobsLists.slice_request(1, num_positions=3)
assert len(sliced.logprob_token_ids) == 3 assert len(sliced.logprob_token_ids) == 3
assert sliced.logprob_token_ids == [[5, 6], [7, 8], [9, 10]] assert sliced.logprob_token_ids == [[5, 6], [7, 8], [9, 10]]
assert sliced.cu_num_generated_tokens == [0, 3] assert sliced.cu_num_generated_tokens is None
def test_slice_last_request(self): def test_slice_last_request(self):
"""Test slicing the last request""" """Test slicing the last request"""
sliced = self.logprobsLists.slice(2, 3) sliced = self.logprobsLists.slice_request(2, num_positions=4)
assert len(sliced.logprob_token_ids) == 4 assert len(sliced.logprob_token_ids) == 4
assert sliced.logprob_token_ids == [[11, 12], [13, 14], [15, 16], [17, 18]] assert sliced.logprob_token_ids == [[11, 12], [13, 14], [15, 16], [17, 18]]
assert sliced.cu_num_generated_tokens == [0, 4] assert sliced.cu_num_generated_tokens is None
def test_slice_all_requests(self): def test_slice_all_requests(self):
"""Test slicing all requests (full slice)""" """Test slicing all requests (full slice)"""
sliced = self.logprobsLists.slice(0, 3) sliced = self.logprobsLists.slice_request(0, num_positions=9)
assert len(sliced.logprob_token_ids) == 9 # All tokens assert len(sliced.logprob_token_ids) == 9 # All tokens
assert sliced.logprob_token_ids == self.logprobsLists.logprob_token_ids assert sliced.logprob_token_ids == self.logprobsLists.logprob_token_ids
assert ( assert sliced.cu_num_generated_tokens is None
sliced.cu_num_generated_tokens == self.logprobsLists.cu_num_generated_tokens
)

View File

@ -234,11 +234,15 @@ class Scheduler(SchedulerInterface):
num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len or num_spec_placeholders = max(0, request.num_output_placeholders - 1)
# request's max_tokens.
# This is necessary when using spec decoding and/or async scheduling.
max_total_tokens = min( max_total_tokens = min(
request.num_prompt_tokens + request.max_tokens, self.max_model_len # Avoid scheduling tokens that we're sure won't will be needed based on
# request.max_tokens. For this calculation we assume placeholder
# speculated output tokens are rejected.
request.num_prompt_tokens + request.max_tokens + num_spec_placeholders,
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
self.max_model_len,
) )
num_new_tokens = min( num_new_tokens = min(
num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens
@ -1089,7 +1093,7 @@ class Scheduler(SchedulerInterface):
and request.sampling_params.logprobs is not None and request.sampling_params.logprobs is not None
and logprobs and logprobs
): ):
new_logprobs = logprobs.slice(req_index, req_index + 1) new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
if new_token_ids and self.structured_output_manager.should_advance(request): if new_token_ids and self.structured_output_manager.should_advance(request):
struct_output_request = request.structured_output_request struct_output_request = request.structured_output_request

View File

@ -29,27 +29,15 @@ class LogprobsLists(NamedTuple):
# different for each request. # different for each request.
cu_num_generated_tokens: list[int] | None = None cu_num_generated_tokens: list[int] | None = None
def slice(self, start_req_idx: int, end_req_idx: int): def slice_request(self, req_idx: int, num_positions: int):
if self.cu_num_generated_tokens: if self.cu_num_generated_tokens is not None:
start = self.cu_num_generated_tokens[start_req_idx] req_idx = self.cu_num_generated_tokens[req_idx]
end = self.cu_num_generated_tokens[end_req_idx] end_idx = req_idx + num_positions
# Recompute cumulative array starting from 0
cu_num_offset = self.cu_num_generated_tokens[start_req_idx]
sliced_cu_num_generated_tokens = [
cu_num - cu_num_offset
for cu_num in self.cu_num_generated_tokens[
start_req_idx : end_req_idx + 1
]
]
else:
start = start_req_idx
end = end_req_idx
sliced_cu_num_generated_tokens = None
return LogprobsLists( return LogprobsLists(
self.logprob_token_ids[start:end], self.logprob_token_ids[req_idx:end_idx],
self.logprobs[start:end], self.logprobs[req_idx:end_idx],
self.sampled_token_ranks[start:end], self.sampled_token_ranks[req_idx:end_idx],
sliced_cu_num_generated_tokens, None,
) )