mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 05:45:01 +08:00
[Structured Outputs] Refactor bitmask construction into get_grammar_bitmask (#23361)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
044931f97b
commit
800349c2a5
@ -177,14 +177,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
scheduled_running_reqs: list[Request] = []
|
scheduled_running_reqs: list[Request] = []
|
||||||
preempted_reqs: list[Request] = []
|
preempted_reqs: list[Request] = []
|
||||||
|
|
||||||
# NOTE: structured_output_request_ids maps
|
|
||||||
# a request's (request that uses structured output)
|
|
||||||
# request_id to the running request index.
|
|
||||||
# This will helps us determine to slice the grammar bitmask
|
|
||||||
# and only applies valid mask for requests that
|
|
||||||
# uses structured decoding.
|
|
||||||
structured_output_request_ids: dict[str, int] = {}
|
|
||||||
|
|
||||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||||
num_scheduled_tokens: dict[str, int] = {}
|
num_scheduled_tokens: dict[str, int] = {}
|
||||||
token_budget = self.max_num_scheduled_tokens
|
token_budget = self.max_num_scheduled_tokens
|
||||||
@ -282,12 +274,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# Schedule the request.
|
# Schedule the request.
|
||||||
scheduled_running_reqs.append(request)
|
scheduled_running_reqs.append(request)
|
||||||
if request.use_structured_output:
|
|
||||||
# PERF: in case of chunked prefill,
|
|
||||||
# request might not include any new tokens.
|
|
||||||
# Therefore, we might introduce some additional
|
|
||||||
# cycle to fill in the bitmask, which could be a big no-op.
|
|
||||||
structured_output_request_ids[request.request_id] = req_index
|
|
||||||
req_to_new_blocks[request.request_id] = new_blocks
|
req_to_new_blocks[request.request_id] = new_blocks
|
||||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
@ -477,9 +463,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if request.use_structured_output:
|
|
||||||
structured_output_request_ids[request.request_id] = (
|
|
||||||
req_index)
|
|
||||||
req_index += 1
|
req_index += 1
|
||||||
self.running.append(request)
|
self.running.append(request)
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
@ -538,11 +521,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||||
any_request, len(self.running)))
|
any_request, len(self.running)))
|
||||||
|
|
||||||
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
|
|
||||||
self.requests,
|
|
||||||
structured_output_request_ids,
|
|
||||||
scheduled_spec_decode_tokens,
|
|
||||||
)
|
|
||||||
# Construct the scheduler output.
|
# Construct the scheduler output.
|
||||||
new_reqs_data = [
|
new_reqs_data = [
|
||||||
NewRequestData.from_request(
|
NewRequestData.from_request(
|
||||||
@ -556,6 +534,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
scheduled_spec_decode_tokens,
|
scheduled_spec_decode_tokens,
|
||||||
req_to_new_blocks,
|
req_to_new_blocks,
|
||||||
)
|
)
|
||||||
|
structured_output_request_ids, grammar_bitmask = (
|
||||||
|
self.get_grammar_bitmask(self.running,
|
||||||
|
scheduled_spec_decode_tokens))
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=new_reqs_data,
|
scheduled_new_reqs=new_reqs_data,
|
||||||
scheduled_cached_reqs=cached_reqs_data,
|
scheduled_cached_reqs=cached_reqs_data,
|
||||||
@ -753,6 +734,36 @@ class Scheduler(SchedulerInterface):
|
|||||||
encoder_inputs_to_schedule.append(i)
|
encoder_inputs_to_schedule.append(i)
|
||||||
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
|
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
|
||||||
|
|
||||||
|
def get_grammar_bitmask(
|
||||||
|
self,
|
||||||
|
requests: list[Request],
|
||||||
|
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||||
|
):
|
||||||
|
# NOTE: structured_output_request_ids maps
|
||||||
|
# a request's (request that uses structured output)
|
||||||
|
# request_id to its index in the batch.
|
||||||
|
# This will helps us determine to slice the grammar bitmask
|
||||||
|
# and only applies valid mask for requests that
|
||||||
|
# uses structured decoding.
|
||||||
|
structured_output_request_ids: dict[str, int] = {}
|
||||||
|
for i, req in enumerate(requests):
|
||||||
|
if req.use_structured_output:
|
||||||
|
# PERF: in case of chunked prefill,
|
||||||
|
# request might not include any new tokens.
|
||||||
|
# Therefore, we might introduce some additional
|
||||||
|
# cycle to fill in the bitmask, which could be a big no-op.
|
||||||
|
structured_output_request_ids[req.request_id] = i
|
||||||
|
|
||||||
|
if not structured_output_request_ids:
|
||||||
|
bitmask = None
|
||||||
|
else:
|
||||||
|
bitmask = self.structured_output_manager.grammar_bitmask(
|
||||||
|
self.requests,
|
||||||
|
structured_output_request_ids,
|
||||||
|
scheduled_spec_decode_tokens,
|
||||||
|
)
|
||||||
|
return structured_output_request_ids, bitmask
|
||||||
|
|
||||||
def update_from_output(
|
def update_from_output(
|
||||||
self,
|
self,
|
||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user