[BugFix] Make penalties and bad_words work with async scheduling (#26467)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-10-10 16:27:04 -07:00 committed by GitHub
parent eef921f45e
commit 5bc26c438d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 113 additions and 14 deletions

View File

@ -28,9 +28,8 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
sampling_param_tests: list[dict[str, Any]] = [
dict(),
# dict(min_tokens=20),
# TODO enable these with https://github.com/vllm-project/vllm/pull/26467.
# dict(repetition_penalty=0.1),
# dict(bad_words=[]),
dict(presence_penalty=-1.0),
dict(bad_words=["the", " the"]),
]
default_params = dict(
@ -42,9 +41,9 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
# m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1")
outputs = []
outputs: list[tuple[str, list]] = []
for test_preemption in [False, True]:
for executor in ["uni", "mp"]:
for executor in ["mp", "uni"]:
for async_scheduling in [False, True]:
cache_arg: dict[str, Any] = (
dict(num_gpu_blocks_override=32)
@ -78,6 +77,21 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
),
)
)
if not outputs:
# First check that the different parameter configs
# actually result in different output.
for other_test, params in zip(
results[1:], sampling_param_tests[1:]
):
with pytest.raises(AssertionError):
check_outputs_equal(
outputs_0_lst=results[0],
outputs_1_lst=other_test,
name_0=f"baseline params={params}",
name_1=f"other params={params}",
)
outputs.append((test_config, results))
baseline_config, baseline_tests = outputs[0]

View File

@ -737,7 +737,9 @@ class Scheduler(SchedulerInterface):
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
)
num_computed_tokens.append(req.num_computed_tokens)
num_output_tokens.append(req.num_output_tokens)
num_output_tokens.append(
req.num_output_tokens + req.num_output_placeholders
)
return CachedRequestData(
req_ids=req_ids,

View File

@ -79,6 +79,7 @@ class InputBatch:
block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[int],
logitsprocs: Optional[LogitsProcessors] = None,
logitsprocs_need_output_token_ids: bool = False,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
@ -240,6 +241,7 @@ class InputBatch:
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
# Store last speculative tokens for sampler.
self.spec_token_ids: list[Optional[list[int]]] = []
@ -252,6 +254,11 @@ class InputBatch:
# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None
# These are used to update output_token_ids with real sampled
# ids from prior step, if required by current sampling params
# (e.g. penalties).
self.sampled_token_ids_cpu: Optional[torch.Tensor] = None
self.async_copy_ready_event: Optional[torch.cuda.Event] = None
@property
def req_ids(self) -> list[str]:
@ -776,6 +783,19 @@ class InputBatch:
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
)
# Only set output_token_ids if required by the current requests'
# sampling parameters.
needs_output_token_ids = (
not self.no_penalties
or bool(self.bad_words_token_ids)
or self.logitsprocs_need_output_token_ids
)
output_token_ids = (
cast(list[list[int]], self.req_output_token_ids)
if needs_output_token_ids
else []
)
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
@ -798,7 +818,7 @@ class InputBatch:
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
output_token_ids=output_token_ids,
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
@ -859,6 +879,52 @@ class InputBatch:
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
def set_async_sampled_token_ids(
self,
sampled_token_ids_cpu: torch.Tensor,
async_copy_ready_event: torch.cuda.Event,
) -> None:
"""
In async scheduling case, store ref to sampled_token_ids_cpu
tensor and corresponding copy-ready event. Used to repair
output_token_ids prior to sampling, if needed by logits processors.
"""
if self.sampling_metadata.output_token_ids:
self.sampled_token_ids_cpu = sampled_token_ids_cpu
self.async_copy_ready_event = async_copy_ready_event
else:
self.sampled_token_ids_cpu = None
self.async_copy_ready_event = None
def update_async_output_token_ids(self) -> None:
"""
In async scheduling case, update output_token_ids in sampling metadata
from prior steps sampled token ids once they've finished copying to CPU.
This is called right before they are needed by the logits processors.
"""
output_token_ids = self.sampling_metadata.output_token_ids
if self.sampled_token_ids_cpu is None or not output_token_ids:
# Output token ids not needed or not async scheduling.
return
assert self.prev_req_id_to_index is not None
sampled_token_ids = None
for index, req_id in enumerate(self.req_ids):
prev_index = self.prev_req_id_to_index.get(req_id)
if prev_index is None:
continue
req_output_token_ids = output_token_ids[index]
if not req_output_token_ids or req_output_token_ids[-1] != -1:
# Final output id is not a placeholder, some tokens must have
# been discarded after a kv-load failure.
continue
if sampled_token_ids is None:
assert self.async_copy_ready_event is not None
self.async_copy_ready_event.synchronize()
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist()
# Replace placeholder token id with actual sampled id.
req_output_token_ids[-1] = sampled_token_ids[prev_index]
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)

View File

@ -178,7 +178,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self._invalid_req_indices = invalid_req_indices
# Event on the copy stream so we can synchronize the non-blocking copy.
self._async_copy_ready_event = torch.cuda.Event()
self.async_copy_ready_event = torch.cuda.Event()
# Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host.
@ -188,22 +188,22 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(async_output_copy_stream):
async_output_copy_stream.wait_stream(default_stream)
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
self.sampled_token_ids_cpu = self._sampled_token_ids.to(
"cpu", non_blocking=True
)
self._async_copy_ready_event.record()
self.async_copy_ready_event.record()
def get_output(self) -> ModelRunnerOutput:
"""Copy the device tensors to the host and return a ModelRunnerOutput.
This function blocks until the copy is finished.
"""
self._async_copy_ready_event.synchronize()
self.async_copy_ready_event.synchronize()
# Release the device tensor once the copy has completed
del self._sampled_token_ids
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
@ -349,6 +349,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config.
custom_logitsprocs = model_config.logits_processors
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
# We need to use the encoder length for encoder-decoer
@ -366,8 +367,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.device,
self.pin_memory,
self.is_pooling_model,
self.vllm_config.model_config.logits_processors,
custom_logitsprocs,
),
# We currently don't know whether a particular custom logits processor
# uses output token ids so we set this conservatively.
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model,
)
@ -2210,6 +2214,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
# Update output token ids with tokens sampled in last step
# if async scheduling and required by current sampling params.
self.input_batch.update_async_output_token_ids()
return self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
@ -2666,13 +2673,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not self.use_async_scheduling:
return output
return AsyncGPUModelRunnerOutput(
async_output = AsyncGPUModelRunnerOutput(
model_runner_output=output,
sampled_token_ids=sampler_output.sampled_token_ids,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)
# Save ref of sampled_token_ids CPU tensor if the batch contains
# any requests with sampling params that that require output ids.
self.input_batch.set_async_sampled_token_ids(
async_output.sampled_token_ids_cpu,
async_output.async_copy_ready_event,
)
return async_output
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None:
return None
@ -4198,6 +4214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kernel_block_sizes=kernel_block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model,
num_speculative_tokens=(
self.vllm_config.speculative_config.num_speculative_tokens