mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:44:57 +08:00
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
3277 lines
119 KiB
Python
3277 lines
119 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import dataclasses
|
|
from unittest.mock import Mock
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.config import (
|
|
CacheConfig,
|
|
ECTransferConfig,
|
|
KVTransferConfig,
|
|
ModelConfig,
|
|
SchedulerConfig,
|
|
SpeculativeConfig,
|
|
VllmConfig,
|
|
)
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalFeatureSpec,
|
|
MultiModalKwargsItem,
|
|
PlaceholderRange,
|
|
)
|
|
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
|
from vllm.utils.hashing import sha256
|
|
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
|
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
|
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
|
from vllm.v1.core.sched.scheduler import Scheduler
|
|
from vllm.v1.kv_cache_interface import (
|
|
FullAttentionSpec,
|
|
KVCacheConfig,
|
|
KVCacheGroupSpec,
|
|
)
|
|
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
|
from vllm.v1.request import Request, RequestStatus
|
|
from vllm.v1.structured_output import StructuredOutputManager
|
|
|
|
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler, mock_kv
|
|
|
|
pytestmark = pytest.mark.cpu_test
|
|
|
|
|
|
def test_add_requests():
|
|
scheduler = create_scheduler()
|
|
requests = create_requests(num_requests=10)
|
|
|
|
for i, request in enumerate(requests):
|
|
scheduler.add_request(request)
|
|
assert request.request_id in scheduler.requests
|
|
assert len(scheduler.waiting) == i + 1
|
|
|
|
|
|
def test_finish_request():
|
|
scheduler = create_scheduler()
|
|
requests = create_requests(num_requests=10)
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
for i, request in enumerate(requests):
|
|
scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED)
|
|
assert request.request_id not in scheduler.requests
|
|
assert len(scheduler.waiting) == 9 - i
|
|
|
|
|
|
def test_get_num_unfinished_requests():
|
|
scheduler = create_scheduler()
|
|
requests = create_requests(num_requests=10)
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
for i, request in enumerate(requests):
|
|
scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_STOPPED)
|
|
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"enable_prefix_caching, prompt_logprobs",
|
|
[
|
|
(None, None),
|
|
(True, 5),
|
|
],
|
|
)
|
|
def test_schedule(enable_prefix_caching: bool | None, prompt_logprobs: int | None):
|
|
"""Test scheduling.
|
|
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
|
"""
|
|
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
|
|
requests = create_requests(num_requests=10, prompt_logprobs=prompt_logprobs)
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
# Test initial scheduling
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == len(requests)
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
|
assert len(output.finished_req_ids) == 0
|
|
# Verify all requests are scheduled.
|
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
|
|
|
# Verify requests moved from waiting to running
|
|
assert len(scheduler.waiting) == 0
|
|
assert len(scheduler.running) == len(requests)
|
|
for i, request in enumerate(requests):
|
|
assert scheduler.running[i] == request
|
|
|
|
|
|
def test_schedule_multimodal_requests():
|
|
scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf")
|
|
mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)]
|
|
requests = create_requests(
|
|
num_requests=10,
|
|
num_tokens=200,
|
|
mm_positions=mm_positions,
|
|
)
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == len(requests)
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
|
assert len(output.finished_req_ids) == 0
|
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
|
assert len(output.scheduled_encoder_inputs) == 10
|
|
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
|
|
assert len(encoder_input) == 1
|
|
|
|
|
|
def test_schedule_partial_requests():
|
|
"""Test scheduling behavior with partial requests.
|
|
|
|
This test verifies that:
|
|
1. The scheduler can handle multiple partial requests in a single step when
|
|
constrained by encoder budget.
|
|
2. A request in RUNNING state may be unscheduled in subsequent steps if
|
|
there is insufficient encoder budget.
|
|
"""
|
|
scheduler = create_scheduler(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
max_num_batched_tokens=1024,
|
|
)
|
|
mm_positions = [[PlaceholderRange(offset=100, length=600)] for _ in range(3)]
|
|
requests = create_requests(
|
|
num_requests=3,
|
|
num_tokens=800,
|
|
mm_positions=mm_positions,
|
|
)
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == 3
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
|
assert len(output.finished_req_ids) == 0
|
|
|
|
assert scheduler.max_num_encoder_input_tokens == 1024
|
|
# The first request is scheduled fully.
|
|
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
|
# The second request is scheduled partially.
|
|
# The <img> tokens are not scheduled because of the encoder budget.
|
|
assert output.num_scheduled_tokens[requests[1].request_id] == 100
|
|
# The third request is also scheduled partially.
|
|
# The <img> tokens are not scheduled because of the encoder budget.
|
|
assert output.num_scheduled_tokens[requests[2].request_id] == 100
|
|
req_to_index = {request.request_id: i for i, request in enumerate(requests)}
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=[request.request_id for request in requests],
|
|
req_id_to_index=req_to_index,
|
|
# Only the first request has a sampled token id because
|
|
# the rest requests are still being prefilled.
|
|
sampled_token_ids=[np.array([0]), np.array([]), np.array([])],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_runner_output)
|
|
|
|
# Schedule the next step.
|
|
# Only the first and second requests are scheduled.
|
|
# The third request is in the RUNNING state but not scheduled in this step
|
|
# because of the encoder budget.
|
|
output = scheduler.schedule()
|
|
assert len(scheduler.running) == 3
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
assert output.scheduled_cached_reqs.num_reqs == 2
|
|
assert len(output.finished_req_ids) == 0
|
|
assert output.num_scheduled_tokens[requests[0].request_id] == 1
|
|
assert output.num_scheduled_tokens[requests[1].request_id] == 700
|
|
assert requests[2].request_id not in output.num_scheduled_tokens
|
|
|
|
|
|
def test_no_mm_input_chunking():
|
|
# Disable multimodal input chunking.
|
|
scheduler = create_scheduler(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
max_num_batched_tokens=1024,
|
|
disable_chunked_mm_input=True,
|
|
max_model_len=2048,
|
|
)
|
|
mm_positions = [[PlaceholderRange(offset=400, length=800)]]
|
|
requests = create_requests(
|
|
num_requests=1, num_tokens=1200, mm_positions=mm_positions
|
|
)
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
|
assert len(output.finished_req_ids) == 0
|
|
# We want to only see the 400 text tokens at the start scheduled
|
|
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
|
|
|
req_to_index = {request.request_id: i for i, request in enumerate(requests)}
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=[request.request_id for request in requests],
|
|
req_id_to_index=req_to_index,
|
|
sampled_token_ids=[np.array([]) for _ in range(len(requests))],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_runner_output)
|
|
|
|
output = scheduler.schedule()
|
|
assert len(scheduler.running) == 1
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
assert output.scheduled_cached_reqs.num_reqs == 1
|
|
assert len(output.finished_req_ids) == 0
|
|
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
|
|
|
# Test that we fail if we disable chunked mm input and use too small
|
|
# of a max_num_batched_tokens for the mm input.
|
|
with pytest.raises(ValueError):
|
|
_ = create_scheduler(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
max_num_batched_tokens=100,
|
|
disable_chunked_mm_input=True,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
|
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|
"""Test scheduling behavior with concurrent partial requests.
|
|
|
|
This test verifies that: there are multiple long prefill requests in the
|
|
RUNNING state, and we can schedule them together.
|
|
|
|
"""
|
|
scheduler = create_scheduler(
|
|
model="facebook/opt-125m",
|
|
max_num_batched_tokens=1024,
|
|
long_prefill_token_threshold=400,
|
|
enable_prefix_caching=enable_prefix_caching,
|
|
)
|
|
requests = create_requests(
|
|
num_requests=3,
|
|
num_tokens=800,
|
|
)
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == 3
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
|
assert len(output.finished_req_ids) == 0
|
|
|
|
# The first request is scheduled partially - 400.
|
|
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
|
# The second request is scheduled partially - 400.
|
|
assert output.num_scheduled_tokens[requests[1].request_id] == 400
|
|
# The third request is also scheduled partially - 1024 - 400 - 400 = 224.
|
|
assert output.num_scheduled_tokens[requests[2].request_id] == 224
|
|
req_to_index = {request.request_id: i for i, request in enumerate(requests)}
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=[request.request_id for request in requests],
|
|
req_id_to_index=req_to_index,
|
|
sampled_token_ids=[np.array([]) for _ in range(len(requests))],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_runner_output)
|
|
|
|
# Schedule the next step. All three requests are running.
|
|
# Processed the remaining prefills of the first and second requests.
|
|
output1 = scheduler.schedule()
|
|
assert len(scheduler.running) == 3
|
|
assert len(output1.scheduled_new_reqs) == 0
|
|
assert output1.scheduled_cached_reqs.num_reqs == 3
|
|
assert len(output1.finished_req_ids) == 0
|
|
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
|
|
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
|
|
assert output1.num_scheduled_tokens[requests[2].request_id] == 224
|
|
|
|
# Schedule the third step. All three requests are running.
|
|
# First and second requests are in the decode stage.
|
|
# All the remaining tokens in the third request are processed.
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=[request.request_id for request in requests],
|
|
req_id_to_index=req_to_index,
|
|
sampled_token_ids=[np.array([0]), np.array([0])]
|
|
+ [np.array([]) for _ in range(len(requests) - 2)],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output1, model_runner_output)
|
|
output2 = scheduler.schedule()
|
|
assert len(scheduler.running) == 3
|
|
assert len(output2.scheduled_new_reqs) == 0
|
|
assert output2.scheduled_cached_reqs.num_reqs == 3
|
|
assert len(output2.finished_req_ids) == 0
|
|
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
|
|
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
|
|
assert output2.num_scheduled_tokens[requests[2].request_id] == 800 - 224 - 224
|
|
|
|
|
|
def test_stop_via_update_from_output():
|
|
"""Test stopping behavior through update_from_output"""
|
|
scheduler = create_scheduler(num_speculative_tokens=1)
|
|
|
|
# Test case 1: Stop on EOS token
|
|
requests = create_requests(num_requests=2, max_tokens=10)
|
|
for req in requests:
|
|
req.num_computed_tokens = req.num_tokens
|
|
scheduler.requests[req.request_id] = req
|
|
scheduler.running.append(req)
|
|
req.status = RequestStatus.RUNNING
|
|
|
|
scheduler_output = SchedulerOutput(
|
|
scheduled_new_reqs=[],
|
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
num_scheduled_tokens={requests[0].request_id: 1, requests[1].request_id: 2},
|
|
total_num_scheduled_tokens=3,
|
|
scheduled_encoder_inputs={},
|
|
scheduled_spec_decode_tokens={
|
|
requests[0].request_id: [],
|
|
requests[1].request_id: [10],
|
|
},
|
|
num_common_prefix_blocks=[],
|
|
finished_req_ids=set(),
|
|
free_encoder_mm_hashes=[],
|
|
)
|
|
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[req.request_id for req in requests],
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
|
sampled_token_ids=[
|
|
np.array([EOS_TOKEN_ID]),
|
|
np.array([10, 11]),
|
|
], # First request hits EOS, second continues
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
|
|
scheduler.update_from_output(scheduler_output, model_output)
|
|
|
|
# Verify first request stopped, second continues
|
|
assert len(scheduler.running) == 1
|
|
assert scheduler.running[0].request_id == requests[1].request_id
|
|
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
|
assert requests[0].request_id in scheduler.finished_req_ids
|
|
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
|
|
assert list(requests[1].output_token_ids) == [10, 11]
|
|
|
|
# Test case 2: Stop on custom stop token
|
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
|
requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43])
|
|
for req in requests:
|
|
req.num_computed_tokens = req.num_tokens
|
|
scheduler.requests[req.request_id] = req
|
|
scheduler.running.append(req)
|
|
req.status = RequestStatus.RUNNING
|
|
|
|
scheduler_output = SchedulerOutput(
|
|
scheduled_new_reqs=[],
|
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 2},
|
|
total_num_scheduled_tokens=5,
|
|
scheduled_encoder_inputs={},
|
|
scheduled_spec_decode_tokens={
|
|
requests[0].request_id: [10, 42],
|
|
requests[1].request_id: [13],
|
|
},
|
|
num_common_prefix_blocks=[],
|
|
finished_req_ids=set(),
|
|
free_encoder_mm_hashes=[],
|
|
)
|
|
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[req.request_id for req in requests],
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
|
sampled_token_ids=[
|
|
np.array([10, 42, 12]),
|
|
np.array([13, 14]),
|
|
], # First request hits stop token
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
|
|
scheduler.update_from_output(scheduler_output, model_output)
|
|
|
|
# Verify first request stopped on custom token
|
|
assert len(scheduler.running) == 1
|
|
assert scheduler.running[0].request_id == requests[1].request_id
|
|
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
|
assert requests[0].stop_reason == 42
|
|
assert requests[0].request_id in scheduler.finished_req_ids
|
|
assert list(requests[0].output_token_ids) == [10, 42]
|
|
assert list(requests[1].output_token_ids) == [13, 14]
|
|
|
|
# Test case 3: Stop on max tokens
|
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
|
requests = create_requests(num_requests=2, max_tokens=2)
|
|
for req in requests:
|
|
req.num_computed_tokens = req.num_tokens
|
|
scheduler.requests[req.request_id] = req
|
|
scheduler.running.append(req)
|
|
req.status = RequestStatus.RUNNING
|
|
|
|
scheduler_output = SchedulerOutput(
|
|
scheduled_new_reqs=[],
|
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 1},
|
|
total_num_scheduled_tokens=4,
|
|
scheduled_encoder_inputs={},
|
|
scheduled_spec_decode_tokens={
|
|
requests[0].request_id: [10, 11],
|
|
requests[1].request_id: [],
|
|
},
|
|
num_common_prefix_blocks=[],
|
|
finished_req_ids=set(),
|
|
free_encoder_mm_hashes=[],
|
|
)
|
|
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[req.request_id for req in requests],
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
|
sampled_token_ids=[
|
|
np.array([10, 11, 12]),
|
|
np.array([13]),
|
|
], # First request exceeds max_tokens
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
|
|
scheduler.update_from_output(scheduler_output, model_output)
|
|
|
|
# Verify first request stopped due to length
|
|
assert len(scheduler.running) == 1
|
|
assert scheduler.running[0].request_id == requests[1].request_id
|
|
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
|
|
assert requests[0].request_id in scheduler.finished_req_ids
|
|
assert list(requests[0].output_token_ids) == [10, 11] # Truncated to max_tokens
|
|
assert list(requests[1].output_token_ids) == [13]
|
|
|
|
# Test case 4: Ignore EOS flag
|
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
|
requests = create_requests(num_requests=1, max_tokens=10)
|
|
requests[0].sampling_params.ignore_eos = True
|
|
requests[0].num_computed_tokens = requests[0].num_tokens
|
|
scheduler.requests[requests[0].request_id] = requests[0]
|
|
scheduler.running.append(requests[0])
|
|
|
|
scheduler_output = SchedulerOutput(
|
|
scheduled_new_reqs=[],
|
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
num_scheduled_tokens={requests[0].request_id: 3},
|
|
total_num_scheduled_tokens=3,
|
|
scheduled_encoder_inputs={},
|
|
scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]},
|
|
num_common_prefix_blocks=[],
|
|
finished_req_ids=set(),
|
|
free_encoder_mm_hashes=[],
|
|
)
|
|
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[requests[0].request_id],
|
|
req_id_to_index={requests[0].request_id: 0},
|
|
sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
|
|
scheduler.update_from_output(scheduler_output, model_output)
|
|
|
|
# Verify request continues past EOS
|
|
assert len(scheduler.running) == 1
|
|
assert not requests[0].is_finished()
|
|
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
|
|
|
|
|
|
def test_check_stop_min_tokens():
|
|
"""Test that requests don't stop when min_tokens requirement isn't met."""
|
|
from vllm.v1.core.sched.utils import check_stop
|
|
|
|
# Test case 1: num_output_tokens < min_tokens
|
|
# Should return False (don't stop)
|
|
sampling_params = SamplingParams(
|
|
ignore_eos=False,
|
|
max_tokens=20,
|
|
min_tokens=5,
|
|
)
|
|
request = Request(
|
|
request_id="0",
|
|
prompt_token_ids=[0, 1, 2],
|
|
sampling_params=sampling_params,
|
|
pooling_params=None,
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
)
|
|
# Simulate having generated 3 output tokens (less than min_tokens=5)
|
|
request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present
|
|
|
|
result = check_stop(request, max_model_len=100)
|
|
assert result is False, "Should not stop when num_output_tokens<min_tokens"
|
|
|
|
# Test case 2: num_output_tokens >= min_tokens
|
|
# Should follow normal stopping logic (stop on EOS)
|
|
request.append_output_token_ids(
|
|
[
|
|
10,
|
|
11,
|
|
12,
|
|
13,
|
|
14,
|
|
EOS_TOKEN_ID,
|
|
]
|
|
) # 6 tokens > min_tokens
|
|
|
|
result = check_stop(request, max_model_len=100)
|
|
assert result is True, "Should stop on EOS when min_tokens met"
|
|
assert request.status == RequestStatus.FINISHED_STOPPED
|
|
|
|
# Test case 3: min_tokens = 0, should follow normal stopping logic
|
|
sampling_params_no_min = SamplingParams(
|
|
ignore_eos=False,
|
|
max_tokens=20,
|
|
min_tokens=0,
|
|
)
|
|
request_no_min = Request(
|
|
request_id="1",
|
|
prompt_token_ids=[0, 1, 2],
|
|
sampling_params=sampling_params_no_min,
|
|
pooling_params=None,
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
)
|
|
request_no_min.append_output_token_ids([10, EOS_TOKEN_ID])
|
|
|
|
result = check_stop(request_no_min, max_model_len=100)
|
|
assert result is True, "Should stop on EOS when min_tokens=0"
|
|
assert request_no_min.status == RequestStatus.FINISHED_STOPPED
|
|
|
|
# Test case 4: min_tokens > 0 with stop token (not EOS)
|
|
sampling_params_stop = SamplingParams(
|
|
ignore_eos=False,
|
|
max_tokens=20,
|
|
min_tokens=5,
|
|
stop_token_ids=[42],
|
|
)
|
|
request_stop = Request(
|
|
request_id="2",
|
|
prompt_token_ids=[0, 1, 2],
|
|
sampling_params=sampling_params_stop,
|
|
pooling_params=None,
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
)
|
|
# Only 3 output tokens, less than min_tokens=5, but has stop token
|
|
request_stop.append_output_token_ids([10, 11, 42])
|
|
result = check_stop(request_stop, max_model_len=100)
|
|
assert result is False, "Should not stop when num_output_tokens<min_tokens"
|
|
|
|
# Test case 5: min_tokens met, should stop on stop token
|
|
request_stop.append_output_token_ids(
|
|
[10, 11, 12, 13, 14, 42]
|
|
) # 6 tokens >= min_tokens=5
|
|
|
|
result = check_stop(request_stop, max_model_len=100)
|
|
assert result is True, "Should stop on stop token when min_tokens met"
|
|
assert request_stop.status == RequestStatus.FINISHED_STOPPED
|
|
assert request_stop.stop_reason == 42
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"enable_prefix_caching, prompt_logprobs",
|
|
[
|
|
(None, None),
|
|
(True, 5),
|
|
],
|
|
)
|
|
def test_schedule_concurrent_batches(
|
|
enable_prefix_caching: bool | None, prompt_logprobs: int | None
|
|
):
|
|
scheduler = create_scheduler(
|
|
max_num_batched_tokens=1024,
|
|
max_num_seqs=2,
|
|
enable_prefix_caching=enable_prefix_caching,
|
|
)
|
|
requests = create_requests(
|
|
num_requests=2,
|
|
num_tokens=512,
|
|
prompt_logprobs=prompt_logprobs,
|
|
)
|
|
|
|
# Schedule the first request.
|
|
scheduler.add_request(requests[0])
|
|
scheduler_output0 = scheduler.schedule()
|
|
assert len(scheduler_output0.scheduled_new_reqs) == 1
|
|
assert scheduler_output0.num_scheduled_tokens[requests[0].request_id] == 512
|
|
|
|
# The first request is still running, so only schedule the second request.
|
|
scheduler.add_request(requests[1])
|
|
scheduler_output1 = scheduler.schedule()
|
|
assert len(scheduler_output1.scheduled_new_reqs) == 1
|
|
assert scheduler_output1.num_scheduled_tokens[requests[1].request_id] == 512
|
|
|
|
# Model output of the first request.
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=[requests[0].request_id],
|
|
req_id_to_index={requests[0].request_id: 0},
|
|
sampled_token_ids=[np.array([0])],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(scheduler_output0, model_runner_output)
|
|
|
|
# Schedule the next step.
|
|
# The first request can be scheduled again while the second
|
|
# request is still running.
|
|
scheduler_output2 = scheduler.schedule()
|
|
assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1
|
|
|
|
# Model output of the second request.
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=[requests[1].request_id],
|
|
req_id_to_index={requests[1].request_id: 0},
|
|
sampled_token_ids=[np.array([0])],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
|
|
|
|
|
def test_preempt_during_execution():
|
|
# NOTE(woosuk): The actual number of available blocks is 10 instead of 11
|
|
# because block 0 is reserved as the null block.
|
|
scheduler = create_scheduler(
|
|
max_num_batched_tokens=100,
|
|
block_size=16,
|
|
num_blocks=11,
|
|
enable_prefix_caching=False,
|
|
)
|
|
requests = create_requests(num_requests=2, num_tokens=80, block_size=16)
|
|
|
|
# Schedule the first request.
|
|
scheduler.add_request(requests[0])
|
|
scheduler_output0 = scheduler.schedule()
|
|
assert len(scheduler_output0.num_scheduled_tokens) == 1
|
|
assert len(scheduler_output0.scheduled_new_reqs[0].block_ids[0]) == 5
|
|
|
|
# Schedule the second request while the first request is still running.
|
|
# This scenario can occur in certain cases, when max_concurrent_batches > 1
|
|
# (e.g., when pipeline parallelism is used).
|
|
scheduler.add_request(requests[1])
|
|
scheduler_output1 = scheduler.schedule()
|
|
assert len(scheduler_output1.num_scheduled_tokens) == 1
|
|
assert len(scheduler_output1.scheduled_new_reqs[0].block_ids[0]) == 5
|
|
|
|
# Get the output of the first request.
|
|
model_runner_output0 = ModelRunnerOutput(
|
|
req_ids=[requests[0].request_id],
|
|
req_id_to_index={requests[0].request_id: 0},
|
|
sampled_token_ids=[np.array([0])],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(scheduler_output0, model_runner_output0)
|
|
|
|
# Schedule the first request again. This will cause the preemption
|
|
# of the second request because the KV cache is full.
|
|
_ = scheduler.schedule()
|
|
assert len(scheduler.running) == 1
|
|
assert scheduler.running[0] == requests[0]
|
|
assert requests[1].status == RequestStatus.PREEMPTED
|
|
|
|
model_runner_output1 = ModelRunnerOutput(
|
|
req_ids=[requests[1].request_id],
|
|
req_id_to_index={requests[1].request_id: 0},
|
|
sampled_token_ids=[np.array([42])],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(scheduler_output1, model_runner_output1)
|
|
|
|
# The second request (that is preempted) should be updated with the
|
|
# sampled token id.
|
|
assert len(requests[1].output_token_ids) == 1
|
|
assert requests[1].output_token_ids[0] == 42
|
|
|
|
|
|
# Note - these test cases mirror some of those in test_rejection_sampler.py
|
|
@pytest.mark.parametrize(
|
|
"spec_tokens,output_tokens,expected",
|
|
[
|
|
([[1, 2, 3]], [np.array([1, 2, 3, 4])], (1, 3, 3, [1, 1, 1])), # perfect match
|
|
([[1, 2, 3]], [np.array([1, 5])], (1, 3, 1, [1, 0, 0])), # early mismatch
|
|
(
|
|
[[1, 2], [3]],
|
|
[np.array([1, 2, 5]), np.array([3, 4])],
|
|
(2, 3, 3, [2, 1]),
|
|
), # multiple sequences
|
|
([[1]], [np.array([1, 2])], (1, 1, 1, [1])), # single token sequence
|
|
([[]], [np.array([5])], (0, 0, 0, [0])), # empty sequence
|
|
(
|
|
[[1, 2, 3], [4, 5, 6]],
|
|
[np.array([1, 2, 7]), np.array([4, 8])],
|
|
(2, 6, 3, [2, 1, 0]),
|
|
), # multiple mismatches
|
|
],
|
|
)
|
|
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
|
"""Test scheduling behavior with speculative decoding.
|
|
|
|
This test verifies that:
|
|
1. Speculated tokens get scheduled correctly
|
|
2. Spec decoding stats properly count number of draft and accepted tokens
|
|
"""
|
|
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
|
|
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
|
|
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
|
|
req_ids = []
|
|
req_to_index = {}
|
|
for i, request in enumerate(requests):
|
|
scheduler.add_request(request)
|
|
req_ids.append(request.request_id)
|
|
req_to_index[request.request_id] = i
|
|
|
|
# Schedule a decode, which will also draft speculative tokens
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == len(requests)
|
|
assert output.total_num_scheduled_tokens == len(requests)
|
|
for i in range(len(requests)):
|
|
req_id = requests[i].request_id
|
|
assert output.num_scheduled_tokens[req_id] == 1
|
|
assert req_id not in output.scheduled_spec_decode_tokens
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=req_ids,
|
|
req_id_to_index=req_to_index,
|
|
sampled_token_ids=[np.array([0]) for _ in range(len(requests))],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)
|
|
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
|
|
scheduler.update_draft_token_ids(draft_token_ids)
|
|
|
|
for i in range(len(requests)):
|
|
running_req = scheduler.running[i]
|
|
# The prompt token
|
|
assert running_req.num_computed_tokens == 1
|
|
# The prompt token and the sampled token
|
|
assert running_req.num_tokens == 2
|
|
# The prompt token, the sampled token, and the speculated tokens
|
|
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
|
|
|
|
# No draft or accepted tokens counted yet
|
|
assert not engine_core_outputs or (
|
|
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None
|
|
)
|
|
|
|
# Schedule the speculated tokens for validation
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
# The sampled token and speculated tokens
|
|
assert output.total_num_scheduled_tokens == len(requests) + sum(
|
|
len(ids) for ids in spec_tokens
|
|
)
|
|
for i in range(len(requests)):
|
|
req_id = requests[i].request_id
|
|
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
|
|
if spec_tokens[i]:
|
|
assert len(output.scheduled_spec_decode_tokens[req_id]) == len(
|
|
spec_tokens[i]
|
|
)
|
|
else:
|
|
assert req_id not in output.scheduled_spec_decode_tokens
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=req_ids,
|
|
req_id_to_index=req_to_index,
|
|
sampled_token_ids=output_tokens,
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)
|
|
|
|
scheduler_stats = (
|
|
engine_core_outputs[0].scheduler_stats if engine_core_outputs else None
|
|
)
|
|
if expected[0] == 0:
|
|
assert scheduler_stats is not None
|
|
assert scheduler_stats.spec_decoding_stats is None
|
|
else:
|
|
assert scheduler_stats is not None
|
|
assert scheduler_stats.spec_decoding_stats is not None
|
|
stats = scheduler_stats.spec_decoding_stats
|
|
assert stats.num_drafts == expected[0]
|
|
assert stats.num_draft_tokens == expected[1]
|
|
assert stats.num_accepted_tokens == expected[2]
|
|
assert stats.num_accepted_tokens_per_pos == expected[3]
|
|
|
|
|
|
def _assert_right_scheduler_output(
|
|
output: SchedulerOutput,
|
|
num_requests: int,
|
|
expected_num_scheduled_tokens: int,
|
|
):
|
|
"""Check if SchedulerOutput is correct after remote KV cache hit."""
|
|
|
|
# We should inject the kv_connector_metadata.
|
|
assert len(output.kv_connector_metadata.requests) == num_requests
|
|
|
|
# Only num_tokens - matched_num_new_tokens should be scheduled.
|
|
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
|
|
assert num_scheduled_tokens == expected_num_scheduled_tokens
|
|
|
|
|
|
def _assert_right_kv_cache_manager(
|
|
scheduler: Scheduler,
|
|
requests: list[Request],
|
|
num_tokens: int,
|
|
block_size: int,
|
|
num_requests: int,
|
|
num_total_blocks: int,
|
|
):
|
|
"""Check whether KVCacheManager is correct after allocate."""
|
|
|
|
# Make sure the request stats are right.
|
|
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
|
|
for req in requests:
|
|
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
|
0
|
|
].req_to_blocks[req.request_id]
|
|
hashes = req.block_hashes
|
|
assert (
|
|
scheduler.kv_cache_manager.coordinator.single_type_managers[
|
|
0
|
|
].num_cached_block[req.request_id]
|
|
== EXPECTED_TOTAL_BLOCKS
|
|
)
|
|
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
|
|
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
|
|
|
|
# Make sure we actually touched all the blocks.
|
|
BLOCKS_PER_REQ = num_tokens / block_size
|
|
assert (
|
|
scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
|
|
== num_total_blocks - num_requests * BLOCKS_PER_REQ
|
|
)
|
|
|
|
|
|
def _step_until_done(
|
|
scheduler: Scheduler,
|
|
output: SchedulerOutput,
|
|
model_runner_output: ModelRunnerOutput,
|
|
):
|
|
"""Loop over schedule(), update_from_output() until finished."""
|
|
|
|
all_finished = False
|
|
_ = scheduler.update_from_output(output, model_runner_output)
|
|
while not all_finished:
|
|
# Schedule + a few iterations until stopping.
|
|
output = scheduler.schedule()
|
|
assert len(scheduler.running)
|
|
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
|
|
# We should be in the decode phase now.
|
|
assert num_scheduled_tokens == 1
|
|
if scheduler.connector is not None:
|
|
assert len(output.kv_connector_metadata.requests) == 0
|
|
if scheduler.ec_connector is not None:
|
|
assert len(output.ec_connector_metadata.mm_datas) == 0
|
|
ecos = scheduler.update_from_output(output, model_runner_output)[0]
|
|
all_done = True
|
|
for eco in ecos.outputs:
|
|
if eco.finish_reason is None:
|
|
all_done = False
|
|
all_finished = all_done
|
|
|
|
|
|
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
|
"""Cycle requests through a KV transfer cyle."""
|
|
|
|
# Requests should first transition to WAITING_FOR_REMOTE_KVS
|
|
output = scheduler.schedule()
|
|
assert len(scheduler.waiting) == len(req_ids)
|
|
assert len(scheduler.running) == 0
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
for req in scheduler.requests.values():
|
|
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
|
|
|
# No model execution yet
|
|
EMPTY_OUTPUT = ModelRunnerOutput(
|
|
req_ids=[],
|
|
req_id_to_index={},
|
|
sampled_token_ids=[],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, EMPTY_OUTPUT)
|
|
|
|
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
|
|
output = scheduler.schedule()
|
|
assert len(scheduler.waiting) == len(req_ids)
|
|
assert len(scheduler.running) == 0
|
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
req_ids=[],
|
|
req_id_to_index={},
|
|
sampled_token_ids=[],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
kv_connector_output=KVConnectorOutput(finished_recving=req_ids),
|
|
)
|
|
scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
for req_id in req_ids:
|
|
assert req_id in scheduler.finished_recving_kv_req_ids
|
|
|
|
|
|
@pytest.mark.parametrize("is_async", [False, True])
|
|
def test_kv_connector_basic(is_async: bool):
|
|
"""
|
|
Test whether Scheduler with KVConnector schedules tokens, allocates
|
|
memory, and cleans up requests as expected under normal operation.
|
|
"""
|
|
|
|
# Setup Scheduler.
|
|
BLOCK_SIZE = 16
|
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
|
scheduler = create_scheduler(
|
|
enable_prefix_caching=True,
|
|
use_kv_connector=mock_kv(
|
|
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
|
|
),
|
|
block_size=BLOCK_SIZE,
|
|
)
|
|
NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
|
|
|
|
######################################################
|
|
# FIRST SET OF REQUESTS - External Hit Only
|
|
NUM_REQUESTS = 2
|
|
NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2
|
|
MAX_TOKENS = 3
|
|
requests = create_requests(
|
|
num_requests=NUM_REQUESTS,
|
|
num_tokens=NUM_TOKENS,
|
|
max_tokens=MAX_TOKENS,
|
|
block_size=BLOCK_SIZE,
|
|
)
|
|
req_ids = []
|
|
req_to_index = {}
|
|
for i, request in enumerate(requests):
|
|
scheduler.add_request(request)
|
|
req_ids.append(request.request_id)
|
|
req_to_index[request.request_id] = i
|
|
|
|
if is_async:
|
|
_step_until_kv_transfer_finished(scheduler, req_ids)
|
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
req_ids=req_ids,
|
|
req_id_to_index=req_to_index,
|
|
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
|
|
# Ensure ScheduleOutput is correct.
|
|
output = scheduler.schedule()
|
|
_assert_right_scheduler_output(
|
|
output=output,
|
|
num_requests=NUM_REQUESTS,
|
|
# Just the incremental tokens should be scheduled.
|
|
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
|
|
)
|
|
|
|
# Ensure KVCacheManager is correct.
|
|
_assert_right_kv_cache_manager(
|
|
scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS
|
|
)
|
|
|
|
# Continue Generation until done.
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
_ = scheduler.schedule()
|
|
# Confirm we clean up the memory properly.
|
|
assert (
|
|
scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS
|
|
)
|
|
|
|
######################################################
|
|
# SECOND SET OF REQUESTS - Local And External Hit
|
|
NUM_TOKENS_PREFIX = NUM_TOKENS
|
|
# We will get a local prefix cache hit for the first
|
|
# NUM_TOKENS_PREFIX tokens since they are used above.
|
|
NUM_TOKENS = NUM_TOKENS_PREFIX * 2
|
|
requests = create_requests(
|
|
num_requests=NUM_REQUESTS,
|
|
num_tokens=NUM_TOKENS,
|
|
max_tokens=MAX_TOKENS,
|
|
block_size=BLOCK_SIZE,
|
|
)
|
|
req_ids = []
|
|
req_to_index = {}
|
|
for i, request in enumerate(requests):
|
|
scheduler.add_request(request)
|
|
req_ids.append(request.request_id)
|
|
req_to_index[request.request_id] = i
|
|
|
|
if is_async:
|
|
_step_until_kv_transfer_finished(scheduler, req_ids)
|
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
req_ids=req_ids,
|
|
req_id_to_index=req_to_index,
|
|
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
|
|
# We should get a local cache hit of NUM_TOKENS_PREFIX and
|
|
# a remote KV cache hit of NUM_MATCHED_NEW_TOKENS.
|
|
output = scheduler.schedule()
|
|
_assert_right_scheduler_output(
|
|
output=output,
|
|
num_requests=NUM_REQUESTS,
|
|
# Just the incremental tokens after local + remote cache hit.
|
|
expected_num_scheduled_tokens=(
|
|
NUM_TOKENS - NUM_TOKENS_PREFIX - NUM_MATCHED_NEW_TOKENS
|
|
),
|
|
)
|
|
|
|
# Ensure KVCacheManager is correct.
|
|
_assert_right_kv_cache_manager(
|
|
scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS
|
|
)
|
|
|
|
# Continue Generation until done.
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
_ = scheduler.schedule()
|
|
# Confirm we clean up the memory properly.
|
|
assert (
|
|
scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS
|
|
)
|
|
|
|
|
|
def test_external_prefix_cache_metrics():
|
|
"""
|
|
Verify connector prefix cache metrics are updated
|
|
correctly when the scheduler processes requests with KV connector hits.
|
|
"""
|
|
|
|
# Setup Scheduler.
|
|
NUM_MATCHED_NEW_TOKENS = 4
|
|
scheduler = create_scheduler(
|
|
enable_prefix_caching=False,
|
|
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
|
|
)
|
|
|
|
# --- Prepare simple requests ---
|
|
NUM_REQUESTS = 2
|
|
NUM_TOKENS = 8
|
|
MAX_TOKENS = 2
|
|
requests = create_requests(
|
|
num_requests=NUM_REQUESTS,
|
|
num_tokens=NUM_TOKENS,
|
|
max_tokens=MAX_TOKENS,
|
|
)
|
|
|
|
for req in requests:
|
|
scheduler.add_request(req)
|
|
|
|
# --- Trigger scheduling and simulate model output ---
|
|
output = scheduler.schedule()
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
req_ids=[r.request_id for r in requests],
|
|
req_id_to_index={r.request_id: i for i, r in enumerate(requests)},
|
|
sampled_token_ids=[np.array([1000])] * NUM_REQUESTS,
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
|
|
# Update scheduler stats
|
|
ecos = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
|
|
# --- Assertions ---
|
|
assert ecos is not None and len(ecos) > 0
|
|
assert ecos[0].scheduler_stats is not None
|
|
|
|
external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats
|
|
assert external_stats is not None
|
|
|
|
assert external_stats.queries == NUM_TOKENS * NUM_REQUESTS
|
|
assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS
|
|
assert external_stats.requests == NUM_REQUESTS
|
|
assert external_stats.preempted_requests == 0
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
|
|
)
|
|
def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
|
|
"""
|
|
Test whether scheduler with KVConnector is able to handle
|
|
unable to allocate (run out of blocks in allocate_slots().
|
|
"""
|
|
|
|
# Setup Scheduler With Mock External Cache Hit.
|
|
BLOCK_SIZE = 4
|
|
NUM_BLOCKS = 10
|
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
|
scheduler = create_scheduler(
|
|
enable_prefix_caching=True,
|
|
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
|
|
block_size=BLOCK_SIZE,
|
|
num_blocks=NUM_BLOCKS,
|
|
# encoder connector should not affect test results
|
|
use_ec_connector=use_ec_connector,
|
|
ec_role=ec_role,
|
|
)
|
|
|
|
# Create two requests. The second request will not be able to
|
|
# allocate slots because it will not have enough blocks.
|
|
NUM_REQUESTS = 2
|
|
NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE
|
|
MAX_TOKENS = 2
|
|
requests = create_requests(
|
|
num_requests=NUM_REQUESTS,
|
|
num_tokens=NUM_TOKENS,
|
|
max_tokens=MAX_TOKENS,
|
|
block_size=BLOCK_SIZE,
|
|
)
|
|
req_ids = []
|
|
req_to_index = {}
|
|
for i, request in enumerate(requests):
|
|
scheduler.add_request(request)
|
|
req_ids.append(request.request_id)
|
|
req_to_index[request.request_id] = i
|
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
req_ids=req_ids,
|
|
req_id_to_index=req_to_index,
|
|
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
|
|
# Just one request should be running.
|
|
output = scheduler.schedule()
|
|
_assert_right_scheduler_output(
|
|
output,
|
|
num_requests=1,
|
|
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
|
|
)
|
|
assert len(scheduler.running) == 1
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
# All memory should be freed, with one request waiting.
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
|
assert len(scheduler.running) == 0
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
# Just one request should be running.
|
|
output = scheduler.schedule()
|
|
_assert_right_scheduler_output(
|
|
output,
|
|
num_requests=1,
|
|
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
|
|
)
|
|
assert len(scheduler.running) == 1
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
# All memory should be freed, with no requests waiting / running.
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
|
assert len(scheduler.running) == 0
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
|
|
)
|
|
def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
|
|
"""
|
|
Test whether scheduler with KVConnector is able to handle
|
|
unable to allocate (run out of blocks in allocate_slots().
|
|
"""
|
|
|
|
# Setup Scheduler With Mock External Cache Hit.
|
|
BLOCK_SIZE = 2
|
|
# NOTE: there is 1 null block, so this is 6 blocks.
|
|
NUM_BLOCKS = 7
|
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
|
|
scheduler = create_scheduler(
|
|
enable_prefix_caching=True,
|
|
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
|
|
block_size=BLOCK_SIZE,
|
|
num_blocks=NUM_BLOCKS,
|
|
# encoder connector should not affect test results
|
|
use_ec_connector=use_ec_connector,
|
|
ec_role=ec_role,
|
|
)
|
|
|
|
# Create two requests.
|
|
# Both can be scheduled at first, but the second request
|
|
# will be preempted and re-scheduled.
|
|
NUM_REQUESTS = 2
|
|
NUM_TOKENS = BLOCK_SIZE * 2 + 1
|
|
MAX_TOKENS = BLOCK_SIZE * 2
|
|
requests = create_requests(
|
|
num_requests=NUM_REQUESTS,
|
|
num_tokens=NUM_TOKENS,
|
|
max_tokens=MAX_TOKENS,
|
|
block_size=BLOCK_SIZE,
|
|
)
|
|
req_ids = []
|
|
req_to_index = {}
|
|
for i, request in enumerate(requests):
|
|
scheduler.add_request(request)
|
|
req_ids.append(request.request_id)
|
|
req_to_index[request.request_id] = i
|
|
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
req_ids=req_ids,
|
|
req_id_to_index=req_to_index,
|
|
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
|
|
# All can be scheduled - 1st token.
|
|
output = scheduler.schedule()
|
|
_assert_right_scheduler_output(
|
|
output,
|
|
# 2 remote kv cache hits.
|
|
num_requests=2,
|
|
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
|
|
)
|
|
assert len(scheduler.running) == 2
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
|
|
# All can be scheduled - 2nd token.
|
|
output = scheduler.schedule()
|
|
_assert_right_scheduler_output(
|
|
output,
|
|
# no connector_metadata
|
|
num_requests=0,
|
|
expected_num_scheduled_tokens=1,
|
|
)
|
|
assert len(scheduler.running) == 2
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
|
|
# This will generate a new block and cause a preemption - 3rd token.
|
|
output = scheduler.schedule()
|
|
_assert_right_scheduler_output(
|
|
output,
|
|
# no connector_metadata
|
|
num_requests=0,
|
|
expected_num_scheduled_tokens=1,
|
|
)
|
|
assert len(scheduler.running) == 1
|
|
assert len(scheduler.waiting) == 1
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
assert len(scheduler.running) == 1
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
# Only 1 can be scheduled - 4th (and last token).
|
|
output = scheduler.schedule()
|
|
_assert_right_scheduler_output(
|
|
output,
|
|
# no connector_metadata
|
|
num_requests=0,
|
|
expected_num_scheduled_tokens=1,
|
|
)
|
|
assert len(scheduler.waiting) == 1
|
|
assert len(scheduler.running) == 1
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
assert len(scheduler.running) == 0
|
|
# All memory should be freed since nothing is running.
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
|
|
|
# Restarts the preempted request - generate 3rd token.
|
|
# This will have a local and remote cache hit.
|
|
output = scheduler.schedule()
|
|
_assert_right_scheduler_output(
|
|
output,
|
|
# 1 remote kv_cache hit!
|
|
num_requests=1,
|
|
# Only 1 block was preempted and there is a single
|
|
# remote hit. So only single new token is scheduled.
|
|
expected_num_scheduled_tokens=1,
|
|
)
|
|
assert len(scheduler.running) == 1
|
|
assert len(scheduler.waiting) == 0
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
assert len(scheduler.running) == 1
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
# Only 1 can be scheduled - 4th (and last token).
|
|
output = scheduler.schedule()
|
|
_assert_right_scheduler_output(
|
|
output,
|
|
# no connector_metadata
|
|
num_requests=0,
|
|
expected_num_scheduled_tokens=1,
|
|
)
|
|
assert len(scheduler.running) == 1
|
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
|
assert len(scheduler.running) == 0
|
|
# All memory should be freed since nothing is running.
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
|
|
|
|
|
def make_output(scheduler: Scheduler):
|
|
return ModelRunnerOutput(
|
|
req_ids=[req.request_id for req in scheduler.running],
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)},
|
|
sampled_token_ids=[np.array([1000])] * len(scheduler.running),
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
|
|
|
|
def assert_scheduler_empty(scheduler: Scheduler):
|
|
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
|
# Scheduler Metadata.
|
|
assert len(scheduler.requests) == 0
|
|
assert len(scheduler.waiting) == 0
|
|
assert len(scheduler.running) == 0
|
|
assert len(scheduler.finished_req_ids) == 0
|
|
|
|
# EncoderCacheManager.
|
|
assert len(scheduler.encoder_cache_manager.freed) == 0
|
|
assert len(scheduler.encoder_cache_manager.cached) == 0
|
|
|
|
# KVCache Manager.
|
|
assert (
|
|
len(
|
|
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks
|
|
)
|
|
== 0
|
|
)
|
|
assert (
|
|
len(
|
|
scheduler.kv_cache_manager.coordinator.single_type_managers[
|
|
0
|
|
].num_cached_block
|
|
)
|
|
== 0
|
|
)
|
|
num_free_blocks = (
|
|
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
|
|
)
|
|
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
|
|
|
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
|
# value, etc will remain since we lazily evict for prefix cache.
|
|
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
|
assert block.ref_cnt == 0
|
|
# assert block._block_hash is None
|
|
# assert (
|
|
# len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
|
|
# ) == 0)
|
|
|
|
|
|
def test_memory_leak():
|
|
"""Test that we do not have a memory leak."""
|
|
|
|
scheduler = create_scheduler(enable_prefix_caching=True)
|
|
|
|
NUM_REQUESTS = 5
|
|
NUM_TOKENS = 10
|
|
MAX_TOKENS = 10
|
|
requests = create_requests(
|
|
num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS
|
|
)
|
|
|
|
# Add each request.
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
scheduler_output = scheduler.schedule()
|
|
model_runner_output = make_output(scheduler)
|
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
|
|
|
# Iterate until done.
|
|
while True:
|
|
scheduler_output = scheduler.schedule()
|
|
if len(scheduler.running) == 0:
|
|
break
|
|
model_runner_output = make_output(scheduler)
|
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
|
|
|
# Confirm no memory leak.
|
|
assert_scheduler_empty(scheduler)
|
|
|
|
|
|
def create_scheduler_with_priority(
|
|
model: str = "facebook/opt-125m",
|
|
max_num_seqs: int = 16,
|
|
max_num_batched_tokens: int = 8192,
|
|
enable_prefix_caching: bool | None = None,
|
|
long_prefill_token_threshold: int = 0,
|
|
disable_chunked_mm_input: bool = False,
|
|
use_kv_connector: bool = False,
|
|
num_blocks: int = 10000,
|
|
block_size: int = 16,
|
|
max_model_len: int | None = None,
|
|
num_speculative_tokens: int | None = None,
|
|
use_ec_connector: bool = False,
|
|
ec_role: str | None = None,
|
|
) -> Scheduler:
|
|
"""Create scheduler with priority policy enabled.
|
|
|
|
Args:
|
|
model: model under test
|
|
max_num_seqs: max sequences to schedule
|
|
max_num_batch_tokens: max num tokens to batch
|
|
enable_prefix_caching: optionally force APC config
|
|
(True/False) or use default
|
|
(None)
|
|
|
|
Returns:
|
|
{class}`Scheduler` instance with priority scheduling
|
|
"""
|
|
if max_model_len is None:
|
|
max_model_len = max_num_batched_tokens
|
|
scheduler_config = SchedulerConfig(
|
|
max_num_seqs=max_num_seqs,
|
|
max_num_batched_tokens=max_num_batched_tokens,
|
|
max_model_len=max_model_len,
|
|
long_prefill_token_threshold=long_prefill_token_threshold,
|
|
disable_chunked_mm_input=disable_chunked_mm_input,
|
|
enable_chunked_prefill=True,
|
|
policy="priority", # Enable priority scheduling
|
|
)
|
|
model_config = ModelConfig(
|
|
model=model,
|
|
trust_remote_code=True,
|
|
dtype="float16",
|
|
seed=42,
|
|
)
|
|
# Cache config, optionally force APC
|
|
kwargs_cache = (
|
|
{}
|
|
if enable_prefix_caching is None
|
|
else {"enable_prefix_caching": enable_prefix_caching}
|
|
)
|
|
cache_config = CacheConfig(
|
|
block_size=block_size,
|
|
gpu_memory_utilization=0.9,
|
|
swap_space=0,
|
|
cache_dtype="auto",
|
|
**kwargs_cache,
|
|
)
|
|
kv_transfer_config = (
|
|
KVTransferConfig(
|
|
kv_connector="SharedStorageConnector",
|
|
kv_role="kv_both",
|
|
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
|
)
|
|
if use_kv_connector
|
|
else None
|
|
)
|
|
|
|
speculative_config: SpeculativeConfig | None = None
|
|
if num_speculative_tokens is not None:
|
|
speculative_config = SpeculativeConfig(
|
|
model="ngram", num_speculative_tokens=num_speculative_tokens
|
|
)
|
|
|
|
ec_transfer_config = (
|
|
ECTransferConfig(
|
|
ec_connector="ECSharedStorageConnector",
|
|
ec_role=ec_role,
|
|
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
|
|
)
|
|
if use_ec_connector
|
|
else None
|
|
)
|
|
|
|
vllm_config = VllmConfig(
|
|
scheduler_config=scheduler_config,
|
|
model_config=model_config,
|
|
cache_config=cache_config,
|
|
kv_transfer_config=kv_transfer_config,
|
|
speculative_config=speculative_config,
|
|
ec_transfer_config=ec_transfer_config,
|
|
)
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
|
)
|
|
],
|
|
)
|
|
cache_config.num_gpu_blocks = num_blocks
|
|
return Scheduler(
|
|
vllm_config=vllm_config,
|
|
kv_cache_config=kv_cache_config,
|
|
log_stats=True,
|
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
|
block_size=block_size,
|
|
)
|
|
|
|
|
|
_none_hash_initialized = False
|
|
|
|
|
|
def create_requests_with_priority(
|
|
num_requests: int,
|
|
priorities: list[int],
|
|
arrival_times: list[float] | None = None,
|
|
num_tokens: int = 10,
|
|
mm_hashes_list: list[list[str]] | None = None,
|
|
mm_positions: list[list[PlaceholderRange]] | None = None,
|
|
max_tokens: int = 16,
|
|
stop_token_ids: list[int] | None = None,
|
|
prompt_logprobs: int | None = None,
|
|
starting_idx: int = 0,
|
|
same_prompt: bool = False,
|
|
block_size: int = 16,
|
|
req_ids: list[str] | None = None,
|
|
):
|
|
"""Create requests with specified priorities and arrival times."""
|
|
assert len(priorities) == num_requests
|
|
if arrival_times is not None:
|
|
assert len(arrival_times) == num_requests
|
|
else:
|
|
arrival_times = [float(i) for i in range(num_requests)]
|
|
|
|
global _none_hash_initialized
|
|
if not _none_hash_initialized:
|
|
init_none_hash(sha256)
|
|
_none_hash_initialized = True
|
|
|
|
block_hasher = get_request_block_hasher(block_size, sha256)
|
|
sampling_params = SamplingParams(
|
|
ignore_eos=False,
|
|
max_tokens=max_tokens,
|
|
stop_token_ids=stop_token_ids,
|
|
prompt_logprobs=prompt_logprobs,
|
|
)
|
|
requests = []
|
|
|
|
if mm_hashes_list is not None:
|
|
# NOTE: allow manual input; some mm items can have the same identifier
|
|
# no. of mm_hashes and mm_positions for each request should be identical
|
|
assert mm_positions is not None, (
|
|
"mm_positions must be provided when mm_hashes_list is provided"
|
|
)
|
|
assert len(mm_hashes_list) == len(mm_positions) == num_requests
|
|
assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions]
|
|
|
|
# Since same identifier would imply they are identical encoder output
|
|
# Verify mm items with identical identifier are having mm_position.length
|
|
seen_hashes: dict[str, int] = {}
|
|
|
|
if req_ids:
|
|
assert len(req_ids) == num_requests
|
|
else:
|
|
req_ids = [f"{i + starting_idx}" for i in range(num_requests)]
|
|
|
|
for i in range(num_requests):
|
|
mm_features = []
|
|
|
|
for j, position in enumerate(
|
|
mm_positions[i] if mm_positions is not None else []
|
|
):
|
|
if mm_hashes_list is not None:
|
|
identifier = mm_hashes_list[i][j]
|
|
|
|
# Verify if position length is identical
|
|
position_length = position.length
|
|
if identifier in seen_hashes:
|
|
assert seen_hashes[identifier] == position_length, (
|
|
f"mm_hash '{identifier}' has inconsistent position lengths: "
|
|
f"previously {seen_hashes[identifier]}, now {position_length} "
|
|
f"at request {i}, position {j}"
|
|
)
|
|
else:
|
|
seen_hashes[identifier] = position_length
|
|
else:
|
|
# Unique dummy hash for each mm item
|
|
identifier = f"hash{i}_{j}"
|
|
mm_feature = MultiModalFeatureSpec(
|
|
data=MultiModalKwargsItem.dummy("dummy_m"),
|
|
mm_position=position,
|
|
identifier=identifier,
|
|
modality="image",
|
|
)
|
|
mm_features.append(mm_feature)
|
|
|
|
prompt_token_ids = (
|
|
[starting_idx] * num_tokens
|
|
if same_prompt
|
|
else [i + starting_idx] * num_tokens
|
|
)
|
|
request = Request(
|
|
request_id=req_ids[i],
|
|
prompt_token_ids=prompt_token_ids,
|
|
sampling_params=sampling_params,
|
|
pooling_params=None,
|
|
mm_features=mm_features if mm_features else None,
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
arrival_time=arrival_times[i],
|
|
priority=priorities[i],
|
|
block_hasher=block_hasher,
|
|
)
|
|
requests.append(request)
|
|
return requests
|
|
|
|
|
|
def test_priority_scheduling_basic_ordering():
|
|
"""Test that requests are scheduled in priority order
|
|
(lower value = higher priority)."""
|
|
scheduler = create_scheduler_with_priority()
|
|
|
|
# Create requests with different priorities
|
|
# Priority 0 (highest), 1, 2 (lowest)
|
|
priorities = [2, 0, 1] # Add in non-priority order
|
|
arrival_times = [1.0, 2.0, 3.0] # All different arrival times
|
|
requests = create_requests_with_priority(
|
|
num_requests=3, priorities=priorities, arrival_times=arrival_times
|
|
)
|
|
|
|
# Add requests in non-priority order
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
# Schedule and verify priority order
|
|
output = scheduler.schedule()
|
|
|
|
# Should schedule all requests since they fit in budget
|
|
assert len(output.scheduled_new_reqs) == 3
|
|
|
|
# Verify they are scheduled in priority order:
|
|
# req_1 (priority 0), req_2 (priority 1), req_0 (priority 2)
|
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
|
assert scheduled_req_ids == ["1", "2", "0"]
|
|
|
|
|
|
def test_priority_scheduling_arrival_time_tiebreaker():
|
|
"""Test that arrival time is used
|
|
as tiebreaker when priorities are equal."""
|
|
scheduler = create_scheduler_with_priority()
|
|
|
|
# Create requests with same priority but different arrival times
|
|
priorities = [1, 1, 1] # All same priority
|
|
arrival_times = [3.0, 1.0, 2.0] # Different arrival times
|
|
requests = create_requests_with_priority(
|
|
num_requests=3, priorities=priorities, arrival_times=arrival_times
|
|
)
|
|
|
|
# Add requests in non-arrival order
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
# Schedule and verify arrival time order
|
|
output = scheduler.schedule()
|
|
|
|
# Should schedule all requests since they fit in budget
|
|
assert len(output.scheduled_new_reqs) == 3
|
|
|
|
# Verify they are scheduled in arrival time order:
|
|
# req_1 (1.0), req_2 (2.0), req_0 (3.0)
|
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
|
assert scheduled_req_ids == ["1", "2", "0"]
|
|
|
|
|
|
def test_priority_scheduling_mixed_priority_and_arrival():
|
|
"""Test priority scheduling with mixed priorities and arrival times."""
|
|
scheduler = create_scheduler_with_priority()
|
|
|
|
# Create requests with mixed priorities and arrival times
|
|
priorities = [2, 1, 1, 0] # Mixed priorities
|
|
arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times
|
|
requests = create_requests_with_priority(
|
|
num_requests=4, priorities=priorities, arrival_times=arrival_times
|
|
)
|
|
|
|
# Add requests
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
# Schedule and verify order
|
|
output = scheduler.schedule()
|
|
|
|
# Should schedule all requests since they fit in budget
|
|
assert len(output.scheduled_new_reqs) == 4
|
|
|
|
# Expected order:
|
|
# 1. req_3 (priority 0, arrival 4.0)
|
|
# 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1
|
|
# 3. req_1 (priority 1, arrival 3.0)
|
|
# 4. req_0 (priority 2, arrival 1.0)
|
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
|
assert scheduled_req_ids == ["3", "2", "1", "0"]
|
|
|
|
|
|
def test_priority_scheduling_preemption():
|
|
"""Test that priority scheduling preempts
|
|
lower priority requests when memory is constrained."""
|
|
# Create scheduler with very limited memory to force preemption
|
|
scheduler = create_scheduler_with_priority(
|
|
max_num_seqs=3, # Allow multiple requests
|
|
max_num_batched_tokens=200,
|
|
num_blocks=6, # Very limited blocks to force memory pressure
|
|
block_size=16, # Standard block size
|
|
)
|
|
|
|
# Create initial low-priority requests that will consume most memory
|
|
low_priority_requests = create_requests_with_priority(
|
|
num_requests=2,
|
|
priorities=[5, 5], # Low priority
|
|
arrival_times=[1.0, 2.0],
|
|
num_tokens=30, # Large enough to consume significant memory
|
|
)
|
|
|
|
# Add and schedule low priority requests
|
|
for request in low_priority_requests:
|
|
scheduler.add_request(request)
|
|
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == 2
|
|
|
|
# Simulate model execution to move requests to running state
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[req.request_id for req in low_priority_requests],
|
|
req_id_to_index={
|
|
req.request_id: i for i, req in enumerate(low_priority_requests)
|
|
},
|
|
sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
# Verify both requests are running
|
|
assert len(scheduler.running) == 2
|
|
|
|
# Now add a high-priority request that requires memory allocation
|
|
# This should trigger preemption due to memory constraints
|
|
high_priority_request = create_requests_with_priority(
|
|
num_requests=1,
|
|
priorities=[0], # High priority
|
|
arrival_times=[3.0],
|
|
num_tokens=30, # Large enough to require significant memory
|
|
)[0]
|
|
|
|
scheduler.add_request(high_priority_request)
|
|
|
|
# Schedule again - this should trigger
|
|
# preemption when trying to allocate memory
|
|
output = scheduler.schedule()
|
|
|
|
# Due to the scheduler's design, if preemption happens
|
|
# during running request scheduling,
|
|
# waiting requests won't be scheduled in the same step
|
|
# Let's check if preemption occurred by looking at the waiting queue
|
|
|
|
# If preemption happened, we should see requests in the
|
|
# waiting queue
|
|
if len(scheduler.waiting) > 1: # high priority + preempted request
|
|
# Preemption occurred - verify the high priority request
|
|
# gets scheduled next
|
|
output2 = scheduler.schedule()
|
|
assert len(output2.scheduled_new_reqs) == 1
|
|
# High priority request
|
|
assert output2.scheduled_new_reqs[0].req_id == "0"
|
|
else:
|
|
# No preemption needed - all requests fit
|
|
# This is also valid behavior if memory allows
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
# High priority request
|
|
assert output.scheduled_new_reqs[0].req_id == "0"
|
|
|
|
|
|
def test_priority_scheduling_no_preemption_when_space_available():
|
|
"""Test that preemption doesn't happen
|
|
when there's space for new requests."""
|
|
scheduler = create_scheduler_with_priority(
|
|
max_num_seqs=3, # Allow 3 concurrent requests
|
|
max_num_batched_tokens=200, # Sufficient token budget
|
|
)
|
|
|
|
# Add two low-priority running requests
|
|
low_priority_requests = create_requests_with_priority(
|
|
num_requests=2, priorities=[5, 5], arrival_times=[1.0, 2.0], num_tokens=30
|
|
)
|
|
|
|
for request in low_priority_requests:
|
|
scheduler.add_request(request)
|
|
|
|
output = scheduler.schedule()
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[req.request_id for req in low_priority_requests],
|
|
req_id_to_index={
|
|
req.request_id: i for i, req in enumerate(low_priority_requests)
|
|
},
|
|
sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
# Add high-priority request
|
|
high_priority_request = create_requests_with_priority(
|
|
num_requests=1, priorities=[0], arrival_times=[3.0], num_tokens=30
|
|
)[0]
|
|
|
|
scheduler.add_request(high_priority_request)
|
|
|
|
# Schedule - should not preempt since there's space
|
|
output = scheduler.schedule()
|
|
|
|
# Should schedule the new request without preemption
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
assert len(scheduler.running) == 3 # All three requests running
|
|
assert len(scheduler.waiting) == 0 # No requests waiting
|
|
|
|
|
|
def test_priority_scheduling_preemption_victim_selection():
|
|
"""Test that the correct victim is selected for
|
|
preemption based on priority and arrival time."""
|
|
# This test verifies the priority-based victim selection logic
|
|
# by checking the waiting queue order after adding requests with different
|
|
# priorities
|
|
scheduler = create_scheduler_with_priority(
|
|
max_num_seqs=1, # Force sequential processing to test priority order
|
|
)
|
|
|
|
# Create requests with different priorities
|
|
requests = create_requests_with_priority(
|
|
num_requests=3,
|
|
priorities=[3, 2, 0], # Different priorities: low, medium, high
|
|
arrival_times=[1.0, 2.0, 3.0],
|
|
num_tokens=10,
|
|
)
|
|
|
|
# Add all requests
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
# Schedule - should only schedule the highest priority request
|
|
# (req_2, priority 0)
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority
|
|
|
|
# Verify the waiting queue has the remaining requests in priority order
|
|
assert len(scheduler.waiting) == 2
|
|
|
|
# Extract waiting requests and verify priority order
|
|
waiting_requests = list(scheduler.waiting)
|
|
|
|
waiting_priorities = [req.priority for req in waiting_requests]
|
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
|
|
|
# Should be req_1 (priority 2) then req_0 (priority 3)
|
|
assert waiting_priorities == [2, 3]
|
|
assert waiting_req_ids == ["1", "0"]
|
|
|
|
|
|
def test_priority_scheduling_equal_priority_preemption():
|
|
"""Test arrival time tiebreaker when requests have equal priority."""
|
|
# This test verifies that arrival time is used as a tiebreaker for equal
|
|
# priorities
|
|
scheduler = create_scheduler_with_priority(
|
|
max_num_seqs=1, # Force sequential processing
|
|
)
|
|
|
|
# Create requests with same priority but different arrival times
|
|
requests = create_requests_with_priority(
|
|
num_requests=3,
|
|
priorities=[2, 2, 2], # Same priority
|
|
arrival_times=[3.0, 1.0, 2.0], # Different arrival times
|
|
num_tokens=10,
|
|
)
|
|
|
|
# Add all requests
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
# Schedule - should schedule the request with earliest arrival time
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0)
|
|
|
|
# Verify the waiting queue has remaining requests in arrival time order
|
|
assert len(scheduler.waiting) == 2
|
|
|
|
# Extract waiting requests and verify arrival time order
|
|
waiting_requests = list(scheduler.waiting)
|
|
|
|
waiting_arrival_times = [req.arrival_time for req in waiting_requests]
|
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
|
|
|
# Should be req_2 (arrival 2.0) then req_0 (arrival 3.0)
|
|
assert waiting_arrival_times == [2.0, 3.0]
|
|
assert waiting_req_ids == ["2", "0"]
|
|
|
|
|
|
def test_priority_scheduling_waiting_queue_order():
|
|
"""Test that the waiting queue maintains priority order."""
|
|
scheduler = create_scheduler_with_priority(
|
|
max_num_seqs=1, # Only one request can run at a time
|
|
)
|
|
|
|
# Create multiple requests with different priorities
|
|
requests = create_requests_with_priority(
|
|
num_requests=4,
|
|
priorities=[3, 1, 2, 0], # Mixed priorities
|
|
arrival_times=[1.0, 2.0, 3.0, 4.0],
|
|
num_tokens=10,
|
|
)
|
|
|
|
# Add all requests
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
# Schedule - should only schedule the highest priority request
|
|
# (req_3, priority 0)
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
assert output.scheduled_new_reqs[0].req_id == "3"
|
|
|
|
# Verify waiting queue has remaining requests in priority order
|
|
assert len(scheduler.waiting) == 3
|
|
|
|
# Extract requests from waiting queue
|
|
# (it's a heap, so we need to pop to see order)
|
|
waiting_requests = list(scheduler.waiting)
|
|
|
|
waiting_priorities = [req.priority for req in waiting_requests]
|
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
|
|
|
# Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3)
|
|
assert waiting_req_ids == ["1", "2", "0"]
|
|
assert waiting_priorities == [1, 2, 3]
|
|
|
|
|
|
def test_priority_scheduling_fcfs_fallback():
|
|
"""Test that FCFS behavior is maintained when all
|
|
requests have same priority."""
|
|
scheduler = create_scheduler_with_priority()
|
|
|
|
# Create requests with same priority but different arrival times
|
|
priorities = [1, 1, 1, 1] # All same priority
|
|
arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times
|
|
requests = create_requests_with_priority(
|
|
num_requests=4, priorities=priorities, arrival_times=arrival_times
|
|
)
|
|
|
|
# Add requests
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
# Schedule
|
|
output = scheduler.schedule()
|
|
|
|
# Should schedule all requests in arrival time order
|
|
assert len(output.scheduled_new_reqs) == 4
|
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
|
|
|
# Expected order by arrival time:
|
|
# req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0)
|
|
assert scheduled_req_ids == ["1", "3", "2", "0"]
|
|
|
|
|
|
def test_priority_scheduling_with_limited_slots():
|
|
"""Test priority scheduling when max_num_seqs limits concurrent requests."""
|
|
scheduler = create_scheduler_with_priority(
|
|
max_num_seqs=2, # Only allow 2 concurrent requests
|
|
max_num_batched_tokens=1000, # Plenty of token budget
|
|
)
|
|
|
|
# Create requests with different priorities
|
|
requests = create_requests_with_priority(
|
|
num_requests=4,
|
|
priorities=[3, 1, 2, 0], # Mixed priorities
|
|
arrival_times=[1.0, 2.0, 3.0, 4.0],
|
|
num_tokens=10,
|
|
)
|
|
|
|
# Add all requests
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
# Schedule - should only schedule the 2 highest priority requests
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == 2
|
|
|
|
# Should schedule req_3 (priority 0) and req_1 (priority 1)
|
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
|
assert "3" in scheduled_req_ids # Priority 0
|
|
assert "1" in scheduled_req_ids # Priority 1
|
|
|
|
# Remaining requests should be in waiting queue in priority order
|
|
assert len(scheduler.waiting) == 2
|
|
|
|
# Extract waiting requests and verify order
|
|
waiting_requests = list(scheduler.waiting)
|
|
waiting_priorities = [req.priority for req in waiting_requests]
|
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
|
|
|
# Should be req_2 (priority 2) then req_0 (priority 3)
|
|
assert waiting_priorities == [2, 3]
|
|
assert waiting_req_ids == ["2", "0"]
|
|
|
|
|
|
def test_priority_scheduling_heap_property():
|
|
"""Test that the waiting queue maintains heap
|
|
property for priority scheduling."""
|
|
scheduler = create_scheduler_with_priority(
|
|
max_num_seqs=1, # Only one request can run at a time
|
|
)
|
|
|
|
# Add requests in random priority order
|
|
priorities = [5, 1, 8, 3, 2, 7, 4, 6]
|
|
arrival_times = [float(i) for i in range(len(priorities))]
|
|
requests = create_requests_with_priority(
|
|
num_requests=len(priorities),
|
|
priorities=priorities,
|
|
arrival_times=arrival_times,
|
|
num_tokens=10,
|
|
)
|
|
|
|
# Add all requests
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
# Schedule one request at a time and verify priority order
|
|
scheduled_priorities = []
|
|
|
|
while scheduler.waiting:
|
|
output = scheduler.schedule()
|
|
if output.scheduled_new_reqs:
|
|
req = output.scheduled_new_reqs[0]
|
|
scheduled_priorities.append(requests[int(req.req_id)].priority)
|
|
|
|
# Simulate completion to make room for next request
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[req.req_id],
|
|
req_id_to_index={req.req_id: 0},
|
|
sampled_token_ids=[np.array([100])],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
# Finish the request to make room for the next one
|
|
scheduler.finish_requests(req.req_id, RequestStatus.FINISHED_STOPPED)
|
|
|
|
# Verify requests were scheduled in priority order (lowest value first)
|
|
expected_priorities = sorted(priorities)
|
|
assert scheduled_priorities == expected_priorities
|
|
|
|
|
|
def test_schedule_skip_tokenizer_init():
|
|
scheduler = create_scheduler(skip_tokenizer_init=True)
|
|
requests = create_requests(num_requests=5)
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == len(requests)
|
|
|
|
|
|
def test_schedule_skip_tokenizer_init_structured_output_request():
|
|
scheduler = create_scheduler(skip_tokenizer_init=True)
|
|
structured_outputs_params = StructuredOutputsParams(regex="[0-9]+")
|
|
sampling_params = SamplingParams(
|
|
ignore_eos=False,
|
|
max_tokens=16,
|
|
structured_outputs=structured_outputs_params,
|
|
)
|
|
request = Request(
|
|
request_id="0",
|
|
prompt_token_ids=[0, 1],
|
|
mm_features=None,
|
|
sampling_params=sampling_params,
|
|
pooling_params=None,
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
)
|
|
scheduler.add_request(request)
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
assert len(scheduler.running) == 0
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
|
|
)
|
|
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
|
|
use_ec_connector, ec_role
|
|
):
|
|
"""Test that priority scheduling preempts lower priority requests
|
|
when out of KV cache space."""
|
|
# Create scheduler with very limited memory to force preemption
|
|
scheduler = create_scheduler_with_priority(
|
|
max_num_seqs=2, # Allow multiple requests
|
|
max_num_batched_tokens=200,
|
|
num_blocks=5, # Can hold 64 tokens (first block is null)
|
|
block_size=16, # Standard block size
|
|
use_kv_connector=True,
|
|
# encoder connector should not affect test results
|
|
use_ec_connector=use_ec_connector,
|
|
ec_role=ec_role,
|
|
)
|
|
|
|
# Create a request and schedule it
|
|
request_low = create_requests_with_priority(
|
|
num_requests=1,
|
|
priorities=[1],
|
|
arrival_times=[0.0],
|
|
num_tokens=30,
|
|
starting_idx=0,
|
|
)[0]
|
|
scheduler.add_request(request_low)
|
|
# 1st schedule
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
assert len(scheduler.waiting) == 0
|
|
assert len(scheduler.running) == 1
|
|
|
|
# Simulate model execution - 1st decode
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[request_low.request_id],
|
|
req_id_to_index={request_low.request_id: 0},
|
|
sampled_token_ids=[np.array([100])],
|
|
# spec_token_ids=None,
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
# Create a high priority request and schedule it
|
|
request_high = create_requests_with_priority(
|
|
num_requests=1,
|
|
priorities=[0],
|
|
arrival_times=[1.0],
|
|
num_tokens=32,
|
|
starting_idx=1,
|
|
)[0]
|
|
scheduler.add_request(request_high)
|
|
# 2nd schedule
|
|
output = scheduler.schedule()
|
|
# KV cache should be full at this point
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
assert output.scheduled_cached_reqs.num_reqs == 1
|
|
assert len(scheduler.waiting) == 0
|
|
assert len(scheduler.running) == 2
|
|
|
|
# Simulate model execution - 2nd decode
|
|
requests = [request_low, request_high]
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[req.request_id for req in requests],
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
|
sampled_token_ids=[np.array([100]) for _ in requests],
|
|
# spec_token_ids=None,
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
# 3rd schedule - this should trigger preemption
|
|
# req_low needs 32 tokens = 2 blocks
|
|
# req_high needs 33 tokens = 3 blocks
|
|
# so doesn't fit in 4 blocks.
|
|
output = scheduler.schedule()
|
|
|
|
# Should have preempted req_low
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
assert output.scheduled_cached_reqs.num_reqs == 1
|
|
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
|
|
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
|
|
assert len(scheduler.waiting) == 1
|
|
assert len(scheduler.running) == 1
|
|
|
|
# Simulate model execution - 3rd decode
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[req.request_id for req in requests],
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
|
sampled_token_ids=[np.array([]), np.array([100])],
|
|
# spec_token_ids=None,
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
# Finish the requests to make room for the preempted requests to resume
|
|
scheduler.update_from_output(output, model_output)
|
|
scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED)
|
|
|
|
# 4th Schedule - this should trigger the resumption
|
|
output = scheduler.schedule()
|
|
scheduled_cached_reqs = output.scheduled_cached_reqs
|
|
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
|
|
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
assert scheduled_cached_reqs.num_reqs == 1
|
|
assert len(scheduler.waiting) == 0
|
|
assert len(scheduler.running) == 1
|
|
|
|
# Preempted request resumed in scheduled_cached_reqs
|
|
assert len(resumed_from_preemption) == 1
|
|
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
|
|
assert resumed_from_preemption[0]
|
|
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
|
|
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
|
|
# Resumed tokens include 30 prompt tokens and 2 decoded tokens
|
|
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32
|
|
assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"),
|
|
[
|
|
(True, False, True),
|
|
(False, False, False),
|
|
# Encoder-decoder models should always have it disabled
|
|
(False, True, False),
|
|
(True, True, False),
|
|
],
|
|
)
|
|
def test_chunked_prefill_disabled_for_encoder_decoder(
|
|
enable_chunked_prefill: bool, is_encoder_decoder: bool, expect_enabled: bool
|
|
) -> None:
|
|
"""Validate that chunked prefill is appropriately disabled for
|
|
encoder-decoder models."""
|
|
scheduler_config = SchedulerConfig(
|
|
enable_chunked_prefill=enable_chunked_prefill,
|
|
is_encoder_decoder=is_encoder_decoder,
|
|
# Must <= max_num_batched_tokens if chunked prefill is disabled
|
|
max_model_len=SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
|
)
|
|
|
|
# `is_encoder_decoder` should only be used during construction
|
|
# of the config, and otherwise stored in the model config.
|
|
assert "is_encoder_decoder" not in vars(scheduler_config)
|
|
assert "is_encoder_decoder" not in [
|
|
f.name for f in dataclasses.fields(scheduler_config)
|
|
]
|
|
_validate_chunked_prefill_settings_for_encoder_decoder(
|
|
scheduler_config, is_encoder_decoder, expect_enabled
|
|
)
|
|
|
|
# Ensure it is retained in VllmConfig, even after its post-init.
|
|
vllm_config = VllmConfig(scheduler_config=scheduler_config)
|
|
_validate_chunked_prefill_settings_for_encoder_decoder(
|
|
vllm_config.scheduler_config, is_encoder_decoder, expect_enabled
|
|
)
|
|
|
|
|
|
def _validate_chunked_prefill_settings_for_encoder_decoder(
|
|
scheduler_config: SchedulerConfig, is_encoder_decoder: bool, expect_enabled: bool
|
|
) -> None:
|
|
"""Validate chunked prefill settings in the scheduler config for
|
|
encoder-decoder models."""
|
|
assert scheduler_config.enable_chunked_prefill is expect_enabled
|
|
if is_encoder_decoder:
|
|
# Encoder-decoder models should automatically disable chunked multimodal
|
|
# inputs as well
|
|
assert scheduler_config.disable_chunked_mm_input is not expect_enabled
|
|
if is_encoder_decoder and not expect_enabled:
|
|
assert scheduler_config.long_prefill_token_threshold == 0
|
|
|
|
|
|
# ==============================================================================
|
|
# EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests start
|
|
# NOTE: In E->P->D disagg case, both KV and EC Connector works in P instance
|
|
# Unless specify, the existence of KV Connector should not affect any test results
|
|
# ==============================================================================
|
|
|
|
|
|
def _assert_right_encoder_cache_allocated(
|
|
scheduler: Scheduler,
|
|
hashes_to_check: list[str] | None = None,
|
|
requests: list[Request] | None = None,
|
|
expected_total_allocated: int | None = None,
|
|
):
|
|
"""Check whether encoder cache is allocated correctly."""
|
|
encoder_cache_manager = scheduler.encoder_cache_manager
|
|
|
|
# Verify encoder cache manager exists
|
|
assert encoder_cache_manager is not None, "Encoder cache manager should exist"
|
|
|
|
# Verify number of cache
|
|
if expected_total_allocated is not None:
|
|
assert len(encoder_cache_manager.cached) == expected_total_allocated
|
|
if expected_total_allocated == 0:
|
|
return
|
|
|
|
# Verify each request with MM data is in cache
|
|
cached_hashes = set(encoder_cache_manager.cached.keys())
|
|
|
|
if hashes_to_check:
|
|
missed_hashes = set(hashes_to_check) - cached_hashes
|
|
assert not missed_hashes, (
|
|
f"Miss hashes: {missed_hashes} "
|
|
f"Existing encoder cache: {encoder_cache_manager.cached}"
|
|
)
|
|
|
|
for req in requests if requests is not None else []:
|
|
if req.mm_features:
|
|
mm_hashes = [f.identifier for f in req.mm_features]
|
|
req_hashes = set(mm_hashes) # unique hashes set
|
|
missed_hashes = req_hashes - cached_hashes
|
|
assert not missed_hashes, (
|
|
f"Miss hashes in cache for request {req.request_id}: {missed_hashes} "
|
|
f"Existing encoder cache: {encoder_cache_manager.cached}"
|
|
)
|
|
|
|
|
|
def _assert_right_ec_connector_metadata(
|
|
output: SchedulerOutput,
|
|
mm_features_list: list[MultiModalFeatureSpec],
|
|
):
|
|
"""Verify that ECConnector metadata EXACTLY matches the input MM data"""
|
|
# Get the connector metadata
|
|
metadata = output.ec_connector_metadata
|
|
|
|
# Create lookup dictionaries for efficient access
|
|
metadata_dict = {mm_data.mm_hash: mm_data for mm_data in metadata.mm_datas}
|
|
|
|
# Check all required identifiers exist in metadata; and no extra
|
|
# In ECSharedStorageConnector format
|
|
# NOTE: even having same identifier, the mm_features can be different
|
|
# since their mm_position can be in different offsets, etc
|
|
identifiers_dict = {f.identifier for f in mm_features_list}
|
|
assert set(metadata_dict.keys()) == identifiers_dict
|
|
|
|
# Verify the info matches
|
|
for i, mm_feature in enumerate(mm_features_list):
|
|
identifier = mm_feature.identifier
|
|
assert metadata_dict[identifier].mm_hash == identifier
|
|
assert metadata_dict[identifier].num_token == mm_feature.mm_position.length
|
|
|
|
|
|
def _assert_right_encoder_inputs(
|
|
output: SchedulerOutput,
|
|
check_exist: bool | None = True,
|
|
requests: list[Request] | None = None,
|
|
expected_encoder_inputs: list[list[int]] | None = None,
|
|
expected_total_reqs: int | None = None,
|
|
):
|
|
"""Verify that requests/mm_hashes should (not) in scheduled encoder input
|
|
If check_exist is False, this function returns True
|
|
if requests are NOT in encoder inputs"""
|
|
|
|
# Get the scheduled encoder inputs
|
|
# NOTE: scheduled_encoder_inputs is a dictionary with request id as key
|
|
scheduled_encoder_inputs = output.scheduled_encoder_inputs
|
|
|
|
# Check if scheduled_encoder_inputs is empty as expected
|
|
if expected_total_reqs is not None:
|
|
assert len(scheduled_encoder_inputs) == expected_total_reqs
|
|
if expected_total_reqs == 0:
|
|
return
|
|
|
|
# Number of expected enocder inputs should match number of requests
|
|
if expected_encoder_inputs:
|
|
assert check_exist and requests is not None # only support expect input exist
|
|
assert len(requests) == len(expected_encoder_inputs)
|
|
|
|
# Check request (not) exist as expected
|
|
for i, request in enumerate(requests if requests is not None else []):
|
|
assert (request.request_id in scheduled_encoder_inputs) is check_exist, (
|
|
f"Request {request.id} presence mismatch: expected {check_exist}, "
|
|
f"got {request.id in scheduled_encoder_inputs}"
|
|
)
|
|
if expected_encoder_inputs:
|
|
scheduled_encoder_input = scheduled_encoder_inputs[request.request_id]
|
|
assert scheduled_encoder_input == expected_encoder_inputs[i]
|
|
|
|
|
|
def test_scheduler_no_ec_connector_by_default():
|
|
"""Test scheduler doesn't have EC connector by default."""
|
|
scheduler = create_scheduler()
|
|
assert scheduler.ec_connector is None
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
def test_ec_connector_text_only_request(use_kv_connector):
|
|
"""Test text-only requests don't allocate encoder cache."""
|
|
scheduler = create_scheduler(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
use_kv_connector=use_kv_connector,
|
|
use_ec_connector=True,
|
|
ec_role="ec_consumer",
|
|
)
|
|
|
|
NUM_PROMPT_TOKENS = 100
|
|
|
|
# Create text-only request (no mm_positions)
|
|
requests = create_requests(
|
|
num_requests=1,
|
|
num_tokens=NUM_PROMPT_TOKENS,
|
|
)
|
|
assert not requests[0].mm_features # No MM data
|
|
|
|
scheduler.add_request(requests[0])
|
|
output = scheduler.schedule()
|
|
|
|
# Should schedule
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
# Scheduled tokens should equal prompt tokens exactly
|
|
scheduled = output.num_scheduled_tokens[requests[0].request_id]
|
|
assert scheduled == NUM_PROMPT_TOKENS, (
|
|
f"Text-only should schedule {NUM_PROMPT_TOKENS}, got {scheduled}"
|
|
)
|
|
|
|
# Encoder cache should be empty
|
|
_assert_right_encoder_cache_allocated(scheduler, expected_total_allocated=0)
|
|
|
|
# ECConnector should carry no metadata
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=[])
|
|
|
|
# Scheduled encoder input should be empty; no mm to compute
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
def test_ec_connector_cache_hit_external_load(use_kv_connector):
|
|
"""Test ec_consumer loads from external cache when hit.
|
|
A normal basic operation for EPD disaggrgation"""
|
|
scheduler = create_scheduler(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
enable_prefix_caching=True,
|
|
# kv connector should not effect test results
|
|
use_kv_connector=use_kv_connector,
|
|
use_ec_connector=True,
|
|
ec_role="ec_consumer",
|
|
)
|
|
|
|
# Create MM request
|
|
NUM_TOKENS = 200 # NOTE: includes mm tokens
|
|
NUM_ENCODER_TOKENS = 100
|
|
mm_hashes_list = [["hash_test1"]]
|
|
mm_positions = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS)]]
|
|
|
|
request = create_requests(
|
|
num_requests=1,
|
|
num_tokens=NUM_TOKENS,
|
|
mm_hashes_list=mm_hashes_list,
|
|
mm_positions=mm_positions,
|
|
)[0]
|
|
|
|
# Mock cache hit - encoder cache exists externally
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
)
|
|
|
|
scheduler.add_request(request)
|
|
output = scheduler.schedule()
|
|
|
|
# Should schedule prompt tokens
|
|
scheduled_tokens = output.num_scheduled_tokens[request.request_id]
|
|
assert scheduled_tokens == NUM_TOKENS
|
|
|
|
# Should called update_state_after_alloc for external load
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request, 0)
|
|
|
|
# Encoder cache should contain mm items from request
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request])
|
|
|
|
# ECConnector should carry metadata of request
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=request.mm_features)
|
|
|
|
# Scheduled encoder input should be empty; no mm to compute
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
def test_ec_connector_cache_miss_computes_locally(use_kv_connector):
|
|
"""Test consumer can compute encoder locally when cache miss (fallback)."""
|
|
# encoder cache itself if it doesn't receive it from external storage
|
|
|
|
scheduler = create_scheduler(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
enable_prefix_caching=True,
|
|
use_kv_connector=use_kv_connector,
|
|
use_ec_connector=True,
|
|
ec_role="ec_consumer",
|
|
)
|
|
|
|
# Verify consumer role
|
|
assert scheduler.ec_connector is not None
|
|
assert not scheduler.ec_connector.is_producer
|
|
|
|
# Create MM request
|
|
request_mm_missed = create_requests(
|
|
num_requests=1,
|
|
num_tokens=200, # Total (including 100 MM)
|
|
mm_positions=[[PlaceholderRange(offset=0, length=100)]], # 100 MM tokens
|
|
)[0]
|
|
|
|
# Mock cache miss - encoder cache doesn't exist externally
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[False])
|
|
|
|
scheduler.add_request(request_mm_missed)
|
|
output = scheduler.schedule()
|
|
|
|
# SCHEDULER should decide to compute encoder locally (fallback)
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
|
|
# Should schedule full prompt tokens
|
|
scheduled_tokens = output.num_scheduled_tokens[request_mm_missed.request_id]
|
|
assert scheduled_tokens == 200, (
|
|
f"Expected 200 tokens on cache miss, got {scheduled_tokens}"
|
|
)
|
|
|
|
# Encoder cache should contain mm items from request
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request_mm_missed])
|
|
|
|
# ECConnector should carry no metadata (missed cache)
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=[])
|
|
|
|
# Scheduled encoder input contain mm for request_mm_missed
|
|
_assert_right_encoder_inputs(
|
|
output,
|
|
requests=[request_mm_missed],
|
|
expected_encoder_inputs=[[0]], # index 0 of the mm item
|
|
expected_total_reqs=1,
|
|
)
|
|
|
|
# Then MODEL_RUNNER will execute the encoder and cache the result
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
|
|
"""Test consumer with partial cache hit (local & connector) with 2 requests."""
|
|
scheduler = create_scheduler(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
enable_prefix_caching=True,
|
|
use_kv_connector=use_kv_connector,
|
|
use_ec_connector=True,
|
|
ec_role="ec_consumer",
|
|
)
|
|
|
|
# Create MM request
|
|
NUM_TOKENS_1 = 300 # NOTE: includes mm tokens
|
|
NUM_ENCODER_TOKENS_1 = 50
|
|
mm_hashes_list_1 = [["hash1_A", "hash1_B", "hash1_A", "hash1_F"]]
|
|
mm_positions_1 = [
|
|
[
|
|
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1),
|
|
PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_1),
|
|
PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_1),
|
|
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_1),
|
|
]
|
|
]
|
|
|
|
# Create request with 4 MM items, with 2 identical items
|
|
request1 = create_requests(
|
|
num_requests=1,
|
|
num_tokens=NUM_TOKENS_1,
|
|
mm_hashes_list=mm_hashes_list_1,
|
|
mm_positions=mm_positions_1,
|
|
max_tokens=1, # For simplicity
|
|
)[0]
|
|
|
|
# Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[False, True, False, True])
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
)
|
|
|
|
scheduler.add_request(request1)
|
|
output = scheduler.schedule()
|
|
|
|
# Should schedule all tokens
|
|
scheduled_tokens = output.num_scheduled_tokens[request1.request_id]
|
|
assert scheduled_tokens == NUM_TOKENS_1
|
|
|
|
# Encoder cache should contain all mm items from request
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request1])
|
|
|
|
# Should have called update_state_after_alloc for external load
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called()
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
# ECConnector should carry metadata for 2nd and 4th mm item
|
|
_assert_right_ec_connector_metadata(
|
|
output, mm_features_list=[request1.mm_features[1], request1.mm_features[3]]
|
|
)
|
|
|
|
# Should schedule ONLY 1 encoder input (index 0), no repeat for identical items
|
|
_assert_right_encoder_inputs(
|
|
output,
|
|
requests=[request1],
|
|
expected_encoder_inputs=[[0]], # index 0 of the mm item ONLY
|
|
expected_total_reqs=1,
|
|
)
|
|
|
|
# Simulate model execution 1 step
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[request1.request_id],
|
|
req_id_to_index={request1.request_id: 0},
|
|
sampled_token_ids=[np.array([100])],
|
|
# spec_token_ids=None,
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
# request1 is finished after outputing 1 token
|
|
# Finish request
|
|
scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED)
|
|
|
|
# Create another request with 4 MM items
|
|
NUM_TOKENS_2 = 400
|
|
NUM_ENCODER_TOKENS_2 = 50
|
|
mm_hashes_list_2 = [["hash1_C", "hash1_D", "hash1_E", "hash1_A"]]
|
|
mm_positions_2 = [
|
|
[
|
|
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2),
|
|
PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_2),
|
|
PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_2),
|
|
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_2),
|
|
]
|
|
]
|
|
|
|
request2 = create_requests(
|
|
num_requests=1,
|
|
num_tokens=NUM_TOKENS_2,
|
|
mm_hashes_list=mm_hashes_list_2,
|
|
mm_positions=mm_positions_2,
|
|
max_tokens=1, # For simplicity
|
|
)[0]
|
|
|
|
# Mock partial cache hit: only hash1_A and hash1_C exist in connector
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True, False, False, True])
|
|
|
|
scheduler.add_request(request2)
|
|
output = scheduler.schedule()
|
|
|
|
# Check
|
|
# Should schedule all tokens
|
|
scheduled_tokens = output.num_scheduled_tokens[request2.request_id]
|
|
assert scheduled_tokens == 400
|
|
|
|
# Encoder cache should contain all mm items from request2
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request2])
|
|
|
|
# Should call update_state_after_alloc for hash1_C, ONLY
|
|
# hash1_A should not be loaded from connector
|
|
# since it's computed in last request & exist in local cache
|
|
# Order of getting encoder cache should be: local cache -> connector-> compute
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 0)
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_once()
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
# ECConnector should carry metadata for hash1_C only (index 0)
|
|
_assert_right_ec_connector_metadata(
|
|
output, mm_features_list=[request2.mm_features[0]]
|
|
)
|
|
|
|
# Should schedule 2 encoder input hash1_D and hash1_E (index 1, 2)
|
|
_assert_right_encoder_inputs(
|
|
output,
|
|
requests=[request2],
|
|
expected_encoder_inputs=[[1, 2]],
|
|
expected_total_reqs=1,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"])
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
def test_ec_connector_schedule_multiple_requests(cache_exist, use_kv_connector):
|
|
scheduler = create_scheduler(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
max_num_seqs=10, # allow multiple requests
|
|
max_num_batched_tokens=2048,
|
|
enable_prefix_caching=True,
|
|
use_kv_connector=use_kv_connector,
|
|
use_ec_connector=True,
|
|
ec_role="ec_consumer",
|
|
)
|
|
mm_hashes_list = [[f"hash_{i}"] for i in range(10)]
|
|
mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)]
|
|
requests = create_requests(
|
|
num_requests=10,
|
|
num_tokens=200,
|
|
mm_hashes_list=mm_hashes_list,
|
|
mm_positions=mm_positions,
|
|
)
|
|
for request in requests:
|
|
scheduler.add_request(request)
|
|
|
|
# Set up to test different encoder cache exsistence scenario after preemption
|
|
# Order of getting encoder cache should be: local cache -> connector-> compute
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
)
|
|
|
|
if cache_exist == "local":
|
|
# Allocate cache to cache manager manually to mimick
|
|
for req in requests:
|
|
scheduler.encoder_cache_manager.allocate(req, 0)
|
|
else:
|
|
# Make sure local encoder cache empty
|
|
scheduler.encoder_cache_manager.cached = {}
|
|
|
|
if cache_exist == "connector_only":
|
|
# Cache exist in ec_connector
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
|
elif cache_exist == "no_where":
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[False])
|
|
|
|
output = scheduler.schedule()
|
|
assert len(output.scheduled_new_reqs) == len(requests)
|
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
|
assert len(output.finished_req_ids) == 0
|
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
|
|
|
## Encoder-cache-specific checks:
|
|
# mm_hashes of requests exist in cache after scheduling for all scenario
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=requests)
|
|
|
|
# Should only call update_state_after_alloc when loaded externally
|
|
if cache_exist == "connector_only":
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
|
|
requests[-1], 0
|
|
)
|
|
|
|
# Concat mm_features for the 10 requests together
|
|
mm_features_list = [feature for req in requests for feature in req.mm_features]
|
|
|
|
# Check metadata should contain mm data for all 10 requests
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=mm_features_list)
|
|
else:
|
|
scheduler.ec_connector.update_state_after_alloc.assert_not_called()
|
|
# ECConnector should carry no metadata
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=[])
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
# Should only schedule encoder input when cache is not found anywhere
|
|
if cache_exist == "no_where":
|
|
_assert_right_encoder_inputs(
|
|
output,
|
|
requests=requests,
|
|
expected_encoder_inputs=[[0] for _ in range(10)],
|
|
expected_total_reqs=10,
|
|
)
|
|
else:
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
def test_ec_connector_unable_to_allocate(use_kv_connector):
|
|
"""
|
|
Test whether scheduler with ECConnector is able to handle
|
|
unable to allocate (run out of blocks).
|
|
"""
|
|
|
|
# Setup Scheduler With Mock External Cache Hit.
|
|
BLOCK_SIZE = 4
|
|
NUM_BLOCKS = 10
|
|
scheduler = create_scheduler(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
enable_prefix_caching=True,
|
|
use_kv_connector=use_kv_connector,
|
|
block_size=BLOCK_SIZE,
|
|
num_blocks=NUM_BLOCKS,
|
|
use_ec_connector=True,
|
|
ec_role="ec_consumer",
|
|
)
|
|
|
|
# Mock ec_connector load external cache behavior
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
)
|
|
|
|
# Create two requests. The second request will not be able to
|
|
# allocate slots because it will not have enough blocks.
|
|
NUM_REQUESTS = 2
|
|
NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE
|
|
MAX_TOKENS = 2
|
|
requests = create_requests(
|
|
num_requests=NUM_REQUESTS,
|
|
num_tokens=NUM_TOKENS,
|
|
mm_hashes_list=[["hash_1"], ["hash_2"]],
|
|
mm_positions=[
|
|
[PlaceholderRange(offset=1, length=10)] for _ in range(NUM_REQUESTS)
|
|
],
|
|
max_tokens=MAX_TOKENS,
|
|
block_size=BLOCK_SIZE,
|
|
)
|
|
req_ids = []
|
|
req_to_index = {}
|
|
for i, request in enumerate(requests):
|
|
scheduler.add_request(request)
|
|
req_ids.append(request.request_id)
|
|
req_to_index[request.request_id] = i
|
|
|
|
# Setup MODEL_RUNNER_OUTPUT to be run in _step_until_done later
|
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
req_ids=req_ids,
|
|
req_id_to_index=req_to_index,
|
|
sampled_token_ids=[np.array([1000])] * len(req_ids),
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
|
|
# Just one request should be running.
|
|
output = scheduler.schedule()
|
|
scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id]
|
|
assert scheduled_tokens == NUM_TOKENS
|
|
assert len(scheduler.running) == 1
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
# Should have called update_state_after_alloc for external load
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
|
|
scheduler.running[0], 0
|
|
)
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
# All memory should be freed, with one request waiting.
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
|
assert len(scheduler.running) == 0
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
# Just one request should be running.
|
|
output = scheduler.schedule()
|
|
scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id]
|
|
assert scheduled_tokens == NUM_TOKENS
|
|
assert len(scheduler.running) == 1
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
# update_state_after_alloc should be called for loading external cache
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
|
|
scheduler.running[0], 0
|
|
)
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
# All memory should be freed, with no requests waiting / running.
|
|
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
|
|
assert len(scheduler.running) == 0
|
|
assert len(scheduler.waiting) == 0
|
|
|
|
|
|
@pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"])
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
def test_priority_scheduling_ec_connector_preemption_and_resumption(
|
|
cache_exist, use_kv_connector
|
|
):
|
|
"""Test that priority scheduling preempts lower priority requests
|
|
when out of KV cache space."""
|
|
# Create scheduler with very limited memory to force preemption
|
|
scheduler = create_scheduler_with_priority(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
enable_prefix_caching=True,
|
|
max_num_seqs=2, # allow multiple requests
|
|
# kv connector should not effect test results
|
|
use_kv_connector=use_kv_connector,
|
|
num_blocks=15, # can hold 244 tokens with 14 blocks (first block is null)
|
|
block_size=16, # standard block size
|
|
use_ec_connector=True,
|
|
ec_role="ec_consumer",
|
|
)
|
|
|
|
# Mock cache hit: Both cache exist in connector (at E->PD initially)
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
)
|
|
|
|
# Create a request and schedule it (and to be preempted)
|
|
request_low = create_requests_with_priority(
|
|
num_requests=1,
|
|
priorities=[1],
|
|
arrival_times=[0.0],
|
|
num_tokens=94,
|
|
mm_hashes_list=[["hash_low"]],
|
|
# NOTE: this test only preempt the last block.
|
|
# Setting mm_position at the last block can force to recompute encoding
|
|
mm_positions=[[PlaceholderRange(offset=82, length=10)]],
|
|
starting_idx=0,
|
|
)[0]
|
|
scheduler.add_request(request_low)
|
|
# 1st schedule
|
|
output = scheduler.schedule()
|
|
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
scheduled_tokens = output.num_scheduled_tokens[request_low.request_id]
|
|
assert scheduled_tokens == 94
|
|
assert len(scheduler.waiting) == 0
|
|
assert len(scheduler.running) == 1
|
|
|
|
## Encoder-cache-specific checks:
|
|
# Encoder cache should contain mm items from request
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request_low])
|
|
|
|
# Verify update_state_after_alloc called (external load)
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_low, 0)
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
# ECConnector should carry metadata of request
|
|
_assert_right_ec_connector_metadata(
|
|
output, mm_features_list=request_low.mm_features
|
|
)
|
|
|
|
# Scheduled encoder input should be empty; no mm to compute
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
# Simulate model execution - 1st decode
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[request_low.request_id],
|
|
req_id_to_index={request_low.request_id: 0},
|
|
sampled_token_ids=[np.array([100])],
|
|
# spec_token_ids=None,
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
# Create a high priority request and schedule it
|
|
request_high = create_requests_with_priority(
|
|
num_requests=1,
|
|
priorities=[0],
|
|
arrival_times=[1.0],
|
|
num_tokens=128,
|
|
mm_hashes_list=[["hash_high"]],
|
|
mm_positions=[[PlaceholderRange(offset=1, length=10)]],
|
|
max_tokens=2,
|
|
starting_idx=1,
|
|
)[0]
|
|
scheduler.add_request(request_high)
|
|
# 2nd schedule
|
|
output = scheduler.schedule()
|
|
|
|
# KV cache should be full at this point
|
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
|
|
assert len(output.scheduled_new_reqs) == 1
|
|
assert output.scheduled_cached_reqs.num_reqs == 1
|
|
assert len(scheduler.waiting) == 0
|
|
assert len(scheduler.running) == 2
|
|
|
|
## Encoder-cache-specific checks:
|
|
# Encoder cache should contain mm items from request
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request_high])
|
|
|
|
# Verify update_state_after_alloc called (external load)
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_high, 0)
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
# ECConnector should carry metadata of request
|
|
_assert_right_ec_connector_metadata(
|
|
output, mm_features_list=request_high.mm_features
|
|
)
|
|
|
|
# Scheduled encoder input should be empty; no mm to compute
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
# Simulate model execution - 2nd decode
|
|
requests = [request_low, request_high]
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[req.request_id for req in requests],
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
|
sampled_token_ids=[np.array([100]) for _ in requests],
|
|
# spec_token_ids=None,
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
# 3rd schedule - - this should trigger preemption
|
|
# req_low needs 96 tokens = 6 blocks
|
|
# req_high needs 129 tokens = 9 blocks
|
|
# so doesn't fit in 14 blocks.
|
|
output = scheduler.schedule()
|
|
|
|
# Should have preempted req_low
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
assert output.scheduled_cached_reqs.num_reqs == 1
|
|
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
|
|
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
|
|
assert len(scheduler.waiting) == 1
|
|
assert len(scheduler.running) == 1
|
|
|
|
## Encoder-cache-specific checks:
|
|
# request_high is in decode phase now
|
|
# ECConnector should carry no metadata
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=[])
|
|
|
|
# Scheduled encoder input should be empty; no mm to compute
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
# Simulate model execution - 3rd decode, after req_low was preempted
|
|
requests = [request_low, request_high]
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[req.request_id for req in requests],
|
|
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
|
sampled_token_ids=[np.array([100]), np.array([100, 200])],
|
|
# spec_token_ids=None,
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
# Finish the requests to make room for the preempted requests to resume
|
|
# req_high is finished after outputing 2 tokens
|
|
scheduler.update_from_output(output, model_output)
|
|
scheduler.finish_requests(
|
|
request_high.request_id, RequestStatus.FINISHED_LENGTH_CAPPED
|
|
)
|
|
|
|
# Set up to test different encoder cache exsistence scenario after preemption
|
|
# Order of getting encoder cache should be: local cache -> connector-> compute
|
|
# By default, the cache should still exist in local in this test case
|
|
if cache_exist != "local":
|
|
# Make local encoder cache empty
|
|
scheduler.encoder_cache_manager.cached = {}
|
|
|
|
if cache_exist == "connector_only":
|
|
# Cache exist in ec_connector
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
|
elif cache_exist == "no_where":
|
|
scheduler.ec_connector.has_caches = Mock(return_value=[False])
|
|
|
|
# 4th Schedule - this should trigger req_low resumption from waiting
|
|
output = scheduler.schedule()
|
|
scheduled_cached_reqs = output.scheduled_cached_reqs
|
|
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
|
|
|
|
assert len(output.scheduled_new_reqs) == 0
|
|
assert scheduled_cached_reqs.num_reqs == 1
|
|
assert len(scheduler.waiting) == 0
|
|
assert len(scheduler.running) == 1
|
|
|
|
# Preempted request resumed in scheduled_cached_reqs
|
|
assert len(resumed_from_preemption) == 1
|
|
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
|
|
assert resumed_from_preemption[0]
|
|
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
|
|
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
|
|
## Resumed tokens include 94 prompt tokens and 2 decoded tokens
|
|
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 96
|
|
assert scheduled_cached_reqs.resumed_req_token_ids[0][95] == 100
|
|
assert scheduler.running[0].request_id == request_low.request_id
|
|
assert request_high.request_id in output.finished_req_ids
|
|
|
|
## Encoder-cache-specific checks:
|
|
# mm_hash of request_low exists in cache after scheduling for all scenario
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request_low])
|
|
|
|
# Should only call update_state_after_alloc when loaded externally
|
|
if cache_exist == "connector_only":
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
|
|
request_low, 0
|
|
)
|
|
_assert_right_ec_connector_metadata(
|
|
output, mm_features_list=request_low.mm_features
|
|
)
|
|
else:
|
|
scheduler.ec_connector.update_state_after_alloc.assert_not_called()
|
|
# ECConnector should carry no metadata
|
|
_assert_right_ec_connector_metadata(output, mm_features_list=[])
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
# Should only schedule encoder input when cache is not found anywhere
|
|
if cache_exist == "no_where":
|
|
_assert_right_encoder_inputs(
|
|
output,
|
|
requests=[request_low],
|
|
expected_encoder_inputs=[[0]],
|
|
expected_total_reqs=1,
|
|
)
|
|
else:
|
|
_assert_right_encoder_inputs(output, expected_total_reqs=0)
|
|
|
|
|
|
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
|
def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connector):
|
|
"""
|
|
Scenario:
|
|
- Encoder cache size: 32
|
|
- Request A: 1 feature (12 tokens) → NOT cached remotely.
|
|
- Request B: 3 features (3 x 10 tokens) → ALL cached remotely.
|
|
|
|
Steps:
|
|
1. Schedule Request A (locally uses 12 tokens).
|
|
2. Schedule Request B (remote cache) - only schedule 1st and 2nd
|
|
3. Free A's cache, then schedule B again (continuation) - schedule 3rd image
|
|
"""
|
|
scheduler = create_scheduler(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
max_num_batched_tokens=1024,
|
|
enable_prefix_caching=True,
|
|
use_kv_connector=use_kv_connector,
|
|
block_size=16,
|
|
num_blocks=11, # Can hold 160 tokens (first block is null)
|
|
use_ec_connector=True,
|
|
ec_role="ec_consumer",
|
|
)
|
|
|
|
# Limit the number of availiable slots of EncoderCacheManager
|
|
scheduler.encoder_cache_manager = EncoderCacheManager(cache_size=32)
|
|
|
|
# Create MM request1
|
|
NUM_TOKENS_1 = 50 # NOTE: includes mm tokens
|
|
NUM_ENCODER_TOKENS_1 = 12
|
|
mm_hashes_list_1 = [["hash1_1"]]
|
|
mm_positions_1 = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1)]]
|
|
|
|
request1 = create_requests(
|
|
num_requests=1,
|
|
num_tokens=NUM_TOKENS_1,
|
|
mm_hashes_list=mm_hashes_list_1,
|
|
mm_positions=mm_positions_1,
|
|
max_tokens=1, # For simplicity
|
|
req_ids=["req1"],
|
|
)[0]
|
|
|
|
# Create MM request1 with 3 MM items
|
|
NUM_TOKENS_2 = 40
|
|
NUM_ENCODER_TOKENS_2 = 10
|
|
mm_hashes_list_2 = [["hash2_1", "hash2_2", "hash2_3"]]
|
|
mm_positions_2 = [
|
|
[
|
|
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2),
|
|
PlaceholderRange(offset=12, length=NUM_ENCODER_TOKENS_2),
|
|
PlaceholderRange(offset=24, length=NUM_ENCODER_TOKENS_2),
|
|
]
|
|
]
|
|
|
|
request2 = create_requests(
|
|
num_requests=1,
|
|
num_tokens=NUM_TOKENS_2,
|
|
mm_hashes_list=mm_hashes_list_2,
|
|
mm_positions=mm_positions_2,
|
|
max_tokens=10,
|
|
req_ids=["req2"],
|
|
)[0]
|
|
|
|
# Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely
|
|
scheduler.ec_connector.has_caches = Mock(
|
|
side_effect=lambda req: [True, True, True] if req == request2 else [False]
|
|
)
|
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
|
)
|
|
|
|
scheduler.add_request(request1)
|
|
scheduler.add_request(request2)
|
|
output = scheduler.schedule()
|
|
|
|
# Now, since encoder cache manager can only store 32 tokens
|
|
# It should allocated mm item hash1_1, hash2_1 and hash2_2
|
|
scheduled_tokens = output.num_scheduled_tokens[request1.request_id]
|
|
assert scheduled_tokens == NUM_TOKENS_1
|
|
assert scheduler.get_num_unfinished_requests() == 2
|
|
|
|
# Encoder cache should contain mm item from request1
|
|
_assert_right_encoder_cache_allocated(
|
|
scheduler, hashes_to_check=["hash1_1", "hash2_1", "hash2_2"]
|
|
)
|
|
|
|
# request2's 2nd mm item is the last call of update_state_after_alloc
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 1)
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
# ECConnector should carry metadata of hash2_1 and hash2_2 ONLY
|
|
_assert_right_ec_connector_metadata(
|
|
output, mm_features_list=[request2.mm_features[0], request2.mm_features[1]]
|
|
)
|
|
|
|
# Should schedule ONLY 1 encoder input
|
|
_assert_right_encoder_inputs(
|
|
output,
|
|
requests=[request1],
|
|
expected_encoder_inputs=[[0]], # index 0 of the mm item of request1
|
|
expected_total_reqs=1,
|
|
)
|
|
|
|
# Simulate model execution 1 step
|
|
model_output = ModelRunnerOutput(
|
|
req_ids=[request1.request_id, request2.request_id],
|
|
req_id_to_index={request1.request_id: 0, request2.request_id: 1},
|
|
sampled_token_ids=[np.array([100]), np.array([121])],
|
|
# spec_token_ids=None,
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=[],
|
|
)
|
|
scheduler.update_from_output(output, model_output)
|
|
|
|
# request1 is finished after outputing 1 token
|
|
# Finish request
|
|
scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED)
|
|
assert scheduler.get_num_unfinished_requests() == 1
|
|
|
|
# Schedule again; Now request1's encoder cache should be freed
|
|
# -> hash2_3 can be scheduled and allocated
|
|
output = scheduler.schedule()
|
|
|
|
# Check
|
|
# Should schedule all tokens
|
|
scheduled_tokens = output.num_scheduled_tokens[request2.request_id]
|
|
print(f"Hero: scheduled_tokens for req2: {scheduled_tokens}")
|
|
print(f"hero: num_scheduled_tokens 2: {output.num_scheduled_tokens}")
|
|
|
|
# Encoder cache should contain all mm items from request2
|
|
_assert_right_encoder_cache_allocated(scheduler, requests=[request2])
|
|
|
|
# request2's 3rd mm item is the ONLY call of update_state_after_alloc
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 2)
|
|
scheduler.ec_connector.update_state_after_alloc.assert_called_once()
|
|
|
|
scheduler.ec_connector.update_state_after_alloc.reset_mock()
|
|
|
|
# ECConnector should carry metadata for hash2_3 ONLY
|
|
_assert_right_ec_connector_metadata(
|
|
output, mm_features_list=[request2.mm_features[2]]
|
|
)
|
|
|
|
# Should schedule no encoder input
|
|
_assert_right_encoder_inputs(
|
|
output,
|
|
expected_total_reqs=0,
|
|
)
|
|
|
|
|
|
# ==============================================================================
|
|
# EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests end
|
|
# ==============================================================================
|