mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 07:04:27 +08:00
[BugFix] Make penalties and bad_words work with async scheduling (#26467)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
eef921f45e
commit
5bc26c438d
@ -28,9 +28,8 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
|
||||
sampling_param_tests: list[dict[str, Any]] = [
|
||||
dict(),
|
||||
# dict(min_tokens=20),
|
||||
# TODO enable these with https://github.com/vllm-project/vllm/pull/26467.
|
||||
# dict(repetition_penalty=0.1),
|
||||
# dict(bad_words=[]),
|
||||
dict(presence_penalty=-1.0),
|
||||
dict(bad_words=["the", " the"]),
|
||||
]
|
||||
|
||||
default_params = dict(
|
||||
@ -42,9 +41,9 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
# m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1")
|
||||
|
||||
outputs = []
|
||||
outputs: list[tuple[str, list]] = []
|
||||
for test_preemption in [False, True]:
|
||||
for executor in ["uni", "mp"]:
|
||||
for executor in ["mp", "uni"]:
|
||||
for async_scheduling in [False, True]:
|
||||
cache_arg: dict[str, Any] = (
|
||||
dict(num_gpu_blocks_override=32)
|
||||
@ -78,6 +77,21 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not outputs:
|
||||
# First check that the different parameter configs
|
||||
# actually result in different output.
|
||||
for other_test, params in zip(
|
||||
results[1:], sampling_param_tests[1:]
|
||||
):
|
||||
with pytest.raises(AssertionError):
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=results[0],
|
||||
outputs_1_lst=other_test,
|
||||
name_0=f"baseline params={params}",
|
||||
name_1=f"other params={params}",
|
||||
)
|
||||
|
||||
outputs.append((test_config, results))
|
||||
|
||||
baseline_config, baseline_tests = outputs[0]
|
||||
|
||||
@ -737,7 +737,9 @@ class Scheduler(SchedulerInterface):
|
||||
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
|
||||
)
|
||||
num_computed_tokens.append(req.num_computed_tokens)
|
||||
num_output_tokens.append(req.num_output_tokens)
|
||||
num_output_tokens.append(
|
||||
req.num_output_tokens + req.num_output_placeholders
|
||||
)
|
||||
|
||||
return CachedRequestData(
|
||||
req_ids=req_ids,
|
||||
|
||||
@ -79,6 +79,7 @@ class InputBatch:
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
kernel_block_sizes: list[int],
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
logitsprocs_need_output_token_ids: bool = False,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
@ -240,6 +241,7 @@ class InputBatch:
|
||||
# Store provided logitsprocs. If none are provided, initialize empty
|
||||
# data structure
|
||||
self.logitsprocs = logitsprocs or LogitsProcessors()
|
||||
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
|
||||
|
||||
# Store last speculative tokens for sampler.
|
||||
self.spec_token_ids: list[Optional[list[int]]] = []
|
||||
@ -252,6 +254,11 @@ class InputBatch:
|
||||
# Cached reference to the GPU tensor of previously sampled tokens
|
||||
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
||||
# These are used to update output_token_ids with real sampled
|
||||
# ids from prior step, if required by current sampling params
|
||||
# (e.g. penalties).
|
||||
self.sampled_token_ids_cpu: Optional[torch.Tensor] = None
|
||||
self.async_copy_ready_event: Optional[torch.cuda.Event] = None
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
@ -776,6 +783,19 @@ class InputBatch:
|
||||
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
|
||||
)
|
||||
|
||||
# Only set output_token_ids if required by the current requests'
|
||||
# sampling parameters.
|
||||
needs_output_token_ids = (
|
||||
not self.no_penalties
|
||||
or bool(self.bad_words_token_ids)
|
||||
or self.logitsprocs_need_output_token_ids
|
||||
)
|
||||
output_token_ids = (
|
||||
cast(list[list[int]], self.req_output_token_ids)
|
||||
if needs_output_token_ids
|
||||
else []
|
||||
)
|
||||
|
||||
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
if not self.no_allowed_token_ids:
|
||||
assert self.allowed_token_ids_mask is not None
|
||||
@ -798,7 +818,7 @@ class InputBatch:
|
||||
frequency_penalties=self.frequency_penalties[:num_reqs],
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
|
||||
no_penalties=self.no_penalties,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
@ -859,6 +879,52 @@ class InputBatch:
|
||||
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
def set_async_sampled_token_ids(
|
||||
self,
|
||||
sampled_token_ids_cpu: torch.Tensor,
|
||||
async_copy_ready_event: torch.cuda.Event,
|
||||
) -> None:
|
||||
"""
|
||||
In async scheduling case, store ref to sampled_token_ids_cpu
|
||||
tensor and corresponding copy-ready event. Used to repair
|
||||
output_token_ids prior to sampling, if needed by logits processors.
|
||||
"""
|
||||
if self.sampling_metadata.output_token_ids:
|
||||
self.sampled_token_ids_cpu = sampled_token_ids_cpu
|
||||
self.async_copy_ready_event = async_copy_ready_event
|
||||
else:
|
||||
self.sampled_token_ids_cpu = None
|
||||
self.async_copy_ready_event = None
|
||||
|
||||
def update_async_output_token_ids(self) -> None:
|
||||
"""
|
||||
In async scheduling case, update output_token_ids in sampling metadata
|
||||
from prior steps sampled token ids once they've finished copying to CPU.
|
||||
This is called right before they are needed by the logits processors.
|
||||
"""
|
||||
output_token_ids = self.sampling_metadata.output_token_ids
|
||||
if self.sampled_token_ids_cpu is None or not output_token_ids:
|
||||
# Output token ids not needed or not async scheduling.
|
||||
return
|
||||
|
||||
assert self.prev_req_id_to_index is not None
|
||||
sampled_token_ids = None
|
||||
for index, req_id in enumerate(self.req_ids):
|
||||
prev_index = self.prev_req_id_to_index.get(req_id)
|
||||
if prev_index is None:
|
||||
continue
|
||||
req_output_token_ids = output_token_ids[index]
|
||||
if not req_output_token_ids or req_output_token_ids[-1] != -1:
|
||||
# Final output id is not a placeholder, some tokens must have
|
||||
# been discarded after a kv-load failure.
|
||||
continue
|
||||
if sampled_token_ids is None:
|
||||
assert self.async_copy_ready_event is not None
|
||||
self.async_copy_ready_event.synchronize()
|
||||
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist()
|
||||
# Replace placeholder token id with actual sampled id.
|
||||
req_output_token_ids[-1] = sampled_token_ids[prev_index]
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
@ -178,7 +178,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
self._invalid_req_indices = invalid_req_indices
|
||||
|
||||
# Event on the copy stream so we can synchronize the non-blocking copy.
|
||||
self._async_copy_ready_event = torch.cuda.Event()
|
||||
self.async_copy_ready_event = torch.cuda.Event()
|
||||
|
||||
# Keep a reference to the device tensor to avoid it being
|
||||
# deallocated until we finish copying it to the host.
|
||||
@ -188,22 +188,22 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
default_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(async_output_copy_stream):
|
||||
async_output_copy_stream.wait_stream(default_stream)
|
||||
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
|
||||
self.sampled_token_ids_cpu = self._sampled_token_ids.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
self._async_copy_ready_event.record()
|
||||
self.async_copy_ready_event.record()
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
"""Copy the device tensors to the host and return a ModelRunnerOutput.
|
||||
|
||||
This function blocks until the copy is finished.
|
||||
"""
|
||||
self._async_copy_ready_event.synchronize()
|
||||
self.async_copy_ready_event.synchronize()
|
||||
|
||||
# Release the device tensor once the copy has completed
|
||||
del self._sampled_token_ids
|
||||
|
||||
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
|
||||
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
|
||||
for i in self._invalid_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
|
||||
@ -349,6 +349,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# solution, we initialize the input batch here, and re-initialize it
|
||||
# in `initialize_kv_cache` if the block_sizes here is different from
|
||||
# the block_sizes in the kv cache config.
|
||||
custom_logitsprocs = model_config.logits_processors
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
# We need to use the encoder length for encoder-decoer
|
||||
@ -366,8 +367,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
self.is_pooling_model,
|
||||
self.vllm_config.model_config.logits_processors,
|
||||
custom_logitsprocs,
|
||||
),
|
||||
# We currently don't know whether a particular custom logits processor
|
||||
# uses output token ids so we set this conservatively.
|
||||
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
)
|
||||
|
||||
@ -2210,6 +2214,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
# Update output token ids with tokens sampled in last step
|
||||
# if async scheduling and required by current sampling params.
|
||||
self.input_batch.update_async_output_token_ids()
|
||||
return self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
@ -2666,13 +2673,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if not self.use_async_scheduling:
|
||||
return output
|
||||
|
||||
return AsyncGPUModelRunnerOutput(
|
||||
async_output = AsyncGPUModelRunnerOutput(
|
||||
model_runner_output=output,
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
)
|
||||
|
||||
# Save ref of sampled_token_ids CPU tensor if the batch contains
|
||||
# any requests with sampling params that that require output ids.
|
||||
self.input_batch.set_async_sampled_token_ids(
|
||||
async_output.sampled_token_ids_cpu,
|
||||
async_output.async_copy_ready_event,
|
||||
)
|
||||
|
||||
return async_output
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
if self._draft_token_ids is None:
|
||||
return None
|
||||
@ -4198,6 +4214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=self.input_batch.logitsprocs,
|
||||
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
num_speculative_tokens=(
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user