mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[Core] feat: Implement Priority Scheduling in V1 Engine (#19057)
Signed-off-by: amit <amit.man@gmail.com> Co-authored-by: Roger Wang <Rogerw0108@gmail.com>
This commit is contained in:
parent
c4cf260677
commit
4a0f7888a3
@ -45,6 +45,18 @@ For each item, our progress towards V1 support falls into one of the following s
|
|||||||
- **🟠 Delayed**: Temporarily dropped in V1 but planned to be re-introduced later.
|
- **🟠 Delayed**: Temporarily dropped in V1 but planned to be re-introduced later.
|
||||||
- **🔴 Deprecated**: Not planned for V1 unless there is strong demand.
|
- **🔴 Deprecated**: Not planned for V1 unless there is strong demand.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
vLLM V1’s unified scheduler treats both prompt and output tokens the same
|
||||||
|
way by using a simple dictionary (e.g., `{request_id: num_tokens}`) to dynamically
|
||||||
|
allocate a fixed token budget per request, enabling features like chunked prefills,
|
||||||
|
prefix caching, and speculative decoding without a strict separation between prefill
|
||||||
|
and decode phases.
|
||||||
|
|
||||||
|
The V1 scheduler supports multiple scheduling policies, including First-Come,
|
||||||
|
First-Served (FCFS) and priority-based scheduling (where requests are processed
|
||||||
|
based on assigned priority, with FCFS as a tie-breaker), configurable via the
|
||||||
|
`--scheduling-policy` argument.
|
||||||
|
|
||||||
### Hardware
|
### Hardware
|
||||||
|
|
||||||
| Hardware | Status |
|
| Hardware | Status |
|
||||||
|
|||||||
@ -1150,7 +1150,6 @@ def test_kv_connector_handles_preemption():
|
|||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler.waiting) == 1
|
|
||||||
# All memory should be freed since nothing is running.
|
# All memory should be freed since nothing is running.
|
||||||
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
||||||
== NUM_BLOCKS - 1
|
== NUM_BLOCKS - 1
|
||||||
@ -1265,3 +1264,592 @@ def test_memory_leak():
|
|||||||
|
|
||||||
# Confirm no memory leak.
|
# Confirm no memory leak.
|
||||||
assert_scheduler_empty(scheduler)
|
assert_scheduler_empty(scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
def create_scheduler_with_priority(
|
||||||
|
model: str = "facebook/opt-125m",
|
||||||
|
max_num_seqs: int = 16,
|
||||||
|
max_num_batched_tokens: int = 8192,
|
||||||
|
enable_prefix_caching: Optional[bool] = None,
|
||||||
|
long_prefill_token_threshold: int = 0,
|
||||||
|
disable_chunked_mm_input: bool = False,
|
||||||
|
use_kv_connector: bool = False,
|
||||||
|
num_blocks: int = 10000,
|
||||||
|
block_size: int = 16,
|
||||||
|
max_model_len: Optional[int] = None,
|
||||||
|
num_speculative_tokens: Optional[int] = None,
|
||||||
|
) -> Scheduler:
|
||||||
|
'''Create scheduler with priority policy enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: model under test
|
||||||
|
max_num_seqs: max sequences to schedule
|
||||||
|
max_num_batch_tokens: max num tokens to batch
|
||||||
|
enable_prefix_caching: optionally force APC config
|
||||||
|
(True/False) or use default
|
||||||
|
(None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{class}`Scheduler` instance with priority scheduling
|
||||||
|
'''
|
||||||
|
if max_model_len is None:
|
||||||
|
max_model_len = max_num_batched_tokens
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||||
|
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
policy="priority", # Enable priority scheduling
|
||||||
|
)
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=model,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="float16",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
# Cache config, optionally force APC
|
||||||
|
kwargs_cache = ({} if enable_prefix_caching is None else {
|
||||||
|
'enable_prefix_caching': enable_prefix_caching
|
||||||
|
})
|
||||||
|
cache_config = CacheConfig(
|
||||||
|
block_size=block_size,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
swap_space=0,
|
||||||
|
cache_dtype="auto",
|
||||||
|
**kwargs_cache,
|
||||||
|
)
|
||||||
|
kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="SharedStorageConnector",
|
||||||
|
kv_role="kv_both",
|
||||||
|
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||||
|
) if use_kv_connector else None
|
||||||
|
|
||||||
|
speculative_config: Optional[SpeculativeConfig] = None
|
||||||
|
if num_speculative_tokens is not None:
|
||||||
|
speculative_config = SpeculativeConfig(
|
||||||
|
model="ngram", num_speculative_tokens=num_speculative_tokens)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
)
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(['layer'],
|
||||||
|
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||||
|
False))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cache_config.num_gpu_blocks = num_blocks
|
||||||
|
return Scheduler(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
log_stats=True,
|
||||||
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_requests_with_priority(
|
||||||
|
num_requests: int,
|
||||||
|
priorities: list[int],
|
||||||
|
arrival_times: Optional[list[float]] = None,
|
||||||
|
num_tokens: int = 10,
|
||||||
|
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||||
|
max_tokens: int = 16,
|
||||||
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
|
prompt_logprobs: Optional[int] = None):
|
||||||
|
"""Create requests with specified priorities and arrival times."""
|
||||||
|
assert len(priorities) == num_requests
|
||||||
|
if arrival_times is not None:
|
||||||
|
assert len(arrival_times) == num_requests
|
||||||
|
else:
|
||||||
|
arrival_times = [float(i) for i in range(num_requests)]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(ignore_eos=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
prompt_logprobs=prompt_logprobs)
|
||||||
|
requests = []
|
||||||
|
for i in range(num_requests):
|
||||||
|
if mm_positions is not None:
|
||||||
|
mm_position = mm_positions[i]
|
||||||
|
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||||
|
else:
|
||||||
|
mm_position = None
|
||||||
|
mm_inputs = None
|
||||||
|
request = Request(
|
||||||
|
request_id=f"{i}",
|
||||||
|
prompt_token_ids=[i] * num_tokens,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
pooling_params=None,
|
||||||
|
multi_modal_inputs=mm_inputs,
|
||||||
|
multi_modal_placeholders=mm_position,
|
||||||
|
multi_modal_hashes=None,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
arrival_time=arrival_times[i],
|
||||||
|
priority=priorities[i],
|
||||||
|
)
|
||||||
|
requests.append(request)
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_basic_ordering():
|
||||||
|
"""Test that requests are scheduled in priority order
|
||||||
|
(lower value = higher priority)."""
|
||||||
|
scheduler = create_scheduler_with_priority()
|
||||||
|
|
||||||
|
# Create requests with different priorities
|
||||||
|
# Priority 0 (highest), 1, 2 (lowest)
|
||||||
|
priorities = [2, 0, 1] # Add in non-priority order
|
||||||
|
arrival_times = [1.0, 2.0, 3.0] # All different arrival times
|
||||||
|
requests = create_requests_with_priority(num_requests=3,
|
||||||
|
priorities=priorities,
|
||||||
|
arrival_times=arrival_times)
|
||||||
|
|
||||||
|
# Add requests in non-priority order
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Schedule and verify priority order
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
# Should schedule all requests since they fit in budget
|
||||||
|
assert len(output.scheduled_new_reqs) == 3
|
||||||
|
|
||||||
|
# Verify they are scheduled in priority order:
|
||||||
|
# req_1 (priority 0), req_2 (priority 1), req_0 (priority 2)
|
||||||
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||||
|
assert scheduled_req_ids == ["1", "2", "0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_arrival_time_tiebreaker():
|
||||||
|
"""Test that arrival time is used
|
||||||
|
as tiebreaker when priorities are equal."""
|
||||||
|
scheduler = create_scheduler_with_priority()
|
||||||
|
|
||||||
|
# Create requests with same priority but different arrival times
|
||||||
|
priorities = [1, 1, 1] # All same priority
|
||||||
|
arrival_times = [3.0, 1.0, 2.0] # Different arrival times
|
||||||
|
requests = create_requests_with_priority(num_requests=3,
|
||||||
|
priorities=priorities,
|
||||||
|
arrival_times=arrival_times)
|
||||||
|
|
||||||
|
# Add requests in non-arrival order
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Schedule and verify arrival time order
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
# Should schedule all requests since they fit in budget
|
||||||
|
assert len(output.scheduled_new_reqs) == 3
|
||||||
|
|
||||||
|
# Verify they are scheduled in arrival time order:
|
||||||
|
# req_1 (1.0), req_2 (2.0), req_0 (3.0)
|
||||||
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||||
|
assert scheduled_req_ids == ["1", "2", "0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_mixed_priority_and_arrival():
|
||||||
|
"""Test priority scheduling with mixed priorities and arrival times."""
|
||||||
|
scheduler = create_scheduler_with_priority()
|
||||||
|
|
||||||
|
# Create requests with mixed priorities and arrival times
|
||||||
|
priorities = [2, 1, 1, 0] # Mixed priorities
|
||||||
|
arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times
|
||||||
|
requests = create_requests_with_priority(num_requests=4,
|
||||||
|
priorities=priorities,
|
||||||
|
arrival_times=arrival_times)
|
||||||
|
|
||||||
|
# Add requests
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Schedule and verify order
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
# Should schedule all requests since they fit in budget
|
||||||
|
assert len(output.scheduled_new_reqs) == 4
|
||||||
|
|
||||||
|
# Expected order:
|
||||||
|
# 1. req_3 (priority 0, arrival 4.0)
|
||||||
|
# 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1
|
||||||
|
# 3. req_1 (priority 1, arrival 3.0)
|
||||||
|
# 4. req_0 (priority 2, arrival 1.0)
|
||||||
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||||
|
assert scheduled_req_ids == ["3", "2", "1", "0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_preemption():
|
||||||
|
"""Test that priority scheduling preempts
|
||||||
|
lower priority requests when memory is constrained."""
|
||||||
|
# Create scheduler with very limited memory to force preemption
|
||||||
|
scheduler = create_scheduler_with_priority(
|
||||||
|
max_num_seqs=3, # Allow multiple requests
|
||||||
|
max_num_batched_tokens=200,
|
||||||
|
num_blocks=6, # Very limited blocks to force memory pressure
|
||||||
|
block_size=16, # Standard block size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create initial low-priority requests that will consume most memory
|
||||||
|
low_priority_requests = create_requests_with_priority(
|
||||||
|
num_requests=2,
|
||||||
|
priorities=[5, 5], # Low priority
|
||||||
|
arrival_times=[1.0, 2.0],
|
||||||
|
num_tokens=30 # Large enough to consume significant memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add and schedule low priority requests
|
||||||
|
for request in low_priority_requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 2
|
||||||
|
|
||||||
|
# Simulate model execution to move requests to running state
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in low_priority_requests],
|
||||||
|
req_id_to_index={
|
||||||
|
req.request_id: i
|
||||||
|
for i, req in enumerate(low_priority_requests)
|
||||||
|
},
|
||||||
|
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output, model_output)
|
||||||
|
|
||||||
|
# Verify both requests are running
|
||||||
|
assert len(scheduler.running) == 2
|
||||||
|
|
||||||
|
# Now add a high-priority request that requires memory allocation
|
||||||
|
# This should trigger preemption due to memory constraints
|
||||||
|
high_priority_request = create_requests_with_priority(
|
||||||
|
num_requests=1,
|
||||||
|
priorities=[0], # High priority
|
||||||
|
arrival_times=[3.0],
|
||||||
|
num_tokens=30 # Large enough to require significant memory
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
scheduler.add_request(high_priority_request)
|
||||||
|
|
||||||
|
# Schedule again - this should trigger
|
||||||
|
# preemption when trying to allocate memory
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
# Due to the scheduler's design, if preemption happens
|
||||||
|
# during running request scheduling,
|
||||||
|
# waiting requests won't be scheduled in the same step
|
||||||
|
# Let's check if preemption occurred by looking at the waiting queue
|
||||||
|
|
||||||
|
# If preemption happened, we should see requests in the
|
||||||
|
# waiting queue
|
||||||
|
if len(scheduler.waiting) > 1: # high priority + preempted request
|
||||||
|
# Preemption occurred - verify the high priority request
|
||||||
|
# gets scheduled next
|
||||||
|
output2 = scheduler.schedule()
|
||||||
|
assert len(output2.scheduled_new_reqs) == 1
|
||||||
|
# High priority request
|
||||||
|
assert output2.scheduled_new_reqs[0].req_id == "0"
|
||||||
|
else:
|
||||||
|
# No preemption needed - all requests fit
|
||||||
|
# This is also valid behavior if memory allows
|
||||||
|
assert len(output.scheduled_new_reqs) == 1
|
||||||
|
# High priority request
|
||||||
|
assert output.scheduled_new_reqs[0].req_id == "0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_no_preemption_when_space_available():
|
||||||
|
"""Test that preemption doesn't happen
|
||||||
|
when there's space for new requests."""
|
||||||
|
scheduler = create_scheduler_with_priority(
|
||||||
|
max_num_seqs=3, # Allow 3 concurrent requests
|
||||||
|
max_num_batched_tokens=200, # Sufficient token budget
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add two low-priority running requests
|
||||||
|
low_priority_requests = create_requests_with_priority(
|
||||||
|
num_requests=2,
|
||||||
|
priorities=[5, 5],
|
||||||
|
arrival_times=[1.0, 2.0],
|
||||||
|
num_tokens=30)
|
||||||
|
|
||||||
|
for request in low_priority_requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
output = scheduler.schedule()
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in low_priority_requests],
|
||||||
|
req_id_to_index={
|
||||||
|
req.request_id: i
|
||||||
|
for i, req in enumerate(low_priority_requests)
|
||||||
|
},
|
||||||
|
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output, model_output)
|
||||||
|
|
||||||
|
# Add high-priority request
|
||||||
|
high_priority_request = create_requests_with_priority(num_requests=1,
|
||||||
|
priorities=[0],
|
||||||
|
arrival_times=[3.0],
|
||||||
|
num_tokens=30)[0]
|
||||||
|
|
||||||
|
scheduler.add_request(high_priority_request)
|
||||||
|
|
||||||
|
# Schedule - should not preempt since there's space
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
# Should schedule the new request without preemption
|
||||||
|
assert len(output.scheduled_new_reqs) == 1
|
||||||
|
assert len(scheduler.running) == 3 # All three requests running
|
||||||
|
assert len(scheduler.waiting) == 0 # No requests waiting
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_preemption_victim_selection():
|
||||||
|
"""Test that the correct victim is selected for
|
||||||
|
preemption based on priority and arrival time."""
|
||||||
|
# This test verifies the priority-based victim selection logic
|
||||||
|
# by checking the waiting queue order after adding requests with different
|
||||||
|
# priorities
|
||||||
|
scheduler = create_scheduler_with_priority(
|
||||||
|
max_num_seqs=1, # Force sequential processing to test priority order
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create requests with different priorities
|
||||||
|
requests = create_requests_with_priority(
|
||||||
|
num_requests=3,
|
||||||
|
priorities=[3, 2, 0], # Different priorities: low, medium, high
|
||||||
|
arrival_times=[1.0, 2.0, 3.0],
|
||||||
|
num_tokens=10)
|
||||||
|
|
||||||
|
# Add all requests
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Schedule - should only schedule the highest priority request
|
||||||
|
# (req_2, priority 0)
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 1
|
||||||
|
assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority
|
||||||
|
|
||||||
|
# Verify the waiting queue has the remaining requests in priority order
|
||||||
|
assert len(scheduler.waiting) == 2
|
||||||
|
|
||||||
|
# Extract waiting requests and verify priority order
|
||||||
|
waiting_requests = list(scheduler.waiting)
|
||||||
|
|
||||||
|
waiting_priorities = [req.priority for req in waiting_requests]
|
||||||
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
||||||
|
|
||||||
|
# Should be req_1 (priority 2) then req_0 (priority 3)
|
||||||
|
assert waiting_priorities == [2, 3]
|
||||||
|
assert waiting_req_ids == ["1", "0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_equal_priority_preemption():
|
||||||
|
"""Test arrival time tiebreaker when requests have equal priority."""
|
||||||
|
# This test verifies that arrival time is used as a tiebreaker for equal
|
||||||
|
# priorities
|
||||||
|
scheduler = create_scheduler_with_priority(
|
||||||
|
max_num_seqs=1, # Force sequential processing
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create requests with same priority but different arrival times
|
||||||
|
requests = create_requests_with_priority(
|
||||||
|
num_requests=3,
|
||||||
|
priorities=[2, 2, 2], # Same priority
|
||||||
|
arrival_times=[3.0, 1.0, 2.0], # Different arrival times
|
||||||
|
num_tokens=10)
|
||||||
|
|
||||||
|
# Add all requests
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Schedule - should schedule the request with earliest arrival time
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 1
|
||||||
|
assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0)
|
||||||
|
|
||||||
|
# Verify the waiting queue has remaining requests in arrival time order
|
||||||
|
assert len(scheduler.waiting) == 2
|
||||||
|
|
||||||
|
# Extract waiting requests and verify arrival time order
|
||||||
|
waiting_requests = list(scheduler.waiting)
|
||||||
|
|
||||||
|
waiting_arrival_times = [req.arrival_time for req in waiting_requests]
|
||||||
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
||||||
|
|
||||||
|
# Should be req_2 (arrival 2.0) then req_0 (arrival 3.0)
|
||||||
|
assert waiting_arrival_times == [2.0, 3.0]
|
||||||
|
assert waiting_req_ids == ["2", "0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_waiting_queue_order():
|
||||||
|
"""Test that the waiting queue maintains priority order."""
|
||||||
|
scheduler = create_scheduler_with_priority(
|
||||||
|
max_num_seqs=1, # Only one request can run at a time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create multiple requests with different priorities
|
||||||
|
requests = create_requests_with_priority(
|
||||||
|
num_requests=4,
|
||||||
|
priorities=[3, 1, 2, 0], # Mixed priorities
|
||||||
|
arrival_times=[1.0, 2.0, 3.0, 4.0],
|
||||||
|
num_tokens=10)
|
||||||
|
|
||||||
|
# Add all requests
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Schedule - should only schedule the highest priority request
|
||||||
|
# (req_3, priority 0)
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 1
|
||||||
|
assert output.scheduled_new_reqs[0].req_id == "3"
|
||||||
|
|
||||||
|
# Verify waiting queue has remaining requests in priority order
|
||||||
|
assert len(scheduler.waiting) == 3
|
||||||
|
|
||||||
|
# Extract requests from waiting queue
|
||||||
|
# (it's a heap, so we need to pop to see order)
|
||||||
|
waiting_requests = list(scheduler.waiting)
|
||||||
|
|
||||||
|
waiting_priorities = [req.priority for req in waiting_requests]
|
||||||
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
||||||
|
|
||||||
|
# Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3)
|
||||||
|
assert waiting_req_ids == ["1", "2", "0"]
|
||||||
|
assert waiting_priorities == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_fcfs_fallback():
|
||||||
|
"""Test that FCFS behavior is maintained when all
|
||||||
|
requests have same priority."""
|
||||||
|
scheduler = create_scheduler_with_priority()
|
||||||
|
|
||||||
|
# Create requests with same priority but different arrival times
|
||||||
|
priorities = [1, 1, 1, 1] # All same priority
|
||||||
|
arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times
|
||||||
|
requests = create_requests_with_priority(num_requests=4,
|
||||||
|
priorities=priorities,
|
||||||
|
arrival_times=arrival_times)
|
||||||
|
|
||||||
|
# Add requests
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Schedule
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
# Should schedule all requests in arrival time order
|
||||||
|
assert len(output.scheduled_new_reqs) == 4
|
||||||
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||||
|
|
||||||
|
# Expected order by arrival time:
|
||||||
|
# req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0)
|
||||||
|
assert scheduled_req_ids == ["1", "3", "2", "0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_with_limited_slots():
|
||||||
|
"""Test priority scheduling when max_num_seqs limits concurrent requests."""
|
||||||
|
scheduler = create_scheduler_with_priority(
|
||||||
|
max_num_seqs=2, # Only allow 2 concurrent requests
|
||||||
|
max_num_batched_tokens=1000, # Plenty of token budget
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create requests with different priorities
|
||||||
|
requests = create_requests_with_priority(
|
||||||
|
num_requests=4,
|
||||||
|
priorities=[3, 1, 2, 0], # Mixed priorities
|
||||||
|
arrival_times=[1.0, 2.0, 3.0, 4.0],
|
||||||
|
num_tokens=10)
|
||||||
|
|
||||||
|
# Add all requests
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Schedule - should only schedule the 2 highest priority requests
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 2
|
||||||
|
|
||||||
|
# Should schedule req_3 (priority 0) and req_1 (priority 1)
|
||||||
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||||
|
assert "3" in scheduled_req_ids # Priority 0
|
||||||
|
assert "1" in scheduled_req_ids # Priority 1
|
||||||
|
|
||||||
|
# Remaining requests should be in waiting queue in priority order
|
||||||
|
assert len(scheduler.waiting) == 2
|
||||||
|
|
||||||
|
# Extract waiting requests and verify order
|
||||||
|
waiting_requests = list(scheduler.waiting)
|
||||||
|
waiting_priorities = [req.priority for req in waiting_requests]
|
||||||
|
waiting_req_ids = [req.request_id for req in waiting_requests]
|
||||||
|
|
||||||
|
# Should be req_2 (priority 2) then req_0 (priority 3)
|
||||||
|
assert waiting_priorities == [2, 3]
|
||||||
|
assert waiting_req_ids == ["2", "0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_heap_property():
|
||||||
|
"""Test that the waiting queue maintains heap
|
||||||
|
property for priority scheduling."""
|
||||||
|
scheduler = create_scheduler_with_priority(
|
||||||
|
max_num_seqs=1, # Only one request can run at a time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add requests in random priority order
|
||||||
|
priorities = [5, 1, 8, 3, 2, 7, 4, 6]
|
||||||
|
arrival_times = [float(i) for i in range(len(priorities))]
|
||||||
|
requests = create_requests_with_priority(num_requests=len(priorities),
|
||||||
|
priorities=priorities,
|
||||||
|
arrival_times=arrival_times,
|
||||||
|
num_tokens=10)
|
||||||
|
|
||||||
|
# Add all requests
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Schedule one request at a time and verify priority order
|
||||||
|
scheduled_priorities = []
|
||||||
|
|
||||||
|
while scheduler.waiting:
|
||||||
|
output = scheduler.schedule()
|
||||||
|
if output.scheduled_new_reqs:
|
||||||
|
req = output.scheduled_new_reqs[0]
|
||||||
|
scheduled_priorities.append(requests[int(req.req_id)].priority)
|
||||||
|
|
||||||
|
# Simulate completion to make room for next request
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[req.req_id],
|
||||||
|
req_id_to_index={req.req_id: 0},
|
||||||
|
sampled_token_ids=[[100]],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output, model_output)
|
||||||
|
|
||||||
|
# Finish the request to make room for the next one
|
||||||
|
scheduler.finish_requests(req.req_id,
|
||||||
|
RequestStatus.FINISHED_STOPPED)
|
||||||
|
|
||||||
|
# Verify requests were scheduled in priority order (lowest value first)
|
||||||
|
expected_priorities = sorted(priorities)
|
||||||
|
assert scheduled_priorities == expected_priorities
|
||||||
|
|||||||
224
vllm/v1/core/sched/request_queue.py
Normal file
224
vllm/v1/core/sched/request_queue.py
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import heapq
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import deque
|
||||||
|
from collections.abc import Iterable, Iterator
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulingPolicy(Enum):
|
||||||
|
"""Enum for scheduling policies."""
|
||||||
|
FCFS = "fcfs"
|
||||||
|
PRIORITY = "priority"
|
||||||
|
|
||||||
|
|
||||||
|
class RequestQueue(ABC):
|
||||||
|
"""Abstract base class for request queues."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_request(self, request: Request) -> None:
|
||||||
|
"""Add a request to the queue according to the policy."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pop_request(self) -> Request:
|
||||||
|
"""Pop a request from the queue according to the policy."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def peek_request(self) -> Request:
|
||||||
|
"""Peek at the request at the front of the queue without removing it."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prepend_request(self, request: Request) -> None:
|
||||||
|
"""Prepend a request to the front of the queue."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prepend_requests(self, requests: RequestQueue) -> None:
|
||||||
|
"""Prepend all requests from another queue to the front of this
|
||||||
|
queue."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_request(self, request: Request) -> None:
|
||||||
|
"""Remove a specific request from the queue."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||||
|
"""Remove multiple specific requests from the queue."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
"""Check if queue has any requests."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Get number of requests in queue."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __iter__(self) -> Iterator[Request]:
|
||||||
|
"""Iterate over the queue according to the policy."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __reversed__(self) -> Iterator[Request]:
|
||||||
|
"""Iterate over the queue in reverse order."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FCFSRequestQueue(deque[Request], RequestQueue):
|
||||||
|
"""A first-come-first-served queue that supports deque operations."""
|
||||||
|
|
||||||
|
def add_request(self, request: Request) -> None:
|
||||||
|
"""Add a request to the queue according to FCFS policy."""
|
||||||
|
self.append(request)
|
||||||
|
|
||||||
|
def pop_request(self) -> Request:
|
||||||
|
"""Pop a request from the queue according to FCFS policy."""
|
||||||
|
return self.popleft()
|
||||||
|
|
||||||
|
def peek_request(self) -> Request:
|
||||||
|
"""Peek at the next request in the queue without removing it."""
|
||||||
|
if not self:
|
||||||
|
raise IndexError("peek from an empty queue")
|
||||||
|
return self[0]
|
||||||
|
|
||||||
|
def prepend_request(self, request: Request) -> None:
|
||||||
|
"""Prepend a request to the front of the queue."""
|
||||||
|
self.appendleft(request)
|
||||||
|
|
||||||
|
def prepend_requests(self, requests: RequestQueue) -> None:
|
||||||
|
"""Prepend all requests from another queue to the front of this
|
||||||
|
queue."""
|
||||||
|
self.extendleft(reversed(requests))
|
||||||
|
|
||||||
|
def remove_request(self, request: Request) -> None:
|
||||||
|
"""Remove a specific request from the queue."""
|
||||||
|
self.remove(request)
|
||||||
|
|
||||||
|
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||||
|
"""Remove multiple specific requests from the queue."""
|
||||||
|
requests_to_remove = set(requests)
|
||||||
|
filtered_requests = [
|
||||||
|
req for req in self if req not in requests_to_remove
|
||||||
|
]
|
||||||
|
# deque does not support in-place filtering, so we need to clear
|
||||||
|
# and extend
|
||||||
|
self.clear()
|
||||||
|
self.extend(filtered_requests)
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
"""Check if queue has any requests."""
|
||||||
|
return len(self) > 0
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Get number of requests in queue."""
|
||||||
|
return super().__len__()
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Request]:
|
||||||
|
"""Iterate over the queue according to FCFS policy."""
|
||||||
|
return super().__iter__()
|
||||||
|
|
||||||
|
def __reversed__(self) -> Iterator[Request]:
|
||||||
|
"""Iterate over the queue in reverse order."""
|
||||||
|
return super().__reversed__()
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityRequestQueue(RequestQueue):
|
||||||
|
"""
|
||||||
|
A priority queue that supports heap operations.
|
||||||
|
|
||||||
|
Requests with a smaller value of `priority` are processed first.
|
||||||
|
If multiple requests have the same priority, the one with the earlier
|
||||||
|
`arrival_time` is processed first.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._heap: list[tuple[int, float, Request]] = []
|
||||||
|
|
||||||
|
def add_request(self, request: Request) -> None:
|
||||||
|
"""Add a request to the queue according to priority policy."""
|
||||||
|
heapq.heappush(self._heap,
|
||||||
|
(request.priority, request.arrival_time, request))
|
||||||
|
|
||||||
|
def pop_request(self) -> Request:
|
||||||
|
"""Pop a request from the queue according to priority policy."""
|
||||||
|
if not self._heap:
|
||||||
|
raise IndexError("pop from empty heap")
|
||||||
|
_, _, request = heapq.heappop(self._heap)
|
||||||
|
return request
|
||||||
|
|
||||||
|
def peek_request(self) -> Request:
|
||||||
|
"""Peek at the next request in the queue without removing it."""
|
||||||
|
if not self._heap:
|
||||||
|
raise IndexError("peek from empty heap")
|
||||||
|
_, _, request = self._heap[0]
|
||||||
|
return request
|
||||||
|
|
||||||
|
def prepend_request(self, request: Request) -> None:
|
||||||
|
"""Add a request to the queue according to priority policy.
|
||||||
|
|
||||||
|
Note: In a priority queue, there is no concept of prepending to the
|
||||||
|
front. Requests are ordered by (priority, arrival_time)."""
|
||||||
|
self.add_request(request)
|
||||||
|
|
||||||
|
def prepend_requests(self, requests: RequestQueue) -> None:
|
||||||
|
"""Add all requests from another queue according to priority policy.
|
||||||
|
|
||||||
|
Note: In a priority queue, there is no concept of prepending to the
|
||||||
|
front. Requests are ordered by (priority, arrival_time)."""
|
||||||
|
for request in requests:
|
||||||
|
self.add_request(request)
|
||||||
|
|
||||||
|
def remove_request(self, request: Request) -> None:
|
||||||
|
"""Remove a specific request from the queue."""
|
||||||
|
self._heap = [(p, t, r) for p, t, r in self._heap if r != request]
|
||||||
|
heapq.heapify(self._heap)
|
||||||
|
|
||||||
|
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||||
|
"""Remove multiple specific requests from the queue."""
|
||||||
|
requests_to_remove = set(requests)
|
||||||
|
self._heap = [(p, t, r) for p, t, r in self._heap
|
||||||
|
if r not in requests_to_remove]
|
||||||
|
heapq.heapify(self._heap)
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
"""Check if queue has any requests."""
|
||||||
|
return bool(self._heap)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Get number of requests in queue."""
|
||||||
|
return len(self._heap)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Request]:
|
||||||
|
"""Iterate over the queue according to priority policy."""
|
||||||
|
heap_copy = self._heap[:]
|
||||||
|
while heap_copy:
|
||||||
|
_, _, request = heapq.heappop(heap_copy)
|
||||||
|
yield request
|
||||||
|
|
||||||
|
def __reversed__(self) -> Iterator[Request]:
|
||||||
|
"""Iterate over the queue in reverse priority order."""
|
||||||
|
return reversed(list(self))
|
||||||
|
|
||||||
|
|
||||||
|
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
|
||||||
|
"""Create request queue based on scheduling policy."""
|
||||||
|
if policy == SchedulingPolicy.PRIORITY:
|
||||||
|
return PriorityRequestQueue()
|
||||||
|
elif policy == SchedulingPolicy.FCFS:
|
||||||
|
return FCFSRequestQueue()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown scheduling policy: {policy}")
|
||||||
@ -22,6 +22,8 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
|
|||||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||||
SchedulerOutput)
|
SchedulerOutput)
|
||||||
|
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
|
||||||
|
create_request_queue)
|
||||||
from vllm.v1.core.sched.utils import check_stop
|
from vllm.v1.core.sched.utils import check_stop
|
||||||
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
||||||
EngineCoreOutputs)
|
EngineCoreOutputs)
|
||||||
@ -94,8 +96,16 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# req_id -> Request
|
# req_id -> Request
|
||||||
self.requests: dict[str, Request] = {}
|
self.requests: dict[str, Request] = {}
|
||||||
|
# Scheduling policy
|
||||||
|
if self.scheduler_config.policy == "priority":
|
||||||
|
self.policy = SchedulingPolicy.PRIORITY
|
||||||
|
elif self.scheduler_config.policy == "fcfs":
|
||||||
|
self.policy = SchedulingPolicy.FCFS
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown scheduling policy: {self.scheduler_config.policy}")
|
||||||
# Priority queues for requests.
|
# Priority queues for requests.
|
||||||
self.waiting: deque[Request] = deque()
|
self.waiting = create_request_queue(self.policy)
|
||||||
self.running: list[Request] = []
|
self.running: list[Request] = []
|
||||||
|
|
||||||
# The request IDs that are finished in between the previous and the
|
# The request IDs that are finished in between the previous and the
|
||||||
@ -247,7 +257,15 @@ class Scheduler(SchedulerInterface):
|
|||||||
if new_blocks is None:
|
if new_blocks is None:
|
||||||
# The request cannot be scheduled.
|
# The request cannot be scheduled.
|
||||||
# Preempt the lowest-priority request.
|
# Preempt the lowest-priority request.
|
||||||
preempted_req = self.running.pop()
|
if self.policy == SchedulingPolicy.PRIORITY:
|
||||||
|
preempted_req = max(
|
||||||
|
self.running,
|
||||||
|
key=lambda r: (r.priority, r.arrival_time),
|
||||||
|
)
|
||||||
|
self.running.remove(preempted_req)
|
||||||
|
else:
|
||||||
|
preempted_req = self.running.pop()
|
||||||
|
|
||||||
self.kv_cache_manager.free(preempted_req)
|
self.kv_cache_manager.free(preempted_req)
|
||||||
preempted_req.status = RequestStatus.PREEMPTED
|
preempted_req.status = RequestStatus.PREEMPTED
|
||||||
preempted_req.num_computed_tokens = 0
|
preempted_req.num_computed_tokens = 0
|
||||||
@ -255,7 +273,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
preempted_req.record_event(
|
preempted_req.record_event(
|
||||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
|
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
|
||||||
|
|
||||||
self.waiting.appendleft(preempted_req)
|
self.waiting.prepend_request(preempted_req)
|
||||||
preempted_reqs.append(preempted_req)
|
preempted_reqs.append(preempted_req)
|
||||||
if preempted_req == request:
|
if preempted_req == request:
|
||||||
# No more request to preempt.
|
# No more request to preempt.
|
||||||
@ -311,9 +329,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
if req.lora_request and req.lora_request.lora_int_id > 0)
|
if req.lora_request and req.lora_request.lora_int_id > 0)
|
||||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||||
|
|
||||||
# Use a temporary deque to collect requests that need to be skipped
|
# Use a temporary RequestQueue to collect requests that need to be
|
||||||
# and put back at the head of the waiting queue later
|
# skipped and put back at the head of the waiting queue later
|
||||||
skipped_waiting_requests: deque[Request] = deque()
|
skipped_waiting_requests = create_request_queue(self.policy)
|
||||||
|
|
||||||
# Next, schedule the WAITING requests.
|
# Next, schedule the WAITING requests.
|
||||||
if not preempted_reqs:
|
if not preempted_reqs:
|
||||||
@ -321,7 +339,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
if len(self.running) == self.max_num_running_reqs:
|
if len(self.running) == self.max_num_running_reqs:
|
||||||
break
|
break
|
||||||
|
|
||||||
request = self.waiting[0]
|
request = self.waiting.peek_request()
|
||||||
|
|
||||||
# KVTransfer: skip request if still waiting for remote kvs.
|
# KVTransfer: skip request if still waiting for remote kvs.
|
||||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||||
@ -332,8 +350,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||||
request.request_id)
|
request.request_id)
|
||||||
self.waiting.popleft()
|
self.waiting.pop_request()
|
||||||
skipped_waiting_requests.appendleft(request)
|
skipped_waiting_requests.prepend_request(request)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Skip request if the structured output request is still waiting
|
# Skip request if the structured output request is still waiting
|
||||||
@ -343,19 +361,18 @@ class Scheduler(SchedulerInterface):
|
|||||||
if structured_output_req and structured_output_req.grammar:
|
if structured_output_req and structured_output_req.grammar:
|
||||||
request.status = RequestStatus.WAITING
|
request.status = RequestStatus.WAITING
|
||||||
else:
|
else:
|
||||||
self.waiting.popleft()
|
self.waiting.pop_request()
|
||||||
skipped_waiting_requests.appendleft(request)
|
skipped_waiting_requests.prepend_request(request)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check that adding the request still respects the max_loras
|
# Check that adding the request still respects the max_loras
|
||||||
# constraint.
|
# constraint.
|
||||||
if self.lora_config and request.lora_request and (
|
if (self.lora_config and request.lora_request and
|
||||||
len(scheduled_loras) == self.lora_config.max_loras
|
(len(scheduled_loras) == self.lora_config.max_loras and
|
||||||
and request.lora_request.lora_int_id
|
request.lora_request.lora_int_id not in scheduled_loras)):
|
||||||
not in scheduled_loras):
|
|
||||||
# Scheduling would exceed max_loras, skip.
|
# Scheduling would exceed max_loras, skip.
|
||||||
self.waiting.popleft()
|
self.waiting.pop_request()
|
||||||
skipped_waiting_requests.appendleft(request)
|
skipped_waiting_requests.prepend_request(request)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_external_computed_tokens = 0
|
num_external_computed_tokens = 0
|
||||||
@ -407,8 +424,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
# pooling requests to be chunked
|
# pooling requests to be chunked
|
||||||
if not self.scheduler_config.chunked_prefill_enabled and \
|
if not self.scheduler_config.chunked_prefill_enabled and \
|
||||||
num_new_tokens > token_budget:
|
num_new_tokens > token_budget:
|
||||||
self.waiting.popleft()
|
self.waiting.pop_request()
|
||||||
skipped_waiting_requests.appendleft(request)
|
skipped_waiting_requests.prepend_request(request)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
@ -448,17 +465,19 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_external_computed_tokens,
|
num_external_computed_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waiting.popleft()
|
# Request was already popped from self.waiting
|
||||||
|
# unless it was re-added above due to new_blocks being None.
|
||||||
|
request = self.waiting.pop_request()
|
||||||
if load_kv_async:
|
if load_kv_async:
|
||||||
# If loading async, allocate memory and put request
|
# If loading async, allocate memory and put request
|
||||||
# into the WAITING_FOR_REMOTE_KV state.
|
# into the WAITING_FOR_REMOTE_KV state.
|
||||||
skipped_waiting_requests.appendleft(request)
|
skipped_waiting_requests.prepend_request(request)
|
||||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if request.use_structured_output:
|
if request.use_structured_output:
|
||||||
structured_output_request_ids[
|
structured_output_request_ids[request.request_id] = (
|
||||||
request.request_id] = req_index
|
req_index)
|
||||||
req_index += 1
|
req_index += 1
|
||||||
self.running.append(request)
|
self.running.append(request)
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
@ -494,7 +513,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# Put back any skipped requests at the head of the waiting queue
|
# Put back any skipped requests at the head of the waiting queue
|
||||||
if skipped_waiting_requests:
|
if skipped_waiting_requests:
|
||||||
self.waiting.extendleft(skipped_waiting_requests)
|
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||||
|
|
||||||
# Check if the scheduling constraints are satisfied.
|
# Check if the scheduling constraints are satisfied.
|
||||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||||
@ -896,7 +915,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
return len(self.running), len(self.waiting)
|
return len(self.running), len(self.waiting)
|
||||||
|
|
||||||
def add_request(self, request: Request) -> None:
|
def add_request(self, request: Request) -> None:
|
||||||
self.waiting.append(request)
|
self.waiting.add_request(request)
|
||||||
self.requests[request.request_id] = request
|
self.requests[request.request_id] = request
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
request.record_event(EngineCoreEventType.QUEUED)
|
request.record_event(EngineCoreEventType.QUEUED)
|
||||||
@ -917,16 +936,31 @@ class Scheduler(SchedulerInterface):
|
|||||||
else:
|
else:
|
||||||
request_ids = set(request_ids)
|
request_ids = set(request_ids)
|
||||||
|
|
||||||
|
running_requests_to_remove = []
|
||||||
|
waiting_requests_to_remove = []
|
||||||
|
valid_requests = []
|
||||||
|
|
||||||
|
# First pass: collect requests to remove from queues
|
||||||
for req_id in request_ids:
|
for req_id in request_ids:
|
||||||
request = self.requests.get(req_id)
|
request = self.requests.get(req_id)
|
||||||
if request is None:
|
if request is None:
|
||||||
# Invalid request ID.
|
# Invalid request ID.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
valid_requests.append(request)
|
||||||
if request.status == RequestStatus.RUNNING:
|
if request.status == RequestStatus.RUNNING:
|
||||||
self.running.remove(request)
|
running_requests_to_remove.append(request)
|
||||||
else:
|
else:
|
||||||
self.waiting.remove(request)
|
waiting_requests_to_remove.append(request)
|
||||||
|
|
||||||
|
# Remove all requests from queues at once for better efficiency
|
||||||
|
for request in running_requests_to_remove:
|
||||||
|
self.running.remove(request)
|
||||||
|
if waiting_requests_to_remove:
|
||||||
|
self.waiting.remove_requests(waiting_requests_to_remove)
|
||||||
|
|
||||||
|
# Second pass: set status and free requests
|
||||||
|
for request in valid_requests:
|
||||||
request.status = finished_status
|
request.status = finished_status
|
||||||
self._free_request(request)
|
self._free_request(request)
|
||||||
|
|
||||||
|
|||||||
@ -68,6 +68,7 @@ class EngineCoreRequest(
|
|||||||
# belong to, to cover a race condition where the request is sent before
|
# belong to, to cover a race condition where the request is sent before
|
||||||
# a wave finished notification is received.
|
# a wave finished notification is received.
|
||||||
current_wave: int = 0
|
current_wave: int = 0
|
||||||
|
priority: int = 0
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreEventType(enum.IntEnum):
|
class EngineCoreEventType(enum.IntEnum):
|
||||||
|
|||||||
@ -219,8 +219,6 @@ class Processor:
|
|||||||
# TODO(woosuk): Support encoder-decoder models.
|
# TODO(woosuk): Support encoder-decoder models.
|
||||||
self._validate_lora(lora_request)
|
self._validate_lora(lora_request)
|
||||||
self._validate_params(params, lora_request)
|
self._validate_params(params, lora_request)
|
||||||
if priority != 0:
|
|
||||||
raise ValueError("V1 does not support priority yet.")
|
|
||||||
if trace_headers is not None:
|
if trace_headers is not None:
|
||||||
raise ValueError("V1 does not support tracing yet.")
|
raise ValueError("V1 does not support tracing yet.")
|
||||||
if prompt_adapter_request is not None:
|
if prompt_adapter_request is not None:
|
||||||
@ -340,6 +338,7 @@ class Processor:
|
|||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
cache_salt=decoder_inputs.get("cache_salt"),
|
cache_salt=decoder_inputs.get("cache_salt"),
|
||||||
|
priority=priority,
|
||||||
data_parallel_rank=data_parallel_rank,
|
data_parallel_rank=data_parallel_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
@ -30,18 +31,23 @@ class Request:
|
|||||||
pooling_params: Optional[PoolingParams],
|
pooling_params: Optional[PoolingParams],
|
||||||
eos_token_id: Optional[int],
|
eos_token_id: Optional[int],
|
||||||
client_index: int = 0,
|
client_index: int = 0,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional["LoRARequest"] = None,
|
lora_request: Optional["LoRARequest"] = None,
|
||||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||||
cache_salt: Optional[str] = None,
|
cache_salt: Optional[str] = None,
|
||||||
|
priority: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.client_index = client_index
|
self.client_index = client_index
|
||||||
|
self.priority = priority
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.pooling_params = pooling_params
|
self.pooling_params = pooling_params
|
||||||
# Because of LoRA, the eos token id can be different for each request.
|
# Because of LoRA, the eos token id can be different for each request.
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.structured_output_request = structured_output_request
|
self.structured_output_request = structured_output_request
|
||||||
|
self.arrival_time = arrival_time if arrival_time is not None else \
|
||||||
|
time.time()
|
||||||
|
|
||||||
self.status = RequestStatus.WAITING
|
self.status = RequestStatus.WAITING
|
||||||
if sampling_params and sampling_params.guided_decoding is not None:
|
if sampling_params and sampling_params.guided_decoding is not None:
|
||||||
@ -118,11 +124,13 @@ class Request:
|
|||||||
sampling_params=request.sampling_params,
|
sampling_params=request.sampling_params,
|
||||||
pooling_params=request.pooling_params,
|
pooling_params=request.pooling_params,
|
||||||
eos_token_id=request.eos_token_id,
|
eos_token_id=request.eos_token_id,
|
||||||
|
arrival_time=request.arrival_time,
|
||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
structured_output_request=StructuredOutputRequest(
|
structured_output_request=StructuredOutputRequest(
|
||||||
sampling_params=request.sampling_params) \
|
sampling_params=request.sampling_params) \
|
||||||
if request.sampling_params else None,
|
if request.sampling_params else None,
|
||||||
cache_salt=request.cache_salt,
|
cache_salt=request.cache_salt,
|
||||||
|
priority=request.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
def append_output_token_ids(
|
def append_output_token_ids(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user