vllm/tests/v1/core/test_scheduler.py
Harry Mellor 951445a52d
Remove default values from InitVars so that they're not stored (#29859)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-12-02 12:16:37 +00:00

3329 lines
120 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from unittest.mock import Mock
import pytest
import torch
from vllm.config import (
CacheConfig,
ECTransferConfig,
KVTransferConfig,
ModelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalKwargsItem,
PlaceholderRange,
)
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.utils.hashing import sha256
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler, mock_kv
pytestmark = pytest.mark.cpu_test
def test_add_requests():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for i, request in enumerate(requests):
scheduler.add_request(request)
assert request.request_id in scheduler.requests
assert len(scheduler.waiting) == i + 1
def test_finish_request():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED)
assert request.request_id not in scheduler.requests
assert len(scheduler.waiting) == 9 - i
def test_get_num_unfinished_requests():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_STOPPED)
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
@pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs",
[
(False, None),
(True, 5),
],
)
def test_schedule(enable_prefix_caching: bool, prompt_logprobs: int | None):
"""Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
"""
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
requests = create_requests(num_requests=10, prompt_logprobs=prompt_logprobs)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
# Verify requests moved from waiting to running
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == len(requests)
for i, request in enumerate(requests):
assert scheduler.running[i] == request
def test_schedule_multimodal_requests():
scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf")
mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)]
requests = create_requests(
num_requests=10,
num_tokens=200,
mm_positions=mm_positions,
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
assert len(output.scheduled_encoder_inputs) == 10
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
assert len(encoder_input) == 1
def test_schedule_partial_requests():
"""Test scheduling behavior with partial requests.
This test verifies that:
1. The scheduler can handle multiple partial requests in a single step when
constrained by encoder budget.
2. A request in RUNNING state may be unscheduled in subsequent steps if
there is insufficient encoder budget.
"""
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
max_num_batched_tokens=1024,
)
mm_positions = [[PlaceholderRange(offset=100, length=600)] for _ in range(3)]
requests = create_requests(
num_requests=3,
num_tokens=800,
mm_positions=mm_positions,
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
assert scheduler.max_num_encoder_input_tokens == 1024
# The first request is scheduled fully.
assert output.num_scheduled_tokens[requests[0].request_id] == 800
# The second request is scheduled partially.
# The <img> tokens are not scheduled because of the encoder budget.
assert output.num_scheduled_tokens[requests[1].request_id] == 100
# The third request is also scheduled partially.
# The <img> tokens are not scheduled because of the encoder budget.
assert output.num_scheduled_tokens[requests[2].request_id] == 100
req_to_index = {request.request_id: i for i, request in enumerate(requests)}
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
# Only the first request has a sampled token id because
# the rest requests are still being prefilled.
sampled_token_ids=[[0], [], []],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
# Schedule the next step.
# Only the first and second requests are scheduled.
# The third request is in the RUNNING state but not scheduled in this step
# because of the encoder budget.
output = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output.scheduled_new_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 2
assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 1
assert output.num_scheduled_tokens[requests[1].request_id] == 700
assert requests[2].request_id not in output.num_scheduled_tokens
def test_no_mm_input_chunking():
# Disable multimodal input chunking.
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
max_num_batched_tokens=1024,
disable_chunked_mm_input=True,
max_model_len=2048,
)
mm_positions = [[PlaceholderRange(offset=400, length=800)]]
requests = create_requests(
num_requests=1, num_tokens=1200, mm_positions=mm_positions
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# We want to only see the 400 text tokens at the start scheduled
assert output.num_scheduled_tokens[requests[0].request_id] == 400
req_to_index = {request.request_id: i for i, request in enumerate(requests)}
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(output.scheduled_new_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 1
assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 800
# Test that we fail if we disable chunked mm input and use too small
# of a max_num_batched_tokens for the mm input.
with pytest.raises(ValueError):
_ = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
max_num_batched_tokens=100,
disable_chunked_mm_input=True,
)
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
"""Test scheduling behavior with concurrent partial requests.
This test verifies that: there are multiple long prefill requests in the
RUNNING state, and we can schedule them together.
"""
scheduler = create_scheduler(
model="facebook/opt-125m",
max_num_batched_tokens=1024,
long_prefill_token_threshold=400,
enable_prefix_caching=enable_prefix_caching,
)
requests = create_requests(
num_requests=3,
num_tokens=800,
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400.
assert output.num_scheduled_tokens[requests[0].request_id] == 400
# The second request is scheduled partially - 400.
assert output.num_scheduled_tokens[requests[1].request_id] == 400
# The third request is also scheduled partially - 1024 - 400 - 400 = 224.
assert output.num_scheduled_tokens[requests[2].request_id] == 224
req_to_index = {request.request_id: i for i, request in enumerate(requests)}
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
# Schedule the next step. All three requests are running.
# Processed the remaining prefills of the first and second requests.
output1 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0
assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
assert output1.num_scheduled_tokens[requests[2].request_id] == 224
# Schedule the third step. All three requests are running.
# First and second requests are in the decode stage.
# All the remaining tokens in the third request are processed.
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0
assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
assert output2.num_scheduled_tokens[requests[2].request_id] == 800 - 224 - 224
def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output"""
scheduler = create_scheduler(num_speculative_tokens=1)
# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 1, requests[1].request_id: 2},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10],
},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[
[EOS_TOKEN_ID],
[10, 11],
], # First request hits EOS, second continues
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped, second continues
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
assert list(requests[1].output_token_ids) == [10, 11]
# Test case 2: Stop on custom stop token
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 2},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13],
},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped on custom token
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].stop_reason == 42
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 42]
assert list(requests[1].output_token_ids) == [13, 14]
# Test case 3: Stop on max tokens
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 1},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: [],
},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 11] # Truncated to max_tokens
assert list(requests[1].output_token_ids) == [13]
# Test case 4: Ignore EOS flag
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output, model_output)
# Verify request continues past EOS
assert len(scheduler.running) == 1
assert not requests[0].is_finished()
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
def test_check_stop_min_tokens():
"""Test that requests don't stop when min_tokens requirement isn't met."""
from vllm.v1.core.sched.utils import check_stop
# Test case 1: num_output_tokens < min_tokens
# Should return False (don't stop)
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=5,
)
request = Request(
request_id="0",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Simulate having generated 3 output tokens (less than min_tokens=5)
request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present
result = check_stop(request, max_model_len=100)
assert result is False, "Should not stop when num_output_tokens<min_tokens"
# Test case 2: num_output_tokens >= min_tokens
# Should follow normal stopping logic (stop on EOS)
request.append_output_token_ids(
[
10,
11,
12,
13,
14,
EOS_TOKEN_ID,
]
) # 6 tokens > min_tokens
result = check_stop(request, max_model_len=100)
assert result is True, "Should stop on EOS when min_tokens met"
assert request.status == RequestStatus.FINISHED_STOPPED
# Test case 3: min_tokens = 0, should follow normal stopping logic
sampling_params_no_min = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=0,
)
request_no_min = Request(
request_id="1",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_no_min,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
request_no_min.append_output_token_ids([10, EOS_TOKEN_ID])
result = check_stop(request_no_min, max_model_len=100)
assert result is True, "Should stop on EOS when min_tokens=0"
assert request_no_min.status == RequestStatus.FINISHED_STOPPED
# Test case 4: min_tokens > 0 with stop token (not EOS)
sampling_params_stop = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=5,
stop_token_ids=[42],
)
request_stop = Request(
request_id="2",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_stop,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Only 3 output tokens, less than min_tokens=5, but has stop token
request_stop.append_output_token_ids([10, 11, 42])
result = check_stop(request_stop, max_model_len=100)
assert result is False, "Should not stop when num_output_tokens<min_tokens"
# Test case 5: min_tokens met, should stop on stop token
request_stop.append_output_token_ids(
[10, 11, 12, 13, 14, 42]
) # 6 tokens >= min_tokens=5
result = check_stop(request_stop, max_model_len=100)
assert result is True, "Should stop on stop token when min_tokens met"
assert request_stop.status == RequestStatus.FINISHED_STOPPED
assert request_stop.stop_reason == 42
@pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs",
[
(False, None),
(True, 5),
],
)
def test_schedule_concurrent_batches(
enable_prefix_caching: bool, prompt_logprobs: int | None
):
scheduler = create_scheduler(
max_num_batched_tokens=1024,
max_num_seqs=2,
enable_prefix_caching=enable_prefix_caching,
)
requests = create_requests(
num_requests=2,
num_tokens=512,
prompt_logprobs=prompt_logprobs,
)
# Schedule the first request.
scheduler.add_request(requests[0])
scheduler_output0 = scheduler.schedule()
assert len(scheduler_output0.scheduled_new_reqs) == 1
assert scheduler_output0.num_scheduled_tokens[requests[0].request_id] == 512
# The first request is still running, so only schedule the second request.
scheduler.add_request(requests[1])
scheduler_output1 = scheduler.schedule()
assert len(scheduler_output1.scheduled_new_reqs) == 1
assert scheduler_output1.num_scheduled_tokens[requests[1].request_id] == 512
# Model output of the first request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output0, model_runner_output)
# Schedule the next step.
# The first request can be scheduled again while the second
# request is still running.
scheduler_output2 = scheduler.schedule()
assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1
# Model output of the second request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output1, model_runner_output)
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
def test_schedule_order(enable_chunked_prefill: bool):
scheduler = create_scheduler(
max_num_batched_tokens=1024,
max_num_seqs=3,
enable_chunked_prefill=enable_chunked_prefill,
)
# long requests
requests = create_requests(num_requests=2, num_tokens=800)
# short requests
requests += create_requests(num_requests=2, num_tokens=10)
for request in requests:
scheduler.add_request(request)
scheduler_output1 = scheduler.schedule()
if enable_chunked_prefill:
# When enable chunked prefill, long requests will be chunked.
assert len(scheduler_output1.scheduled_new_reqs) == 2
else:
# When disable chunked prefill, should not skip the long requests,
# and scheduling subsequent short requests in advance,
# even though there is still token budgets remaining.
assert len(scheduler_output1.scheduled_new_reqs) == 1
def test_preempt_during_execution():
# NOTE(woosuk): The actual number of available blocks is 10 instead of 11
# because block 0 is reserved as the null block.
scheduler = create_scheduler(
max_num_batched_tokens=100,
block_size=16,
num_blocks=11,
enable_prefix_caching=False,
)
requests = create_requests(num_requests=2, num_tokens=80, block_size=16)
# Schedule the first request.
scheduler.add_request(requests[0])
scheduler_output0 = scheduler.schedule()
assert len(scheduler_output0.num_scheduled_tokens) == 1
assert len(scheduler_output0.scheduled_new_reqs[0].block_ids[0]) == 5
# Schedule the second request while the first request is still running.
# This scenario can occur in certain cases, when max_concurrent_batches > 1
# (e.g., when pipeline parallelism is used).
scheduler.add_request(requests[1])
scheduler_output1 = scheduler.schedule()
assert len(scheduler_output1.num_scheduled_tokens) == 1
assert len(scheduler_output1.scheduled_new_reqs[0].block_ids[0]) == 5
# Get the output of the first request.
model_runner_output0 = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output0, model_runner_output0)
# Schedule the first request again. This will cause the preemption
# of the second request because the KV cache is full.
_ = scheduler.schedule()
assert len(scheduler.running) == 1
assert scheduler.running[0] == requests[0]
assert requests[1].status == RequestStatus.PREEMPTED
model_runner_output1 = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[42]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output1, model_runner_output1)
# The second request (that is preempted) should be updated with the
# sampled token id.
assert len(requests[1].output_token_ids) == 1
assert requests[1].output_token_ids[0] == 42
def test_scheduler_reset_prefix_cache():
scheduler = create_scheduler(enable_prefix_caching=True)
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
# Initial scheduling, requests should be at the running state now
_ = scheduler.schedule()
# Verify requests moved from waiting to running
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == len(requests)
for i, request in enumerate(requests):
assert scheduler.running[i] == request
# Reset prefix cache should fail since there are still running requests
# and they are taking KV cache
assert not scheduler.reset_prefix_cache()
# Reset prefix cache with reset_running_requests=True. All running requests
# Should be pushed back to the waiting queue and kv cache should be freed
assert scheduler.reset_prefix_cache(reset_running_requests=True)
# Verify requests moved from running to waiting
assert len(scheduler.waiting) == len(requests)
assert len(scheduler.running) == 0
for i, request in enumerate(requests):
assert scheduler.waiting[i] == request
# Note - these test cases mirror some of those in test_rejection_sampler.py
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
(
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 7], [4, 8]],
(2, 6, 3, [2, 1, 0]),
), # multiple mismatches
],
)
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
"""Test scheduling behavior with speculative decoding.
This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
# Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.total_num_scheduled_tokens == len(requests)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
scheduler.update_draft_token_ids(draft_token_ids)
for i in range(len(requests)):
running_req = scheduler.running[i]
# The prompt token
assert running_req.num_computed_tokens == 1
# The prompt token and the sampled token
assert running_req.num_tokens == 2
# The prompt token, the sampled token, and the speculated tokens
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
# No draft or accepted tokens counted yet
assert not engine_core_outputs or (
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None
)
# Schedule the speculated tokens for validation
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
# The sampled token and speculated tokens
assert output.total_num_scheduled_tokens == len(requests) + sum(
len(ids) for ids in spec_tokens
)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
if spec_tokens[i]:
assert len(output.scheduled_spec_decode_tokens[req_id]) == len(
spec_tokens[i]
)
else:
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)
scheduler_stats = (
engine_core_outputs[0].scheduler_stats if engine_core_outputs else None
)
if expected[0] == 0:
assert scheduler_stats is not None
assert scheduler_stats.spec_decoding_stats is None
else:
assert scheduler_stats is not None
assert scheduler_stats.spec_decoding_stats is not None
stats = scheduler_stats.spec_decoding_stats
assert stats.num_drafts == expected[0]
assert stats.num_draft_tokens == expected[1]
assert stats.num_accepted_tokens == expected[2]
assert stats.num_accepted_tokens_per_pos == expected[3]
def _assert_right_scheduler_output(
output: SchedulerOutput,
num_requests: int,
expected_num_scheduled_tokens: int,
):
"""Check if SchedulerOutput is correct after remote KV cache hit."""
# We should inject the kv_connector_metadata.
assert len(output.kv_connector_metadata.requests) == num_requests
# Only num_tokens - matched_num_new_tokens should be scheduled.
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
assert num_scheduled_tokens == expected_num_scheduled_tokens
def _assert_right_kv_cache_manager(
scheduler: Scheduler,
requests: list[Request],
num_tokens: int,
block_size: int,
num_requests: int,
num_total_blocks: int,
):
"""Check whether KVCacheManager is correct after allocate."""
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
for req in requests:
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[req.request_id]
hashes = req.block_hashes
assert (
scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].num_cached_block[req.request_id]
== EXPECTED_TOTAL_BLOCKS
)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
# Make sure we actually touched all the blocks.
BLOCKS_PER_REQ = num_tokens / block_size
assert (
scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
== num_total_blocks - num_requests * BLOCKS_PER_REQ
)
def _step_until_done(
scheduler: Scheduler,
output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
):
"""Loop over schedule(), update_from_output() until finished."""
all_finished = False
_ = scheduler.update_from_output(output, model_runner_output)
while not all_finished:
# Schedule + a few iterations until stopping.
output = scheduler.schedule()
assert len(scheduler.running)
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
# We should be in the decode phase now.
assert num_scheduled_tokens == 1
if scheduler.connector is not None:
assert len(output.kv_connector_metadata.requests) == 0
if scheduler.ec_connector is not None:
assert len(output.ec_connector_metadata.mm_datas) == 0
ecos = scheduler.update_from_output(output, model_runner_output)[0]
all_done = True
for eco in ecos.outputs:
if eco.finish_reason is None:
all_done = False
all_finished = all_done
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
"""Cycle requests through a KV transfer cyle."""
# Requests should first transition to WAITING_FOR_REMOTE_KVS
output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids)
assert len(scheduler.running) == 0
assert len(output.scheduled_new_reqs) == 0
for req in scheduler.requests.values():
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
# No model execution yet
EMPTY_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, EMPTY_OUTPUT)
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids)
assert len(scheduler.running) == 0
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
kv_connector_output=KVConnectorOutput(finished_recving=req_ids),
)
scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
for req_id in req_ids:
assert req_id in scheduler.finished_recving_kv_req_ids
@pytest.mark.parametrize("is_async", [False, True])
def test_kv_connector_basic(is_async: bool):
"""
Test whether Scheduler with KVConnector schedules tokens, allocates
memory, and cleans up requests as expected under normal operation.
"""
# Setup Scheduler.
BLOCK_SIZE = 16
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=mock_kv(
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
),
block_size=BLOCK_SIZE,
)
NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
######################################################
# FIRST SET OF REQUESTS - External Hit Only
NUM_REQUESTS = 2
NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2
MAX_TOKENS = 3
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
if is_async:
_step_until_kv_transfer_finished(scheduler, req_ids)
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Ensure ScheduleOutput is correct.
output = scheduler.schedule()
_assert_right_scheduler_output(
output=output,
num_requests=NUM_REQUESTS,
# Just the incremental tokens should be scheduled.
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
)
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(
scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS
)
# Continue Generation until done.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
_ = scheduler.schedule()
# Confirm we clean up the memory properly.
assert (
scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS
)
######################################################
# SECOND SET OF REQUESTS - Local And External Hit
NUM_TOKENS_PREFIX = NUM_TOKENS
# We will get a local prefix cache hit for the first
# NUM_TOKENS_PREFIX tokens since they are used above.
NUM_TOKENS = NUM_TOKENS_PREFIX * 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
if is_async:
_step_until_kv_transfer_finished(scheduler, req_ids)
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# We should get a local cache hit of NUM_TOKENS_PREFIX and
# a remote KV cache hit of NUM_MATCHED_NEW_TOKENS.
output = scheduler.schedule()
_assert_right_scheduler_output(
output=output,
num_requests=NUM_REQUESTS,
# Just the incremental tokens after local + remote cache hit.
expected_num_scheduled_tokens=(
NUM_TOKENS - NUM_TOKENS_PREFIX - NUM_MATCHED_NEW_TOKENS
),
)
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(
scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS
)
# Continue Generation until done.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
_ = scheduler.schedule()
# Confirm we clean up the memory properly.
assert (
scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS
)
@pytest.mark.parametrize("is_async", [False, True])
def test_external_prefix_cache_metrics(is_async: bool):
"""
Verify connector prefix cache metrics are updated
correctly when the scheduler processes requests with KV connector hits.
"""
# Setup Scheduler.
NUM_MATCHED_NEW_TOKENS = 4
scheduler = create_scheduler(
enable_prefix_caching=False,
use_kv_connector=mock_kv(
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
),
)
# --- Prepare simple requests ---
NUM_REQUESTS = 2
NUM_TOKENS = 8
MAX_TOKENS = 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
if is_async:
_step_until_kv_transfer_finished(scheduler, req_ids)
# --- Trigger scheduling and simulate model output ---
output = scheduler.schedule()
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[r.request_id for r in requests],
req_id_to_index={r.request_id: i for i, r in enumerate(requests)},
sampled_token_ids=[[1000]] * NUM_REQUESTS,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Update scheduler stats
ecos = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
# --- Assertions ---
assert ecos is not None and len(ecos) > 0
assert ecos[0].scheduler_stats is not None
external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats
assert external_stats is not None
assert external_stats.queries == NUM_TOKENS * NUM_REQUESTS
assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS
assert external_stats.requests == NUM_REQUESTS
assert external_stats.preempted_requests == 0
@pytest.mark.parametrize(
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
)
def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 4
NUM_BLOCKS = 10
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
# encoder connector should not affect test results
use_ec_connector=use_ec_connector,
ec_role=ec_role,
)
# Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks.
NUM_REQUESTS = 2
NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE
MAX_TOKENS = 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Just one request should be running.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
num_requests=1,
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# All memory should be freed, with one request waiting.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# Just one request should be running.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
num_requests=1,
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# All memory should be freed, with no requests waiting / running.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
@pytest.mark.parametrize(
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
)
def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 2
# NOTE: there is 1 null block, so this is 6 blocks.
NUM_BLOCKS = 7
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
# encoder connector should not affect test results
use_ec_connector=use_ec_connector,
ec_role=ec_role,
)
# Create two requests.
# Both can be scheduled at first, but the second request
# will be preempted and re-scheduled.
NUM_REQUESTS = 2
NUM_TOKENS = BLOCK_SIZE * 2 + 1
MAX_TOKENS = BLOCK_SIZE * 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# All can be scheduled - 1st token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# 2 remote kv cache hits.
num_requests=2,
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
)
assert len(scheduler.running) == 2
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
# All can be scheduled - 2nd token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1,
)
assert len(scheduler.running) == 2
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
# This will generate a new block and cause a preemption - 3rd token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1,
)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Only 1 can be scheduled - 4th (and last token).
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1,
)
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
# Restarts the preempted request - generate 3rd token.
# This will have a local and remote cache hit.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# 1 remote kv_cache hit!
num_requests=1,
# Only 1 block was preempted and there is a single
# remote hit. So only single new token is scheduled.
expected_num_scheduled_tokens=1,
)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Only 1 can be scheduled - 4th (and last token).
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1,
)
assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
def make_output(scheduler: Scheduler):
return ModelRunnerOutput(
req_ids=[req.request_id for req in scheduler.running],
req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)},
sampled_token_ids=[[1000]] * len(scheduler.running),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
def assert_scheduler_empty(scheduler: Scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert (
len(
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks
)
== 0
)
assert (
len(
scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].num_cached_block
)
== 0
)
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
)
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
# assert block._block_hash is None
# assert (
# len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
# ) == 0)
def test_memory_leak():
"""Test that we do not have a memory leak."""
scheduler = create_scheduler(enable_prefix_caching=True)
NUM_REQUESTS = 5
NUM_TOKENS = 10
MAX_TOKENS = 10
requests = create_requests(
num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS
)
# Add each request.
for request in requests:
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Iterate until done.
while True:
scheduler_output = scheduler.schedule()
if len(scheduler.running) == 0:
break
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm no memory leak.
assert_scheduler_empty(scheduler)
def create_scheduler_with_priority(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: bool = False,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: int | None = None,
num_speculative_tokens: int | None = None,
use_ec_connector: bool = False,
ec_role: str | None = None,
) -> Scheduler:
"""Create scheduler with priority policy enabled.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(False)
Returns:
{class}`Scheduler` instance with priority scheduling
"""
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
seed=42,
)
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
is_encoder_decoder=model_config.is_encoder_decoder,
policy="priority", # Enable priority scheduling
)
# Cache config, optionally force APC
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=enable_prefix_caching,
)
kv_transfer_config = (
KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
)
if use_kv_connector
else None
)
speculative_config: SpeculativeConfig | None = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens
)
ec_transfer_config = (
ECTransferConfig(
ec_connector="ECSharedStorageConnector",
ec_role=ec_role,
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
)
if use_ec_connector
else None
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
ec_transfer_config=ec_transfer_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
)
],
)
cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=block_size,
)
_none_hash_initialized = False
def create_requests_with_priority(
num_requests: int,
priorities: list[int],
arrival_times: list[float] | None = None,
num_tokens: int = 10,
mm_hashes_list: list[list[str]] | None = None,
mm_positions: list[list[PlaceholderRange]] | None = None,
max_tokens: int = 16,
stop_token_ids: list[int] | None = None,
prompt_logprobs: int | None = None,
starting_idx: int = 0,
same_prompt: bool = False,
block_size: int = 16,
req_ids: list[str] | None = None,
):
"""Create requests with specified priorities and arrival times."""
assert len(priorities) == num_requests
if arrival_times is not None:
assert len(arrival_times) == num_requests
else:
arrival_times = [float(i) for i in range(num_requests)]
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(sha256)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, sha256)
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs,
)
requests = []
if mm_hashes_list is not None:
# NOTE: allow manual input; some mm items can have the same identifier
# no. of mm_hashes and mm_positions for each request should be identical
assert mm_positions is not None, (
"mm_positions must be provided when mm_hashes_list is provided"
)
assert len(mm_hashes_list) == len(mm_positions) == num_requests
assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions]
# Since same identifier would imply they are identical encoder output
# Verify mm items with identical identifier are having mm_position.length
seen_hashes: dict[str, int] = {}
if req_ids:
assert len(req_ids) == num_requests
else:
req_ids = [f"{i + starting_idx}" for i in range(num_requests)]
for i in range(num_requests):
mm_features = []
for j, position in enumerate(
mm_positions[i] if mm_positions is not None else []
):
if mm_hashes_list is not None:
identifier = mm_hashes_list[i][j]
# Verify if position length is identical
position_length = position.length
if identifier in seen_hashes:
assert seen_hashes[identifier] == position_length, (
f"mm_hash '{identifier}' has inconsistent position lengths: "
f"previously {seen_hashes[identifier]}, now {position_length} "
f"at request {i}, position {j}"
)
else:
seen_hashes[identifier] = position_length
else:
# Unique dummy hash for each mm item
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image",
)
mm_features.append(mm_feature)
prompt_token_ids = (
[starting_idx] * num_tokens
if same_prompt
else [i + starting_idx] * num_tokens
)
request = Request(
request_id=req_ids[i],
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=priorities[i],
block_hasher=block_hasher,
)
requests.append(request)
return requests
def test_priority_scheduling_basic_ordering():
"""Test that requests are scheduled in priority order
(lower value = higher priority)."""
scheduler = create_scheduler_with_priority()
# Create requests with different priorities
# Priority 0 (highest), 1, 2 (lowest)
priorities = [2, 0, 1] # Add in non-priority order
arrival_times = [1.0, 2.0, 3.0] # All different arrival times
requests = create_requests_with_priority(
num_requests=3, priorities=priorities, arrival_times=arrival_times
)
# Add requests in non-priority order
for request in requests:
scheduler.add_request(request)
# Schedule and verify priority order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 3
# Verify they are scheduled in priority order:
# req_1 (priority 0), req_2 (priority 1), req_0 (priority 2)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]
def test_priority_scheduling_arrival_time_tiebreaker():
"""Test that arrival time is used
as tiebreaker when priorities are equal."""
scheduler = create_scheduler_with_priority()
# Create requests with same priority but different arrival times
priorities = [1, 1, 1] # All same priority
arrival_times = [3.0, 1.0, 2.0] # Different arrival times
requests = create_requests_with_priority(
num_requests=3, priorities=priorities, arrival_times=arrival_times
)
# Add requests in non-arrival order
for request in requests:
scheduler.add_request(request)
# Schedule and verify arrival time order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 3
# Verify they are scheduled in arrival time order:
# req_1 (1.0), req_2 (2.0), req_0 (3.0)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]
def test_priority_scheduling_mixed_priority_and_arrival():
"""Test priority scheduling with mixed priorities and arrival times."""
scheduler = create_scheduler_with_priority()
# Create requests with mixed priorities and arrival times
priorities = [2, 1, 1, 0] # Mixed priorities
arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times
requests = create_requests_with_priority(
num_requests=4, priorities=priorities, arrival_times=arrival_times
)
# Add requests
for request in requests:
scheduler.add_request(request)
# Schedule and verify order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 4
# Expected order:
# 1. req_3 (priority 0, arrival 4.0)
# 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1
# 3. req_1 (priority 1, arrival 3.0)
# 4. req_0 (priority 2, arrival 1.0)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["3", "2", "1", "0"]
def test_priority_scheduling_preemption():
"""Test that priority scheduling preempts
lower priority requests when memory is constrained."""
# Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow multiple requests
max_num_batched_tokens=200,
num_blocks=6, # Very limited blocks to force memory pressure
block_size=16, # Standard block size
)
# Create initial low-priority requests that will consume most memory
low_priority_requests = create_requests_with_priority(
num_requests=2,
priorities=[5, 5], # Low priority
arrival_times=[1.0, 2.0],
num_tokens=30, # Large enough to consume significant memory
)
# Add and schedule low priority requests
for request in low_priority_requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 2
# Simulate model execution to move requests to running state
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in low_priority_requests],
req_id_to_index={
req.request_id: i for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Verify both requests are running
assert len(scheduler.running) == 2
# Now add a high-priority request that requires memory allocation
# This should trigger preemption due to memory constraints
high_priority_request = create_requests_with_priority(
num_requests=1,
priorities=[0], # High priority
arrival_times=[3.0],
num_tokens=30, # Large enough to require significant memory
)[0]
scheduler.add_request(high_priority_request)
# Schedule again - this should trigger
# preemption when trying to allocate memory
output = scheduler.schedule()
# Due to the scheduler's design, if preemption happens
# during running request scheduling,
# waiting requests won't be scheduled in the same step
# Let's check if preemption occurred by looking at the waiting queue
# If preemption happened, we should see requests in the
# waiting queue
if len(scheduler.waiting) > 1: # high priority + preempted request
# Preemption occurred - verify the high priority request
# gets scheduled next
output2 = scheduler.schedule()
assert len(output2.scheduled_new_reqs) == 1
# High priority request
assert output2.scheduled_new_reqs[0].req_id == "0"
else:
# No preemption needed - all requests fit
# This is also valid behavior if memory allows
assert len(output.scheduled_new_reqs) == 1
# High priority request
assert output.scheduled_new_reqs[0].req_id == "0"
def test_priority_scheduling_no_preemption_when_space_available():
"""Test that preemption doesn't happen
when there's space for new requests."""
scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow 3 concurrent requests
max_num_batched_tokens=200, # Sufficient token budget
)
# Add two low-priority running requests
low_priority_requests = create_requests_with_priority(
num_requests=2, priorities=[5, 5], arrival_times=[1.0, 2.0], num_tokens=30
)
for request in low_priority_requests:
scheduler.add_request(request)
output = scheduler.schedule()
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in low_priority_requests],
req_id_to_index={
req.request_id: i for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Add high-priority request
high_priority_request = create_requests_with_priority(
num_requests=1, priorities=[0], arrival_times=[3.0], num_tokens=30
)[0]
scheduler.add_request(high_priority_request)
# Schedule - should not preempt since there's space
output = scheduler.schedule()
# Should schedule the new request without preemption
assert len(output.scheduled_new_reqs) == 1
assert len(scheduler.running) == 3 # All three requests running
assert len(scheduler.waiting) == 0 # No requests waiting
def test_priority_scheduling_preemption_victim_selection():
"""Test that the correct victim is selected for
preemption based on priority and arrival time."""
# This test verifies the priority-based victim selection logic
# by checking the waiting queue order after adding requests with different
# priorities
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Force sequential processing to test priority order
)
# Create requests with different priorities
requests = create_requests_with_priority(
num_requests=3,
priorities=[3, 2, 0], # Different priorities: low, medium, high
arrival_times=[1.0, 2.0, 3.0],
num_tokens=10,
)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the highest priority request
# (req_2, priority 0)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority
# Verify the waiting queue has the remaining requests in priority order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify priority order
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_1 (priority 2) then req_0 (priority 3)
assert waiting_priorities == [2, 3]
assert waiting_req_ids == ["1", "0"]
def test_priority_scheduling_equal_priority_preemption():
"""Test arrival time tiebreaker when requests have equal priority."""
# This test verifies that arrival time is used as a tiebreaker for equal
# priorities
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Force sequential processing
)
# Create requests with same priority but different arrival times
requests = create_requests_with_priority(
num_requests=3,
priorities=[2, 2, 2], # Same priority
arrival_times=[3.0, 1.0, 2.0], # Different arrival times
num_tokens=10,
)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should schedule the request with earliest arrival time
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0)
# Verify the waiting queue has remaining requests in arrival time order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify arrival time order
waiting_requests = list(scheduler.waiting)
waiting_arrival_times = [req.arrival_time for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_2 (arrival 2.0) then req_0 (arrival 3.0)
assert waiting_arrival_times == [2.0, 3.0]
assert waiting_req_ids == ["2", "0"]
def test_priority_scheduling_waiting_queue_order():
"""Test that the waiting queue maintains priority order."""
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time
)
# Create multiple requests with different priorities
requests = create_requests_with_priority(
num_requests=4,
priorities=[3, 1, 2, 0], # Mixed priorities
arrival_times=[1.0, 2.0, 3.0, 4.0],
num_tokens=10,
)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the highest priority request
# (req_3, priority 0)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "3"
# Verify waiting queue has remaining requests in priority order
assert len(scheduler.waiting) == 3
# Extract requests from waiting queue
# (it's a heap, so we need to pop to see order)
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3)
assert waiting_req_ids == ["1", "2", "0"]
assert waiting_priorities == [1, 2, 3]
def test_priority_scheduling_fcfs_fallback():
"""Test that FCFS behavior is maintained when all
requests have same priority."""
scheduler = create_scheduler_with_priority()
# Create requests with same priority but different arrival times
priorities = [1, 1, 1, 1] # All same priority
arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times
requests = create_requests_with_priority(
num_requests=4, priorities=priorities, arrival_times=arrival_times
)
# Add requests
for request in requests:
scheduler.add_request(request)
# Schedule
output = scheduler.schedule()
# Should schedule all requests in arrival time order
assert len(output.scheduled_new_reqs) == 4
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
# Expected order by arrival time:
# req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0)
assert scheduled_req_ids == ["1", "3", "2", "0"]
def test_priority_scheduling_with_limited_slots():
"""Test priority scheduling when max_num_seqs limits concurrent requests."""
scheduler = create_scheduler_with_priority(
max_num_seqs=2, # Only allow 2 concurrent requests
max_num_batched_tokens=1000, # Plenty of token budget
)
# Create requests with different priorities
requests = create_requests_with_priority(
num_requests=4,
priorities=[3, 1, 2, 0], # Mixed priorities
arrival_times=[1.0, 2.0, 3.0, 4.0],
num_tokens=10,
)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the 2 highest priority requests
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 2
# Should schedule req_3 (priority 0) and req_1 (priority 1)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert "3" in scheduled_req_ids # Priority 0
assert "1" in scheduled_req_ids # Priority 1
# Remaining requests should be in waiting queue in priority order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify order
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_2 (priority 2) then req_0 (priority 3)
assert waiting_priorities == [2, 3]
assert waiting_req_ids == ["2", "0"]
def test_priority_scheduling_heap_property():
"""Test that the waiting queue maintains heap
property for priority scheduling."""
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time
)
# Add requests in random priority order
priorities = [5, 1, 8, 3, 2, 7, 4, 6]
arrival_times = [float(i) for i in range(len(priorities))]
requests = create_requests_with_priority(
num_requests=len(priorities),
priorities=priorities,
arrival_times=arrival_times,
num_tokens=10,
)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule one request at a time and verify priority order
scheduled_priorities = []
while scheduler.waiting:
output = scheduler.schedule()
if output.scheduled_new_reqs:
req = output.scheduled_new_reqs[0]
scheduled_priorities.append(requests[int(req.req_id)].priority)
# Simulate completion to make room for next request
model_output = ModelRunnerOutput(
req_ids=[req.req_id],
req_id_to_index={req.req_id: 0},
sampled_token_ids=[[100]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Finish the request to make room for the next one
scheduler.finish_requests(req.req_id, RequestStatus.FINISHED_STOPPED)
# Verify requests were scheduled in priority order (lowest value first)
expected_priorities = sorted(priorities)
assert scheduled_priorities == expected_priorities
def test_schedule_skip_tokenizer_init():
scheduler = create_scheduler(skip_tokenizer_init=True)
requests = create_requests(num_requests=5)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
def test_schedule_skip_tokenizer_init_structured_output_request():
scheduler = create_scheduler(skip_tokenizer_init=True)
structured_outputs_params = StructuredOutputsParams(regex="[0-9]+")
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=16,
structured_outputs=structured_outputs_params,
)
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
@pytest.mark.parametrize(
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
)
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
use_ec_connector, ec_role
):
"""Test that priority scheduling preempts lower priority requests
when out of KV cache space."""
# Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority(
max_num_seqs=2, # Allow multiple requests
max_num_batched_tokens=200,
num_blocks=5, # Can hold 64 tokens (first block is null)
block_size=16, # Standard block size
use_kv_connector=True,
# encoder connector should not affect test results
use_ec_connector=use_ec_connector,
ec_role=ec_role,
)
# Create a request and schedule it
request_low = create_requests_with_priority(
num_requests=1,
priorities=[1],
arrival_times=[0.0],
num_tokens=30,
starting_idx=0,
)[0]
scheduler.add_request(request_low)
# 1st schedule
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1
# Simulate model execution - 1st decode
model_output = ModelRunnerOutput(
req_ids=[request_low.request_id],
req_id_to_index={request_low.request_id: 0},
sampled_token_ids=[[100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Create a high priority request and schedule it
request_high = create_requests_with_priority(
num_requests=1,
priorities=[0],
arrival_times=[1.0],
num_tokens=32,
starting_idx=1,
)[0]
scheduler.add_request(request_high)
# 2nd schedule
output = scheduler.schedule()
# KV cache should be full at this point
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_cached_reqs.num_reqs == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 2
# Simulate model execution - 2nd decode
requests = [request_low, request_high]
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[100] for _ in requests],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# 3rd schedule - this should trigger preemption
# req_low needs 32 tokens = 2 blocks
# req_high needs 33 tokens = 3 blocks
# so doesn't fit in 4 blocks.
output = scheduler.schedule()
# Should have preempted req_low
assert len(output.scheduled_new_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1
# Simulate model execution - 3rd decode
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[], [100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Finish the requests to make room for the preempted requests to resume
scheduler.update_from_output(output, model_output)
scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED)
# 4th Schedule - this should trigger the resumption
output = scheduler.schedule()
scheduled_cached_reqs = output.scheduled_cached_reqs
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
assert len(output.scheduled_new_reqs) == 0
assert scheduled_cached_reqs.num_reqs == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1
# Preempted request resumed in scheduled_cached_reqs
assert len(resumed_from_preemption) == 1
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
assert resumed_from_preemption[0]
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
# Resumed tokens include 30 prompt tokens and 2 decoded tokens
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32
assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100
@pytest.mark.parametrize(
("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"),
[
(True, False, True),
(False, False, False),
# Encoder-decoder models should always have it disabled
(False, True, False),
(True, True, False),
],
)
def test_chunked_prefill_disabled_for_encoder_decoder(
enable_chunked_prefill: bool, is_encoder_decoder: bool, expect_enabled: bool
) -> None:
"""Validate that chunked prefill is appropriately disabled for
encoder-decoder models."""
scheduler_config = SchedulerConfig(
enable_chunked_prefill=enable_chunked_prefill,
is_encoder_decoder=is_encoder_decoder,
# Must <= max_num_batched_tokens if chunked prefill is disabled
max_model_len=SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
# `is_encoder_decoder` should only be used during construction
# of the config, and otherwise stored in the model config.
assert "is_encoder_decoder" not in vars(scheduler_config)
assert "is_encoder_decoder" not in [
f.name for f in dataclasses.fields(scheduler_config)
]
_validate_chunked_prefill_settings_for_encoder_decoder(
scheduler_config, is_encoder_decoder, expect_enabled
)
# Ensure it is retained in VllmConfig, even after its post-init.
vllm_config = VllmConfig(scheduler_config=scheduler_config)
_validate_chunked_prefill_settings_for_encoder_decoder(
vllm_config.scheduler_config, is_encoder_decoder, expect_enabled
)
def _validate_chunked_prefill_settings_for_encoder_decoder(
scheduler_config: SchedulerConfig, is_encoder_decoder: bool, expect_enabled: bool
) -> None:
"""Validate chunked prefill settings in the scheduler config for
encoder-decoder models."""
assert scheduler_config.enable_chunked_prefill is expect_enabled
if is_encoder_decoder:
# Encoder-decoder models should automatically disable chunked multimodal
# inputs as well
assert scheduler_config.disable_chunked_mm_input is not expect_enabled
if is_encoder_decoder and not expect_enabled:
assert scheduler_config.long_prefill_token_threshold == 0
# ==============================================================================
# EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests start
# NOTE: In E->P->D disagg case, both KV and EC Connector works in P instance
# Unless specify, the existence of KV Connector should not affect any test results
# ==============================================================================
def _assert_right_encoder_cache_allocated(
scheduler: Scheduler,
hashes_to_check: list[str] | None = None,
requests: list[Request] | None = None,
expected_total_allocated: int | None = None,
):
"""Check whether encoder cache is allocated correctly."""
encoder_cache_manager = scheduler.encoder_cache_manager
# Verify encoder cache manager exists
assert encoder_cache_manager is not None, "Encoder cache manager should exist"
# Verify number of cache
if expected_total_allocated is not None:
assert len(encoder_cache_manager.cached) == expected_total_allocated
if expected_total_allocated == 0:
return
# Verify each request with MM data is in cache
cached_hashes = set(encoder_cache_manager.cached.keys())
if hashes_to_check:
missed_hashes = set(hashes_to_check) - cached_hashes
assert not missed_hashes, (
f"Miss hashes: {missed_hashes} "
f"Existing encoder cache: {encoder_cache_manager.cached}"
)
for req in requests if requests is not None else []:
if req.mm_features:
mm_hashes = [f.identifier for f in req.mm_features]
req_hashes = set(mm_hashes) # unique hashes set
missed_hashes = req_hashes - cached_hashes
assert not missed_hashes, (
f"Miss hashes in cache for request {req.request_id}: {missed_hashes} "
f"Existing encoder cache: {encoder_cache_manager.cached}"
)
def _assert_right_ec_connector_metadata(
output: SchedulerOutput,
mm_features_list: list[MultiModalFeatureSpec],
):
"""Verify that ECConnector metadata EXACTLY matches the input MM data"""
# Get the connector metadata
metadata = output.ec_connector_metadata
# Create lookup dictionaries for efficient access
metadata_dict = {mm_data.mm_hash: mm_data for mm_data in metadata.mm_datas}
# Check all required identifiers exist in metadata; and no extra
# In ECSharedStorageConnector format
# NOTE: even having same identifier, the mm_features can be different
# since their mm_position can be in different offsets, etc
identifiers_dict = {f.identifier for f in mm_features_list}
assert set(metadata_dict.keys()) == identifiers_dict
# Verify the info matches
for i, mm_feature in enumerate(mm_features_list):
identifier = mm_feature.identifier
assert metadata_dict[identifier].mm_hash == identifier
assert metadata_dict[identifier].num_token == mm_feature.mm_position.length
def _assert_right_encoder_inputs(
output: SchedulerOutput,
check_exist: bool | None = True,
requests: list[Request] | None = None,
expected_encoder_inputs: list[list[int]] | None = None,
expected_total_reqs: int | None = None,
):
"""Verify that requests/mm_hashes should (not) in scheduled encoder input
If check_exist is False, this function returns True
if requests are NOT in encoder inputs"""
# Get the scheduled encoder inputs
# NOTE: scheduled_encoder_inputs is a dictionary with request id as key
scheduled_encoder_inputs = output.scheduled_encoder_inputs
# Check if scheduled_encoder_inputs is empty as expected
if expected_total_reqs is not None:
assert len(scheduled_encoder_inputs) == expected_total_reqs
if expected_total_reqs == 0:
return
# Number of expected enocder inputs should match number of requests
if expected_encoder_inputs:
assert check_exist and requests is not None # only support expect input exist
assert len(requests) == len(expected_encoder_inputs)
# Check request (not) exist as expected
for i, request in enumerate(requests if requests is not None else []):
assert (request.request_id in scheduled_encoder_inputs) is check_exist, (
f"Request {request.id} presence mismatch: expected {check_exist}, "
f"got {request.id in scheduled_encoder_inputs}"
)
if expected_encoder_inputs:
scheduled_encoder_input = scheduled_encoder_inputs[request.request_id]
assert scheduled_encoder_input == expected_encoder_inputs[i]
def test_scheduler_no_ec_connector_by_default():
"""Test scheduler doesn't have EC connector by default."""
scheduler = create_scheduler()
assert scheduler.ec_connector is None
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_text_only_request(use_kv_connector):
"""Test text-only requests don't allocate encoder cache."""
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
use_kv_connector=use_kv_connector,
use_ec_connector=True,
ec_role="ec_consumer",
)
NUM_PROMPT_TOKENS = 100
# Create text-only request (no mm_positions)
requests = create_requests(
num_requests=1,
num_tokens=NUM_PROMPT_TOKENS,
)
assert not requests[0].mm_features # No MM data
scheduler.add_request(requests[0])
output = scheduler.schedule()
# Should schedule
assert len(output.scheduled_new_reqs) == 1
# Scheduled tokens should equal prompt tokens exactly
scheduled = output.num_scheduled_tokens[requests[0].request_id]
assert scheduled == NUM_PROMPT_TOKENS, (
f"Text-only should schedule {NUM_PROMPT_TOKENS}, got {scheduled}"
)
# Encoder cache should be empty
_assert_right_encoder_cache_allocated(scheduler, expected_total_allocated=0)
# ECConnector should carry no metadata
_assert_right_ec_connector_metadata(output, mm_features_list=[])
# Scheduled encoder input should be empty; no mm to compute
_assert_right_encoder_inputs(output, expected_total_reqs=0)
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_cache_hit_external_load(use_kv_connector):
"""Test ec_consumer loads from external cache when hit.
A normal basic operation for EPD disaggrgation"""
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
enable_prefix_caching=True,
# kv connector should not effect test results
use_kv_connector=use_kv_connector,
use_ec_connector=True,
ec_role="ec_consumer",
)
# Create MM request
NUM_TOKENS = 200 # NOTE: includes mm tokens
NUM_ENCODER_TOKENS = 100
mm_hashes_list = [["hash_test1"]]
mm_positions = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS)]]
request = create_requests(
num_requests=1,
num_tokens=NUM_TOKENS,
mm_hashes_list=mm_hashes_list,
mm_positions=mm_positions,
)[0]
# Mock cache hit - encoder cache exists externally
scheduler.ec_connector.has_caches = Mock(return_value=[True])
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
scheduler.add_request(request)
output = scheduler.schedule()
# Should schedule prompt tokens
scheduled_tokens = output.num_scheduled_tokens[request.request_id]
assert scheduled_tokens == NUM_TOKENS
# Should called update_state_after_alloc for external load
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request, 0)
# Encoder cache should contain mm items from request
_assert_right_encoder_cache_allocated(scheduler, requests=[request])
# ECConnector should carry metadata of request
_assert_right_ec_connector_metadata(output, mm_features_list=request.mm_features)
# Scheduled encoder input should be empty; no mm to compute
_assert_right_encoder_inputs(output, expected_total_reqs=0)
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_cache_miss_computes_locally(use_kv_connector):
"""Test consumer can compute encoder locally when cache miss (fallback)."""
# encoder cache itself if it doesn't receive it from external storage
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
enable_prefix_caching=True,
use_kv_connector=use_kv_connector,
use_ec_connector=True,
ec_role="ec_consumer",
)
# Verify consumer role
assert scheduler.ec_connector is not None
assert not scheduler.ec_connector.is_producer
# Create MM request
request_mm_missed = create_requests(
num_requests=1,
num_tokens=200, # Total (including 100 MM)
mm_positions=[[PlaceholderRange(offset=0, length=100)]], # 100 MM tokens
)[0]
# Mock cache miss - encoder cache doesn't exist externally
scheduler.ec_connector.has_caches = Mock(return_value=[False])
scheduler.add_request(request_mm_missed)
output = scheduler.schedule()
# SCHEDULER should decide to compute encoder locally (fallback)
assert len(output.scheduled_new_reqs) == 1
# Should schedule full prompt tokens
scheduled_tokens = output.num_scheduled_tokens[request_mm_missed.request_id]
assert scheduled_tokens == 200, (
f"Expected 200 tokens on cache miss, got {scheduled_tokens}"
)
# Encoder cache should contain mm items from request
_assert_right_encoder_cache_allocated(scheduler, requests=[request_mm_missed])
# ECConnector should carry no metadata (missed cache)
_assert_right_ec_connector_metadata(output, mm_features_list=[])
# Scheduled encoder input contain mm for request_mm_missed
_assert_right_encoder_inputs(
output,
requests=[request_mm_missed],
expected_encoder_inputs=[[0]], # index 0 of the mm item
expected_total_reqs=1,
)
# Then MODEL_RUNNER will execute the encoder and cache the result
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
"""Test consumer with partial cache hit (local & connector) with 2 requests."""
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
enable_prefix_caching=True,
use_kv_connector=use_kv_connector,
use_ec_connector=True,
ec_role="ec_consumer",
)
# Create MM request
NUM_TOKENS_1 = 300 # NOTE: includes mm tokens
NUM_ENCODER_TOKENS_1 = 50
mm_hashes_list_1 = [["hash1_A", "hash1_B", "hash1_A", "hash1_F"]]
mm_positions_1 = [
[
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1),
PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_1),
PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_1),
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_1),
]
]
# Create request with 4 MM items, with 2 identical items
request1 = create_requests(
num_requests=1,
num_tokens=NUM_TOKENS_1,
mm_hashes_list=mm_hashes_list_1,
mm_positions=mm_positions_1,
max_tokens=1, # For simplicity
)[0]
# Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist
scheduler.ec_connector.has_caches = Mock(return_value=[False, True, False, True])
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
scheduler.add_request(request1)
output = scheduler.schedule()
# Should schedule all tokens
scheduled_tokens = output.num_scheduled_tokens[request1.request_id]
assert scheduled_tokens == NUM_TOKENS_1
# Encoder cache should contain all mm items from request
_assert_right_encoder_cache_allocated(scheduler, requests=[request1])
# Should have called update_state_after_alloc for external load
scheduler.ec_connector.update_state_after_alloc.assert_called()
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata for 2nd and 4th mm item
_assert_right_ec_connector_metadata(
output, mm_features_list=[request1.mm_features[1], request1.mm_features[3]]
)
# Should schedule ONLY 1 encoder input (index 0), no repeat for identical items
_assert_right_encoder_inputs(
output,
requests=[request1],
expected_encoder_inputs=[[0]], # index 0 of the mm item ONLY
expected_total_reqs=1,
)
# Simulate model execution 1 step
model_output = ModelRunnerOutput(
req_ids=[request1.request_id],
req_id_to_index={request1.request_id: 0},
sampled_token_ids=[[100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# request1 is finished after outputing 1 token
# Finish request
scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED)
# Create another request with 4 MM items
NUM_TOKENS_2 = 400
NUM_ENCODER_TOKENS_2 = 50
mm_hashes_list_2 = [["hash1_C", "hash1_D", "hash1_E", "hash1_A"]]
mm_positions_2 = [
[
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2),
PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_2),
PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_2),
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_2),
]
]
request2 = create_requests(
num_requests=1,
num_tokens=NUM_TOKENS_2,
mm_hashes_list=mm_hashes_list_2,
mm_positions=mm_positions_2,
max_tokens=1, # For simplicity
)[0]
# Mock partial cache hit: only hash1_A and hash1_C exist in connector
scheduler.ec_connector.has_caches = Mock(return_value=[True, False, False, True])
scheduler.add_request(request2)
output = scheduler.schedule()
# Check
# Should schedule all tokens
scheduled_tokens = output.num_scheduled_tokens[request2.request_id]
assert scheduled_tokens == 400
# Encoder cache should contain all mm items from request2
_assert_right_encoder_cache_allocated(scheduler, requests=[request2])
# Should call update_state_after_alloc for hash1_C, ONLY
# hash1_A should not be loaded from connector
# since it's computed in last request & exist in local cache
# Order of getting encoder cache should be: local cache -> connector-> compute
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 0)
scheduler.ec_connector.update_state_after_alloc.assert_called_once()
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata for hash1_C only (index 0)
_assert_right_ec_connector_metadata(
output, mm_features_list=[request2.mm_features[0]]
)
# Should schedule 2 encoder input hash1_D and hash1_E (index 1, 2)
_assert_right_encoder_inputs(
output,
requests=[request2],
expected_encoder_inputs=[[1, 2]],
expected_total_reqs=1,
)
@pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"])
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_schedule_multiple_requests(cache_exist, use_kv_connector):
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
max_num_seqs=10, # allow multiple requests
max_num_batched_tokens=2048,
enable_prefix_caching=True,
use_kv_connector=use_kv_connector,
use_ec_connector=True,
ec_role="ec_consumer",
)
mm_hashes_list = [[f"hash_{i}"] for i in range(10)]
mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)]
requests = create_requests(
num_requests=10,
num_tokens=200,
mm_hashes_list=mm_hashes_list,
mm_positions=mm_positions,
)
for request in requests:
scheduler.add_request(request)
# Set up to test different encoder cache exsistence scenario after preemption
# Order of getting encoder cache should be: local cache -> connector-> compute
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
if cache_exist == "local":
# Allocate cache to cache manager manually to mimick
for req in requests:
scheduler.encoder_cache_manager.allocate(req, 0)
else:
# Make sure local encoder cache empty
scheduler.encoder_cache_manager.cached = {}
if cache_exist == "connector_only":
# Cache exist in ec_connector
scheduler.ec_connector.has_caches = Mock(return_value=[True])
elif cache_exist == "no_where":
scheduler.ec_connector.has_caches = Mock(return_value=[False])
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
## Encoder-cache-specific checks:
# mm_hashes of requests exist in cache after scheduling for all scenario
_assert_right_encoder_cache_allocated(scheduler, requests=requests)
# Should only call update_state_after_alloc when loaded externally
if cache_exist == "connector_only":
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
requests[-1], 0
)
# Concat mm_features for the 10 requests together
mm_features_list = [feature for req in requests for feature in req.mm_features]
# Check metadata should contain mm data for all 10 requests
_assert_right_ec_connector_metadata(output, mm_features_list=mm_features_list)
else:
scheduler.ec_connector.update_state_after_alloc.assert_not_called()
# ECConnector should carry no metadata
_assert_right_ec_connector_metadata(output, mm_features_list=[])
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# Should only schedule encoder input when cache is not found anywhere
if cache_exist == "no_where":
_assert_right_encoder_inputs(
output,
requests=requests,
expected_encoder_inputs=[[0] for _ in range(10)],
expected_total_reqs=10,
)
else:
_assert_right_encoder_inputs(output, expected_total_reqs=0)
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_unable_to_allocate(use_kv_connector):
"""
Test whether scheduler with ECConnector is able to handle
unable to allocate (run out of blocks).
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 4
NUM_BLOCKS = 10
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
enable_prefix_caching=True,
use_kv_connector=use_kv_connector,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
use_ec_connector=True,
ec_role="ec_consumer",
)
# Mock ec_connector load external cache behavior
scheduler.ec_connector.has_caches = Mock(return_value=[True])
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
# Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks.
NUM_REQUESTS = 2
NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE
MAX_TOKENS = 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
mm_hashes_list=[["hash_1"], ["hash_2"]],
mm_positions=[
[PlaceholderRange(offset=1, length=10)] for _ in range(NUM_REQUESTS)
],
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
# Setup MODEL_RUNNER_OUTPUT to be run in _step_until_done later
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Just one request should be running.
output = scheduler.schedule()
scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id]
assert scheduled_tokens == NUM_TOKENS
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Should have called update_state_after_alloc for external load
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
scheduler.running[0], 0
)
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# All memory should be freed, with one request waiting.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# Just one request should be running.
output = scheduler.schedule()
scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id]
assert scheduled_tokens == NUM_TOKENS
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# update_state_after_alloc should be called for loading external cache
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
scheduler.running[0], 0
)
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# All memory should be freed, with no requests waiting / running.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
@pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"])
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_priority_scheduling_ec_connector_preemption_and_resumption(
cache_exist, use_kv_connector
):
"""Test that priority scheduling preempts lower priority requests
when out of KV cache space."""
# Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority(
model="llava-hf/llava-1.5-7b-hf",
enable_prefix_caching=True,
max_num_seqs=2, # allow multiple requests
# kv connector should not effect test results
use_kv_connector=use_kv_connector,
num_blocks=15, # can hold 244 tokens with 14 blocks (first block is null)
block_size=16, # standard block size
use_ec_connector=True,
ec_role="ec_consumer",
)
# Mock cache hit: Both cache exist in connector (at E->PD initially)
scheduler.ec_connector.has_caches = Mock(return_value=[True])
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
# Create a request and schedule it (and to be preempted)
request_low = create_requests_with_priority(
num_requests=1,
priorities=[1],
arrival_times=[0.0],
num_tokens=94,
mm_hashes_list=[["hash_low"]],
# NOTE: this test only preempt the last block.
# Setting mm_position at the last block can force to recompute encoding
mm_positions=[[PlaceholderRange(offset=82, length=10)]],
starting_idx=0,
)[0]
scheduler.add_request(request_low)
# 1st schedule
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
scheduled_tokens = output.num_scheduled_tokens[request_low.request_id]
assert scheduled_tokens == 94
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1
## Encoder-cache-specific checks:
# Encoder cache should contain mm items from request
_assert_right_encoder_cache_allocated(scheduler, requests=[request_low])
# Verify update_state_after_alloc called (external load)
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_low, 0)
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata of request
_assert_right_ec_connector_metadata(
output, mm_features_list=request_low.mm_features
)
# Scheduled encoder input should be empty; no mm to compute
_assert_right_encoder_inputs(output, expected_total_reqs=0)
# Simulate model execution - 1st decode
model_output = ModelRunnerOutput(
req_ids=[request_low.request_id],
req_id_to_index={request_low.request_id: 0},
sampled_token_ids=[[100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Create a high priority request and schedule it
request_high = create_requests_with_priority(
num_requests=1,
priorities=[0],
arrival_times=[1.0],
num_tokens=128,
mm_hashes_list=[["hash_high"]],
mm_positions=[[PlaceholderRange(offset=1, length=10)]],
max_tokens=2,
starting_idx=1,
)[0]
scheduler.add_request(request_high)
# 2nd schedule
output = scheduler.schedule()
# KV cache should be full at this point
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_cached_reqs.num_reqs == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 2
## Encoder-cache-specific checks:
# Encoder cache should contain mm items from request
_assert_right_encoder_cache_allocated(scheduler, requests=[request_high])
# Verify update_state_after_alloc called (external load)
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_high, 0)
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata of request
_assert_right_ec_connector_metadata(
output, mm_features_list=request_high.mm_features
)
# Scheduled encoder input should be empty; no mm to compute
_assert_right_encoder_inputs(output, expected_total_reqs=0)
# Simulate model execution - 2nd decode
requests = [request_low, request_high]
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[100] for _ in requests],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# 3rd schedule - - this should trigger preemption
# req_low needs 96 tokens = 6 blocks
# req_high needs 129 tokens = 9 blocks
# so doesn't fit in 14 blocks.
output = scheduler.schedule()
# Should have preempted req_low
assert len(output.scheduled_new_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1
## Encoder-cache-specific checks:
# request_high is in decode phase now
# ECConnector should carry no metadata
_assert_right_ec_connector_metadata(output, mm_features_list=[])
# Scheduled encoder input should be empty; no mm to compute
_assert_right_encoder_inputs(output, expected_total_reqs=0)
# Simulate model execution - 3rd decode, after req_low was preempted
requests = [request_low, request_high]
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[100], [100, 200]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Finish the requests to make room for the preempted requests to resume
# req_high is finished after outputing 2 tokens
scheduler.update_from_output(output, model_output)
scheduler.finish_requests(
request_high.request_id, RequestStatus.FINISHED_LENGTH_CAPPED
)
# Set up to test different encoder cache exsistence scenario after preemption
# Order of getting encoder cache should be: local cache -> connector-> compute
# By default, the cache should still exist in local in this test case
if cache_exist != "local":
# Make local encoder cache empty
scheduler.encoder_cache_manager.cached = {}
if cache_exist == "connector_only":
# Cache exist in ec_connector
scheduler.ec_connector.has_caches = Mock(return_value=[True])
elif cache_exist == "no_where":
scheduler.ec_connector.has_caches = Mock(return_value=[False])
# 4th Schedule - this should trigger req_low resumption from waiting
output = scheduler.schedule()
scheduled_cached_reqs = output.scheduled_cached_reqs
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
assert len(output.scheduled_new_reqs) == 0
assert scheduled_cached_reqs.num_reqs == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1
# Preempted request resumed in scheduled_cached_reqs
assert len(resumed_from_preemption) == 1
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
assert resumed_from_preemption[0]
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
## Resumed tokens include 94 prompt tokens and 2 decoded tokens
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 96
assert scheduled_cached_reqs.resumed_req_token_ids[0][95] == 100
assert scheduler.running[0].request_id == request_low.request_id
assert request_high.request_id in output.finished_req_ids
## Encoder-cache-specific checks:
# mm_hash of request_low exists in cache after scheduling for all scenario
_assert_right_encoder_cache_allocated(scheduler, requests=[request_low])
# Should only call update_state_after_alloc when loaded externally
if cache_exist == "connector_only":
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
request_low, 0
)
_assert_right_ec_connector_metadata(
output, mm_features_list=request_low.mm_features
)
else:
scheduler.ec_connector.update_state_after_alloc.assert_not_called()
# ECConnector should carry no metadata
_assert_right_ec_connector_metadata(output, mm_features_list=[])
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# Should only schedule encoder input when cache is not found anywhere
if cache_exist == "no_where":
_assert_right_encoder_inputs(
output,
requests=[request_low],
expected_encoder_inputs=[[0]],
expected_total_reqs=1,
)
else:
_assert_right_encoder_inputs(output, expected_total_reqs=0)
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connector):
"""
Scenario:
- Encoder cache size: 32
- Request A: 1 feature (12 tokens) → NOT cached remotely.
- Request B: 3 features (3 x 10 tokens) → ALL cached remotely.
Steps:
1. Schedule Request A (locally uses 12 tokens).
2. Schedule Request B (remote cache) - only schedule 1st and 2nd
3. Free A's cache, then schedule B again (continuation) - schedule 3rd image
"""
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
max_num_batched_tokens=1024,
enable_prefix_caching=True,
use_kv_connector=use_kv_connector,
block_size=16,
num_blocks=11, # Can hold 160 tokens (first block is null)
use_ec_connector=True,
ec_role="ec_consumer",
)
# Limit the number of availiable slots of EncoderCacheManager
scheduler.encoder_cache_manager = EncoderCacheManager(cache_size=32)
# Create MM request1
NUM_TOKENS_1 = 50 # NOTE: includes mm tokens
NUM_ENCODER_TOKENS_1 = 12
mm_hashes_list_1 = [["hash1_1"]]
mm_positions_1 = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1)]]
request1 = create_requests(
num_requests=1,
num_tokens=NUM_TOKENS_1,
mm_hashes_list=mm_hashes_list_1,
mm_positions=mm_positions_1,
max_tokens=1, # For simplicity
req_ids=["req1"],
)[0]
# Create MM request1 with 3 MM items
NUM_TOKENS_2 = 40
NUM_ENCODER_TOKENS_2 = 10
mm_hashes_list_2 = [["hash2_1", "hash2_2", "hash2_3"]]
mm_positions_2 = [
[
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2),
PlaceholderRange(offset=12, length=NUM_ENCODER_TOKENS_2),
PlaceholderRange(offset=24, length=NUM_ENCODER_TOKENS_2),
]
]
request2 = create_requests(
num_requests=1,
num_tokens=NUM_TOKENS_2,
mm_hashes_list=mm_hashes_list_2,
mm_positions=mm_positions_2,
max_tokens=10,
req_ids=["req2"],
)[0]
# Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely
scheduler.ec_connector.has_caches = Mock(
side_effect=lambda req: [True, True, True] if req == request2 else [False]
)
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
scheduler.add_request(request1)
scheduler.add_request(request2)
output = scheduler.schedule()
# Now, since encoder cache manager can only store 32 tokens
# It should allocated mm item hash1_1, hash2_1 and hash2_2
scheduled_tokens = output.num_scheduled_tokens[request1.request_id]
assert scheduled_tokens == NUM_TOKENS_1
assert scheduler.get_num_unfinished_requests() == 2
# Encoder cache should contain mm item from request1
_assert_right_encoder_cache_allocated(
scheduler, hashes_to_check=["hash1_1", "hash2_1", "hash2_2"]
)
# request2's 2nd mm item is the last call of update_state_after_alloc
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 1)
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata of hash2_1 and hash2_2 ONLY
_assert_right_ec_connector_metadata(
output, mm_features_list=[request2.mm_features[0], request2.mm_features[1]]
)
# Should schedule ONLY 1 encoder input
_assert_right_encoder_inputs(
output,
requests=[request1],
expected_encoder_inputs=[[0]], # index 0 of the mm item of request1
expected_total_reqs=1,
)
# Simulate model execution 1 step
model_output = ModelRunnerOutput(
req_ids=[request1.request_id, request2.request_id],
req_id_to_index={request1.request_id: 0, request2.request_id: 1},
sampled_token_ids=[[100], [121]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# request1 is finished after outputing 1 token
# Finish request
scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED)
assert scheduler.get_num_unfinished_requests() == 1
# Schedule again; Now request1's encoder cache should be freed
# -> hash2_3 can be scheduled and allocated
output = scheduler.schedule()
# Check
# Should schedule all tokens
scheduled_tokens = output.num_scheduled_tokens[request2.request_id]
print(f"Hero: scheduled_tokens for req2: {scheduled_tokens}")
print(f"hero: num_scheduled_tokens 2: {output.num_scheduled_tokens}")
# Encoder cache should contain all mm items from request2
_assert_right_encoder_cache_allocated(scheduler, requests=[request2])
# request2's 3rd mm item is the ONLY call of update_state_after_alloc
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 2)
scheduler.ec_connector.update_state_after_alloc.assert_called_once()
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata for hash2_3 ONLY
_assert_right_ec_connector_metadata(
output, mm_features_list=[request2.mm_features[2]]
)
# Should schedule no encoder input
_assert_right_encoder_inputs(
output,
expected_total_reqs=0,
)
# ==============================================================================
# EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests end
# ==============================================================================