mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:35:58 +08:00
[V1][Feature] Enable Speculative Decoding with Structured Outputs (#14702)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
parent
7489ec0bab
commit
34120f5acd
@ -260,6 +260,7 @@ async def async_request_openai_completions(
|
|||||||
if request_func_input.model_name else request_func_input.model,
|
if request_func_input.model_name else request_func_input.model,
|
||||||
"prompt": request_func_input.prompt,
|
"prompt": request_func_input.prompt,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
"max_tokens": request_func_input.output_len,
|
"max_tokens": request_func_input.output_len,
|
||||||
"logprobs": request_func_input.logprobs,
|
"logprobs": request_func_input.logprobs,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
|
|||||||
@ -123,6 +123,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
|||||||
copy.deepcopy(schema) for _ in range(args.num_prompts)
|
copy.deepcopy(schema) for _ in range(args.num_prompts)
|
||||||
]
|
]
|
||||||
for i in range(len(json_schemas)):
|
for i in range(len(json_schemas)):
|
||||||
|
if "properties" not in json_schemas[i]:
|
||||||
|
json_schemas[i]["properties"] = {}
|
||||||
json_schemas[i]["properties"][
|
json_schemas[i]["properties"][
|
||||||
f"__optional_field_{uuid.uuid4()}"] = {
|
f"__optional_field_{uuid.uuid4()}"] = {
|
||||||
"type":
|
"type":
|
||||||
@ -134,7 +136,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
|||||||
json_schemas = [schema] * args.num_prompts
|
json_schemas = [schema] * args.num_prompts
|
||||||
|
|
||||||
def gen_prompt(index: int):
|
def gen_prompt(index: int):
|
||||||
return f"Generate an example of a user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
|
return f"Generate an example of a brief user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
|
||||||
|
|
||||||
def get_schema(index: int):
|
def get_schema(index: int):
|
||||||
return json_schemas[index % len(json_schemas)]
|
return json_schemas[index % len(json_schemas)]
|
||||||
@ -231,7 +233,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
|||||||
idx -= len_dataset
|
idx -= len_dataset
|
||||||
schema = dataset["schema"][idx]
|
schema = dataset["schema"][idx]
|
||||||
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
|
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
|
||||||
tokenize=False)
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
input_len = len(tokenizer(prompt).input_ids)
|
input_len = len(tokenizer(prompt).input_ids)
|
||||||
completion = dataset["completion"][idx]
|
completion = dataset["completion"][idx]
|
||||||
|
|
||||||
@ -849,7 +852,7 @@ if __name__ == "__main__":
|
|||||||
'json', 'json-unique', 'grammar', 'regex',
|
'json', 'json-unique', 'grammar', 'regex',
|
||||||
'choice', 'xgrammar_bench'
|
'choice', 'xgrammar_bench'
|
||||||
])
|
])
|
||||||
parser.add_argument("--json_schema_path",
|
parser.add_argument("--json-schema-path",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to json schema.")
|
help="Path to json schema.")
|
||||||
|
|||||||
@ -16,13 +16,31 @@ from vllm.outputs import RequestOutput
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
|
|
||||||
|
NGRAM_SPEC_CONFIG = {
|
||||||
|
"model": "[ngram]",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"prompt_lookup_max": 5,
|
||||||
|
"prompt_lookup_min": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
EAGLE_SPEC_CONFIG = {
|
||||||
|
"method": "eagle",
|
||||||
|
"model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
}
|
||||||
|
|
||||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"),
|
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"),
|
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"),
|
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
||||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"),
|
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
||||||
#FIXME: This test is flaky on CI thus disabled
|
#FIXME: This test is flaky on CI thus disabled
|
||||||
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
|
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
|
||||||
|
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
|
||||||
|
NGRAM_SPEC_CONFIG),
|
||||||
|
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
|
||||||
|
("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto",
|
||||||
|
EAGLE_SPEC_CONFIG)
|
||||||
]
|
]
|
||||||
|
|
||||||
PARAMS_MODELS_TOKENIZER_MODE = [
|
PARAMS_MODELS_TOKENIZER_MODE = [
|
||||||
@ -45,8 +63,9 @@ class CarDescription(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("model_name, guided_decoding_backend, tokenizer_mode",
|
@pytest.mark.parametrize(
|
||||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
|
"model_name, guided_decoding_backend, tokenizer_mode, speculative_config",
|
||||||
|
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
|
||||||
def test_structured_output(
|
def test_structured_output(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
sample_json_schema: dict[str, Any],
|
sample_json_schema: dict[str, Any],
|
||||||
@ -58,6 +77,7 @@ def test_structured_output(
|
|||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
tokenizer_mode: str,
|
tokenizer_mode: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
speculative_config: dict[str, Any],
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
@ -71,7 +91,8 @@ def test_structured_output(
|
|||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
guided_decoding_backend=guided_decoding_backend,
|
guided_decoding_backend=guided_decoding_backend,
|
||||||
guided_decoding_disable_any_whitespace=True,
|
guided_decoding_disable_any_whitespace=True,
|
||||||
tokenizer_mode=tokenizer_mode)
|
tokenizer_mode=tokenizer_mode,
|
||||||
|
speculative_config=speculative_config)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Test 1: Generate JSON output based on a provided schema
|
# Test 1: Generate JSON output based on a provided schema
|
||||||
|
|||||||
@ -441,7 +441,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
|
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
|
||||||
self.requests,
|
self.requests,
|
||||||
structured_output_request_ids,
|
structured_output_request_ids,
|
||||||
len(self.running),
|
scheduled_spec_decode_tokens,
|
||||||
)
|
)
|
||||||
# Construct the scheduler output.
|
# Construct the scheduler output.
|
||||||
new_reqs_data = [
|
new_reqs_data = [
|
||||||
@ -682,10 +682,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.encoder_cache_manager.free_encoder_input(
|
self.encoder_cache_manager.free_encoder_input(
|
||||||
request, input_id)
|
request, input_id)
|
||||||
|
|
||||||
# Add newly generated spec token ids to the request.
|
|
||||||
if spec_token_ids is not None:
|
|
||||||
request.spec_token_ids = spec_token_ids[req_index]
|
|
||||||
|
|
||||||
stopped = False
|
stopped = False
|
||||||
new_logprobs = None
|
new_logprobs = None
|
||||||
new_token_ids = generated_token_ids
|
new_token_ids = generated_token_ids
|
||||||
@ -717,6 +713,17 @@ class Scheduler(SchedulerInterface):
|
|||||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||||
req_id, new_token_ids)
|
req_id, new_token_ids)
|
||||||
|
|
||||||
|
# Add newly generated spec token ids to the request.
|
||||||
|
if spec_token_ids is not None:
|
||||||
|
if request.use_structured_output:
|
||||||
|
metadata = request.structured_output_request
|
||||||
|
assert metadata is not None and metadata.grammar is not None
|
||||||
|
# Needs to happen after new_token_ids are accepted.
|
||||||
|
request.spec_token_ids = metadata.grammar.validate_tokens(
|
||||||
|
spec_token_ids[req_index])
|
||||||
|
else:
|
||||||
|
request.spec_token_ids = spec_token_ids[req_index]
|
||||||
|
|
||||||
# Get prompt logprobs for this request.
|
# Get prompt logprobs for this request.
|
||||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||||
if new_token_ids:
|
if new_token_ids:
|
||||||
|
|||||||
@ -27,6 +27,7 @@ class StructuredOutputManager:
|
|||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
self.backend: Optional[StructuredOutputBackend] = None
|
self.backend: Optional[StructuredOutputBackend] = None
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
|
||||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
self._grammar_bitmask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# The default max_workers if not specified is the number of CPUs * 5,
|
# The default max_workers if not specified is the number of CPUs * 5,
|
||||||
@ -80,7 +81,7 @@ class StructuredOutputManager:
|
|||||||
self,
|
self,
|
||||||
requests: dict[str, Request],
|
requests: dict[str, Request],
|
||||||
structured_output_request_ids: dict[str, int],
|
structured_output_request_ids: dict[str, int],
|
||||||
batch_len: int,
|
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||||
) -> Optional[npt.NDArray[np.int32]]:
|
) -> Optional[npt.NDArray[np.int32]]:
|
||||||
# Prepare the structured output bitmask for this batch.
|
# Prepare the structured output bitmask for this batch.
|
||||||
if not structured_output_request_ids:
|
if not structured_output_request_ids:
|
||||||
@ -88,20 +89,52 @@ class StructuredOutputManager:
|
|||||||
|
|
||||||
if self._grammar_bitmask is None:
|
if self._grammar_bitmask is None:
|
||||||
assert self.backend is not None
|
assert self.backend is not None
|
||||||
self._grammar_bitmask = self.backend.allocate_token_bitmask(
|
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||||
self.vllm_config.scheduler_config.max_num_seqs)
|
if self.vllm_config.speculative_config is not None:
|
||||||
|
max_num_spec_tokens = self.vllm_config.\
|
||||||
|
speculative_config.num_speculative_tokens
|
||||||
|
else:
|
||||||
|
max_num_spec_tokens = 0
|
||||||
|
|
||||||
# Fill the bitmask using the index of each request equal to its
|
# Allocate a bitmask for each token needing to be checked:
|
||||||
# position in the batch. Resize the bitmask down to the size of
|
# one for each speculative position, and one more for the
|
||||||
# the batch.
|
# bonus token / non-speculative token.
|
||||||
bitmask_tensor = self._grammar_bitmask
|
self._grammar_bitmask = \
|
||||||
for req_id, batch_index in structured_output_request_ids.items():
|
self.backend.allocate_token_bitmask(
|
||||||
|
max_batch_size * (1 + max_num_spec_tokens))
|
||||||
|
|
||||||
|
# Generate a batched bitmask for all structured output requests.
|
||||||
|
# When speculative decoding is enabled, we need to include multiple
|
||||||
|
# masks for each request, one for each possible bonus token position.
|
||||||
|
# These are stored inline in the tensor and unpacked by the gpu runner.
|
||||||
|
cumulative_index = 0
|
||||||
|
ordered_seq = sorted(structured_output_request_ids.items(),
|
||||||
|
key=lambda x: x[1])
|
||||||
|
# NOTE: This outer loop can likely be parallelized to improve
|
||||||
|
# performance of bitmask generation for large batches.
|
||||||
|
for req_id, _ in ordered_seq:
|
||||||
request = requests[req_id].structured_output_request
|
request = requests[req_id].structured_output_request
|
||||||
assert request is not None and request.grammar is not None
|
assert request is not None and request.grammar is not None
|
||||||
if not request.grammar.is_terminated():
|
state_advancements = 0
|
||||||
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
|
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
|
||||||
if batch_len < self._grammar_bitmask.shape[0]:
|
for i, token in enumerate(req_tokens):
|
||||||
bitmask_tensor = self._grammar_bitmask[:batch_len]
|
if not request.grammar.is_terminated():
|
||||||
|
request.grammar.fill_bitmask(self._grammar_bitmask,
|
||||||
|
cumulative_index)
|
||||||
|
if token is not None:
|
||||||
|
# In order to generate the correct bitmask for each
|
||||||
|
# position in the speculative sequence, we advance
|
||||||
|
# the FSM state for each speculative token and rollback
|
||||||
|
# to restore the previous state when we are finished.
|
||||||
|
assert request.grammar.accept_tokens(req_id, [token])
|
||||||
|
state_advancements += 1
|
||||||
|
cumulative_index += 1
|
||||||
|
if state_advancements > 0:
|
||||||
|
request.grammar.rollback(state_advancements)
|
||||||
|
|
||||||
|
bitmask_tensor = self._grammar_bitmask
|
||||||
|
if cumulative_index < self._grammar_bitmask.shape[0]:
|
||||||
|
bitmask_tensor = self._grammar_bitmask[:cumulative_index]
|
||||||
|
|
||||||
# After finishing with the xgrammar operations, we convert to
|
# After finishing with the xgrammar operations, we convert to
|
||||||
# np.ndarray, because that is much more efficient for serialization
|
# np.ndarray, because that is much more efficient for serialization
|
||||||
|
|||||||
@ -144,6 +144,27 @@ class GuidanceGrammar(StructuredOutputGrammar):
|
|||||||
|
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||||
|
"""Checks if the list of tokens are accepted by the parser in sequence.
|
||||||
|
Will not advance the parser.
|
||||||
|
|
||||||
|
Returns the prefix list of tokens that are accepted by the parser.
|
||||||
|
"""
|
||||||
|
if len(tokens) == 0:
|
||||||
|
return []
|
||||||
|
if self.ll_matcher.is_stopped():
|
||||||
|
return []
|
||||||
|
|
||||||
|
num_tokens = self.ll_matcher.validate_tokens(tokens)
|
||||||
|
|
||||||
|
self.check_error()
|
||||||
|
|
||||||
|
return tokens[:num_tokens]
|
||||||
|
|
||||||
|
def rollback(self, num_tokens: int) -> None:
|
||||||
|
self.ll_matcher.rollback(num_tokens)
|
||||||
|
self.check_error()
|
||||||
|
|
||||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||||
# this will automatically return [EOS] mask if the matcher is stopped
|
# this will automatically return [EOS] mask if the matcher is stopped
|
||||||
# or otherwise in an error state
|
# or otherwise in an error state
|
||||||
|
|||||||
@ -35,6 +35,30 @@ class StructuredOutputGrammar(ABC):
|
|||||||
bool: True if the tokens are accepted, False otherwise.
|
bool: True if the tokens are accepted, False otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||||
|
"""
|
||||||
|
Validates the provided tokens against the grammar.
|
||||||
|
Will not advance the FSM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens (list[int]): A list of token IDs to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[int]: A list of accepted token IDs. Will be a prefix
|
||||||
|
of the input tokens, and empty if none are accepted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def rollback(self, num_tokens: int) -> None:
|
||||||
|
"""
|
||||||
|
Rolls back the state of the grammar by a specified number of tokens.
|
||||||
|
Will also revert counters for the number of processed tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_tokens (int): The number of tokens to roll back.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
|
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -40,6 +40,11 @@ class XgrammarBackend(StructuredOutputBackend):
|
|||||||
self.disable_any_whitespace = \
|
self.disable_any_whitespace = \
|
||||||
vllm_config.decoding_config.disable_any_whitespace
|
vllm_config.decoding_config.disable_any_whitespace
|
||||||
|
|
||||||
|
self.num_speculative_tokens = 0
|
||||||
|
if self.vllm_config.speculative_config is not None:
|
||||||
|
self.num_speculative_tokens = \
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
|
||||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
@ -118,7 +123,10 @@ class XgrammarBackend(StructuredOutputBackend):
|
|||||||
f"grammar is not of valid supported types. ({request_type!s})")
|
f"grammar is not of valid supported types. ({request_type!s})")
|
||||||
|
|
||||||
return XgrammarGrammar(
|
return XgrammarGrammar(
|
||||||
matcher=xgr.GrammarMatcher(ctx),
|
matcher=xgr.GrammarMatcher(
|
||||||
|
ctx,
|
||||||
|
max_rollback_tokens=self.num_speculative_tokens,
|
||||||
|
),
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
)
|
)
|
||||||
@ -136,7 +144,6 @@ class XgrammarGrammar(StructuredOutputGrammar):
|
|||||||
# supporting different backends, in the future.
|
# supporting different backends, in the future.
|
||||||
# For now, just xgrammar.
|
# For now, just xgrammar.
|
||||||
#
|
#
|
||||||
# TODO: support max_rollback_tokens
|
|
||||||
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
|
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
|
||||||
# for jump-forward decoding
|
# for jump-forward decoding
|
||||||
|
|
||||||
@ -163,6 +170,27 @@ class XgrammarGrammar(StructuredOutputGrammar):
|
|||||||
self.num_processed_tokens += 1
|
self.num_processed_tokens += 1
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||||
|
"""Checks if the list of tokens are accepted by the FSM in sequence.
|
||||||
|
Will not advance the FSM.
|
||||||
|
|
||||||
|
Returns the prefix list of tokens that are accepted by the FSM.
|
||||||
|
"""
|
||||||
|
accepted_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
if self.matcher.accept_token(token):
|
||||||
|
accepted_tokens.append(token)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if len(accepted_tokens) > 0:
|
||||||
|
# Rollback the FSM to the initial state
|
||||||
|
self.matcher.rollback(len(accepted_tokens))
|
||||||
|
return accepted_tokens
|
||||||
|
|
||||||
|
def rollback(self, num_tokens: int) -> None:
|
||||||
|
self.matcher.rollback(num_tokens)
|
||||||
|
self.num_processed_tokens -= num_tokens
|
||||||
|
|
||||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||||
self.matcher.fill_next_token_bitmask(bitmask, idx)
|
self.matcher.fill_next_token_bitmask(bitmask, idx)
|
||||||
|
|
||||||
|
|||||||
@ -957,46 +957,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
):
|
):
|
||||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
|
||||||
# so we receive it in that format.
|
|
||||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||||
if grammar_bitmask is None:
|
if grammar_bitmask is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# We receive the structured output bitmask from the scheduler, but the
|
# We receive the structured output bitmask from the scheduler,
|
||||||
# indices of the requests in the batch may not match the indices of
|
# compacted to contain bitmasks only for structured output requests.
|
||||||
# the bitmask since the scheduler doesn't know how the gpu runner is
|
# The order of the requests in the bitmask is not guaranteed to be the
|
||||||
# ordering the requests in the batch. We need to sort the bitmask to
|
# same as the order of the requests in the gpu runner's batch. We need
|
||||||
# match the order of the requests used here.
|
# to sort the bitmask to match the order of the requests used here.
|
||||||
|
|
||||||
|
# Get the batch indices of the structured output requests.
|
||||||
|
# Keep track of the number of speculative tokens scheduled for every
|
||||||
|
# request in the batch, as the logit indices are offset by this amount.
|
||||||
struct_out_req_batch_indices: dict[str, int] = {}
|
struct_out_req_batch_indices: dict[str, int] = {}
|
||||||
indices_match = True
|
cumulative_offset = 0
|
||||||
for req_id in self.input_batch.req_ids:
|
seq = sorted(self.input_batch.req_id_to_index.items(),
|
||||||
mask_index = scheduler_output.structured_output_request_ids.get(
|
key=lambda x: x[1])
|
||||||
req_id)
|
for req_id, batch_index in seq:
|
||||||
if mask_index is None:
|
logit_index = batch_index + cumulative_offset
|
||||||
# not a structured output request
|
cumulative_offset += len(
|
||||||
continue
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||||
batch_index = self.input_batch.req_id_to_index[req_id]
|
if req_id in scheduler_output.structured_output_request_ids:
|
||||||
if batch_index != mask_index:
|
struct_out_req_batch_indices[req_id] = logit_index
|
||||||
indices_match = False
|
|
||||||
struct_out_req_batch_indices[req_id] = batch_index
|
|
||||||
|
|
||||||
if not indices_match:
|
out_indices = []
|
||||||
# Sort the bitmask to match the order of the requests
|
|
||||||
sorted_bitmask = np.zeros_like(grammar_bitmask)
|
|
||||||
for req_id, batch_index in struct_out_req_batch_indices.items():
|
|
||||||
orig_index = scheduler_output.structured_output_request_ids[
|
|
||||||
req_id]
|
|
||||||
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
|
|
||||||
grammar_bitmask = sorted_bitmask
|
|
||||||
|
|
||||||
|
# Reorder the bitmask to match the order of the requests in the batch.
|
||||||
|
sorted_bitmask = np.zeros_like(grammar_bitmask,
|
||||||
|
shape=(logits.shape[0],
|
||||||
|
grammar_bitmask.shape[1]))
|
||||||
|
cumulative_index = 0
|
||||||
|
seq = sorted(scheduler_output.structured_output_request_ids.items(),
|
||||||
|
key=lambda x: x[1])
|
||||||
|
for req_id, _ in seq:
|
||||||
|
logit_index = struct_out_req_batch_indices[req_id]
|
||||||
|
num_spec_tokens = len(
|
||||||
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||||
|
for i in range(1 + num_spec_tokens):
|
||||||
|
sorted_bitmask[logit_index + i] = \
|
||||||
|
grammar_bitmask[cumulative_index + i]
|
||||||
|
out_indices.append(logit_index + i)
|
||||||
|
cumulative_index += 1 + num_spec_tokens
|
||||||
|
grammar_bitmask = sorted_bitmask
|
||||||
|
|
||||||
|
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||||
|
# so we receive it in that format.
|
||||||
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
||||||
|
|
||||||
# TODO: compatibility with spec decode
|
|
||||||
xgr.apply_token_bitmask_inplace(
|
xgr.apply_token_bitmask_inplace(
|
||||||
logits,
|
logits,
|
||||||
grammar_bitmask.to(self.device, non_blocking=True),
|
grammar_bitmask.to(self.device, non_blocking=True),
|
||||||
indices=list(struct_out_req_batch_indices.values()),
|
indices=out_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user