mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:55:01 +08:00
[BugFix] Fix spec decoding max_tokens scheduling perf issue (#29542)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
953d9c820b
commit
8e7a891602
@ -43,7 +43,7 @@ class TestLogprobsLists(TestCase):
|
|||||||
cu_num_generated_tokens=None,
|
cu_num_generated_tokens=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
sliced = logprobsLists.slice(1, 3)
|
sliced = logprobsLists.slice_request(1, num_positions=2)
|
||||||
assert sliced.logprob_token_ids == [[2], [3]]
|
assert sliced.logprob_token_ids == [[2], [3]]
|
||||||
assert sliced.logprobs == [[0.2], [0.3]]
|
assert sliced.logprobs == [[0.2], [0.3]]
|
||||||
assert sliced.sampled_token_ranks == [2, 3]
|
assert sliced.sampled_token_ranks == [2, 3]
|
||||||
@ -51,7 +51,7 @@ class TestLogprobsLists(TestCase):
|
|||||||
|
|
||||||
def test_slice_from_start(self):
|
def test_slice_from_start(self):
|
||||||
"""Test slicing from the start position"""
|
"""Test slicing from the start position"""
|
||||||
sliced = self.logprobsLists.slice(0, 2)
|
sliced = self.logprobsLists.slice_request(0, num_positions=5)
|
||||||
assert len(sliced.logprob_token_ids) == 5
|
assert len(sliced.logprob_token_ids) == 5
|
||||||
assert sliced.logprob_token_ids == [
|
assert sliced.logprob_token_ids == [
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@ -60,11 +60,11 @@ class TestLogprobsLists(TestCase):
|
|||||||
[7, 8],
|
[7, 8],
|
||||||
[9, 10],
|
[9, 10],
|
||||||
]
|
]
|
||||||
assert sliced.cu_num_generated_tokens == [0, 2, 5]
|
assert sliced.cu_num_generated_tokens is None
|
||||||
|
|
||||||
def test_slice_from_middle(self):
|
def test_slice_from_middle(self):
|
||||||
"""Test slicing from the middle position"""
|
"""Test slicing from the middle position"""
|
||||||
sliced = self.logprobsLists.slice(1, 3)
|
sliced = self.logprobsLists.slice_request(1, num_positions=7)
|
||||||
assert len(sliced.logprob_token_ids) == 7
|
assert len(sliced.logprob_token_ids) == 7
|
||||||
assert sliced.logprob_token_ids == [
|
assert sliced.logprob_token_ids == [
|
||||||
[5, 6],
|
[5, 6],
|
||||||
@ -75,27 +75,25 @@ class TestLogprobsLists(TestCase):
|
|||||||
[15, 16],
|
[15, 16],
|
||||||
[17, 18],
|
[17, 18],
|
||||||
]
|
]
|
||||||
assert sliced.cu_num_generated_tokens == [0, 3, 7]
|
assert sliced.cu_num_generated_tokens is None
|
||||||
|
|
||||||
def test_slice_single_request(self):
|
def test_slice_single_request(self):
|
||||||
"""Test slicing a single request"""
|
"""Test slicing a single request"""
|
||||||
sliced = self.logprobsLists.slice(1, 2)
|
sliced = self.logprobsLists.slice_request(1, num_positions=3)
|
||||||
assert len(sliced.logprob_token_ids) == 3
|
assert len(sliced.logprob_token_ids) == 3
|
||||||
assert sliced.logprob_token_ids == [[5, 6], [7, 8], [9, 10]]
|
assert sliced.logprob_token_ids == [[5, 6], [7, 8], [9, 10]]
|
||||||
assert sliced.cu_num_generated_tokens == [0, 3]
|
assert sliced.cu_num_generated_tokens is None
|
||||||
|
|
||||||
def test_slice_last_request(self):
|
def test_slice_last_request(self):
|
||||||
"""Test slicing the last request"""
|
"""Test slicing the last request"""
|
||||||
sliced = self.logprobsLists.slice(2, 3)
|
sliced = self.logprobsLists.slice_request(2, num_positions=4)
|
||||||
assert len(sliced.logprob_token_ids) == 4
|
assert len(sliced.logprob_token_ids) == 4
|
||||||
assert sliced.logprob_token_ids == [[11, 12], [13, 14], [15, 16], [17, 18]]
|
assert sliced.logprob_token_ids == [[11, 12], [13, 14], [15, 16], [17, 18]]
|
||||||
assert sliced.cu_num_generated_tokens == [0, 4]
|
assert sliced.cu_num_generated_tokens is None
|
||||||
|
|
||||||
def test_slice_all_requests(self):
|
def test_slice_all_requests(self):
|
||||||
"""Test slicing all requests (full slice)"""
|
"""Test slicing all requests (full slice)"""
|
||||||
sliced = self.logprobsLists.slice(0, 3)
|
sliced = self.logprobsLists.slice_request(0, num_positions=9)
|
||||||
assert len(sliced.logprob_token_ids) == 9 # All tokens
|
assert len(sliced.logprob_token_ids) == 9 # All tokens
|
||||||
assert sliced.logprob_token_ids == self.logprobsLists.logprob_token_ids
|
assert sliced.logprob_token_ids == self.logprobsLists.logprob_token_ids
|
||||||
assert (
|
assert sliced.cu_num_generated_tokens is None
|
||||||
sliced.cu_num_generated_tokens == self.logprobsLists.cu_num_generated_tokens
|
|
||||||
)
|
|
||||||
|
|||||||
@ -234,11 +234,15 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
|
|
||||||
# Make sure the input position does not exceed the max model len or
|
num_spec_placeholders = max(0, request.num_output_placeholders - 1)
|
||||||
# request's max_tokens.
|
|
||||||
# This is necessary when using spec decoding and/or async scheduling.
|
|
||||||
max_total_tokens = min(
|
max_total_tokens = min(
|
||||||
request.num_prompt_tokens + request.max_tokens, self.max_model_len
|
# Avoid scheduling tokens that we're sure won't will be needed based on
|
||||||
|
# request.max_tokens. For this calculation we assume placeholder
|
||||||
|
# speculated output tokens are rejected.
|
||||||
|
request.num_prompt_tokens + request.max_tokens + num_spec_placeholders,
|
||||||
|
# Make sure the input position does not exceed the max model len.
|
||||||
|
# This is necessary when using spec decoding.
|
||||||
|
self.max_model_len,
|
||||||
)
|
)
|
||||||
num_new_tokens = min(
|
num_new_tokens = min(
|
||||||
num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens
|
num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens
|
||||||
@ -1089,7 +1093,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
and request.sampling_params.logprobs is not None
|
and request.sampling_params.logprobs is not None
|
||||||
and logprobs
|
and logprobs
|
||||||
):
|
):
|
||||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
|
||||||
|
|
||||||
if new_token_ids and self.structured_output_manager.should_advance(request):
|
if new_token_ids and self.structured_output_manager.should_advance(request):
|
||||||
struct_output_request = request.structured_output_request
|
struct_output_request = request.structured_output_request
|
||||||
|
|||||||
@ -29,27 +29,15 @@ class LogprobsLists(NamedTuple):
|
|||||||
# different for each request.
|
# different for each request.
|
||||||
cu_num_generated_tokens: list[int] | None = None
|
cu_num_generated_tokens: list[int] | None = None
|
||||||
|
|
||||||
def slice(self, start_req_idx: int, end_req_idx: int):
|
def slice_request(self, req_idx: int, num_positions: int):
|
||||||
if self.cu_num_generated_tokens:
|
if self.cu_num_generated_tokens is not None:
|
||||||
start = self.cu_num_generated_tokens[start_req_idx]
|
req_idx = self.cu_num_generated_tokens[req_idx]
|
||||||
end = self.cu_num_generated_tokens[end_req_idx]
|
end_idx = req_idx + num_positions
|
||||||
# Recompute cumulative array starting from 0
|
|
||||||
cu_num_offset = self.cu_num_generated_tokens[start_req_idx]
|
|
||||||
sliced_cu_num_generated_tokens = [
|
|
||||||
cu_num - cu_num_offset
|
|
||||||
for cu_num in self.cu_num_generated_tokens[
|
|
||||||
start_req_idx : end_req_idx + 1
|
|
||||||
]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
start = start_req_idx
|
|
||||||
end = end_req_idx
|
|
||||||
sliced_cu_num_generated_tokens = None
|
|
||||||
return LogprobsLists(
|
return LogprobsLists(
|
||||||
self.logprob_token_ids[start:end],
|
self.logprob_token_ids[req_idx:end_idx],
|
||||||
self.logprobs[start:end],
|
self.logprobs[req_idx:end_idx],
|
||||||
self.sampled_token_ranks[start:end],
|
self.sampled_token_ranks[req_idx:end_idx],
|
||||||
sliced_cu_num_generated_tokens,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user