vllm/tests/v1/e2e/test_async_sched_and_preempt.py
Nick Hill 4fe5895361
[AsyncScheduling] Make async overlap work with logprobs (#27615)
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-10-28 22:35:54 +00:00

142 lines
5.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import pytest
import torch._dynamo.config as dynamo_config
from vllm import SamplingParams
from vllm.logprobs import Logprob
from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal
MODEL = "Qwen/Qwen3-0.6B"
@dynamo_config.patch(cache_size_limit=16)
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, and various sampling parameters."""
first_prompt = (
"The following numbers of the sequence "
+ ", ".join(str(i) for i in range(10))
+ " are:"
)
example_prompts = [first_prompt, "In one word, the capital of France is "] + [
f"Tell me about the number {i}: " for i in range(32)
]
sampling_param_tests: list[dict[str, Any]] = [
dict(),
# dict(min_tokens=20),
dict(presence_penalty=-1.0),
dict(bad_words=["the", " the"]),
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
]
default_params = dict(
temperature=0.0, # greedy
max_tokens=20,
)
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
# m.setenv("VLLM_BATCH_INVARIANT", "1")
outputs: list[tuple[str, list]] = []
for test_preemption in [False, True]:
for executor in ["mp", "uni"]:
for async_scheduling in [False, True]:
cache_arg: dict[str, Any] = (
dict(num_gpu_blocks_override=32)
if test_preemption
else dict(gpu_memory_utilization=0.7)
)
test_config = (
f"executor={executor}, preemption={test_preemption},"
f" async_sched={async_scheduling}"
)
print("-" * 80)
print(f"---- TESTING: {test_config}")
print("-" * 80)
with VllmRunner(
MODEL,
max_model_len=512,
enforce_eager=True,
async_scheduling=async_scheduling,
distributed_executor_backend=executor,
dtype="float32", # avoid precision errors
**cache_arg,
) as vllm_model:
results = []
for override_params in sampling_param_tests:
print(f"----------- RUNNING PARAMS: {override_params}")
results.append(
vllm_model.generate(
example_prompts,
sampling_params=SamplingParams(
**default_params, **override_params
),
return_logprobs=True,
)
)
if not outputs:
# First check that the different parameter configs
# actually result in different output.
for (other_test_outs, other_test_logprobs), params in zip(
results[1:], sampling_param_tests[1:]
):
with pytest.raises(AssertionError):
check_outputs_equal(
outputs_0_lst=results[0][0],
outputs_1_lst=other_test_outs,
name_0=f"baseline params={params}",
name_1=f"other params={params}",
)
assert _all_logprobs_match(
results[0][1], other_test_logprobs
)
outputs.append((test_config, results))
baseline_config, baseline_tests = outputs[0]
for test_config, test_outputs in outputs[1:]:
for (base_outs, base_logprobs), (test_outs, test_logprobs), params in zip(
baseline_tests, test_outputs, sampling_param_tests
):
check_outputs_equal(
outputs_0_lst=base_outs,
outputs_1_lst=test_outs,
name_0=f"baseline=[{baseline_config}], params={params}",
name_1=f"config=[{test_config}], params={params}",
)
assert _all_logprobs_match(base_logprobs, test_logprobs)
print(f"PASSED: config=[{test_config}], params={params}")
def _all_logprobs_match(req_a, req_b) -> bool:
return (
req_a == req_b
or len(req_a) == len(req_b)
and all(
len(seq_a) == len(seq_b)
and all(_logprobs_match(a, b) for a, b in zip(seq_a, seq_b))
for seq_a, seq_b in zip(req_a, req_b)
)
)
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
return len(lps_a) == len(lps_b) and all(
a.decoded_token == b.decoded_token
and a.rank == b.rank
and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6)
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
)