mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:06:14 +08:00
Implement Async Scheduling (#19970)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
85bd6599e4
commit
d4d309409f
0
tests/v1/core/__init__.py
Normal file
0
tests/v1/core/__init__.py
Normal file
228
tests/v1/core/test_async_scheduler.py
Normal file
228
tests/v1/core/test_async_scheduler.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.request import RequestStatus
|
||||||
|
|
||||||
|
from .utils import create_requests, create_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def _make_model_runner_output(
|
||||||
|
scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput:
|
||||||
|
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
|
||||||
|
return ModelRunnerOutput(
|
||||||
|
req_ids=req_ids,
|
||||||
|
req_id_to_index={
|
||||||
|
req_id: i
|
||||||
|
for i, req_id in enumerate(req_ids)
|
||||||
|
},
|
||||||
|
sampled_token_ids=[[i] for i in range(len(req_ids))],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("max_tokens", [1, 2, 3, 5])
|
||||||
|
def test_stop_by_max_tokens(max_tokens: int):
|
||||||
|
scheduler = create_scheduler(async_scheduling=True)
|
||||||
|
requests = create_requests(num_requests=2, max_tokens=max_tokens)
|
||||||
|
req0, req1 = requests
|
||||||
|
|
||||||
|
sched_outputs: deque[SchedulerOutput] = deque()
|
||||||
|
scheduler.add_request(req0)
|
||||||
|
sched_outputs.append(scheduler.schedule())
|
||||||
|
|
||||||
|
scheduler.add_request(req1)
|
||||||
|
sched_outputs.append(scheduler.schedule())
|
||||||
|
|
||||||
|
while sched_outputs:
|
||||||
|
sched_output = sched_outputs.popleft()
|
||||||
|
model_runner_output = _make_model_runner_output(sched_output)
|
||||||
|
scheduler.update_from_output(sched_output, model_runner_output)
|
||||||
|
|
||||||
|
sched_output = scheduler.schedule()
|
||||||
|
if sched_output.num_scheduled_tokens:
|
||||||
|
sched_outputs.append(sched_output)
|
||||||
|
|
||||||
|
assert scheduler.get_num_unfinished_requests() == 0
|
||||||
|
assert req0.num_output_tokens == max_tokens
|
||||||
|
assert req1.num_output_tokens == max_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def test_abort():
|
||||||
|
scheduler = create_scheduler(async_scheduling=True)
|
||||||
|
requests = create_requests(num_requests=10, max_tokens=20)
|
||||||
|
|
||||||
|
for req in requests:
|
||||||
|
scheduler.add_request(req)
|
||||||
|
|
||||||
|
sched_outputs: deque[SchedulerOutput] = deque()
|
||||||
|
sched_outputs.append(scheduler.schedule())
|
||||||
|
sched_outputs.append(scheduler.schedule())
|
||||||
|
|
||||||
|
abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9]
|
||||||
|
abort_order_copy = abort_order.copy()
|
||||||
|
|
||||||
|
def abort_request():
|
||||||
|
if not abort_order:
|
||||||
|
return
|
||||||
|
req = requests[abort_order.pop(0)]
|
||||||
|
scheduler.finish_requests(req.request_id,
|
||||||
|
RequestStatus.FINISHED_ABORTED)
|
||||||
|
|
||||||
|
while sched_outputs:
|
||||||
|
# Abort a scheduled request.
|
||||||
|
abort_request()
|
||||||
|
sched_output = sched_outputs.popleft()
|
||||||
|
model_runner_output = _make_model_runner_output(sched_output)
|
||||||
|
scheduler.update_from_output(sched_output, model_runner_output)
|
||||||
|
|
||||||
|
sched_output = scheduler.schedule()
|
||||||
|
if sched_output.num_scheduled_tokens:
|
||||||
|
sched_outputs.append(sched_output)
|
||||||
|
|
||||||
|
for i, req in enumerate(requests):
|
||||||
|
assert req.status == RequestStatus.FINISHED_ABORTED
|
||||||
|
assert req.num_output_tokens == abort_order_copy.index(i)
|
||||||
|
|
||||||
|
|
||||||
|
def test_preempt():
|
||||||
|
scheduler = create_scheduler(async_scheduling=True)
|
||||||
|
requests = create_requests(num_requests=10, max_tokens=20)
|
||||||
|
|
||||||
|
for req in requests:
|
||||||
|
scheduler.add_request(req)
|
||||||
|
|
||||||
|
sched_outputs: deque[SchedulerOutput] = deque()
|
||||||
|
sched_outputs.append(scheduler.schedule())
|
||||||
|
sched_outputs.append(scheduler.schedule())
|
||||||
|
|
||||||
|
abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9]
|
||||||
|
abort_order_copy = abort_order.copy()
|
||||||
|
|
||||||
|
def abort_request():
|
||||||
|
if not abort_order:
|
||||||
|
return
|
||||||
|
req = requests[abort_order.pop(0)]
|
||||||
|
scheduler.finish_requests(req.request_id,
|
||||||
|
RequestStatus.FINISHED_ABORTED)
|
||||||
|
|
||||||
|
while sched_outputs:
|
||||||
|
# Abort a scheduled request.
|
||||||
|
abort_request()
|
||||||
|
sched_output = sched_outputs.popleft()
|
||||||
|
model_runner_output = _make_model_runner_output(sched_output)
|
||||||
|
scheduler.update_from_output(sched_output, model_runner_output)
|
||||||
|
|
||||||
|
sched_output = scheduler.schedule()
|
||||||
|
if sched_output.num_scheduled_tokens:
|
||||||
|
sched_outputs.append(sched_output)
|
||||||
|
|
||||||
|
for i, req in enumerate(requests):
|
||||||
|
assert req.status == RequestStatus.FINISHED_ABORTED
|
||||||
|
assert req.num_output_tokens == abort_order_copy.index(i)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix_caching_for_prefill_dedup():
|
||||||
|
CHUNK_SIZE = 1000
|
||||||
|
BLOCK_SIZE = 16
|
||||||
|
num_prompt_tokens = 100
|
||||||
|
scheduler = create_scheduler(async_scheduling=True,
|
||||||
|
max_num_batched_tokens=CHUNK_SIZE,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
block_size=BLOCK_SIZE)
|
||||||
|
requests = create_requests(num_requests=5,
|
||||||
|
num_tokens=num_prompt_tokens,
|
||||||
|
max_tokens=3,
|
||||||
|
same_prompt=True)
|
||||||
|
requests_copy = requests.copy()
|
||||||
|
|
||||||
|
# Two requests with the same prompt.
|
||||||
|
req0 = requests.pop(0)
|
||||||
|
req1 = requests.pop(0)
|
||||||
|
scheduler.add_request(req0)
|
||||||
|
scheduler.add_request(req1)
|
||||||
|
|
||||||
|
sched_outputs: deque[SchedulerOutput] = deque()
|
||||||
|
sched_output = scheduler.schedule()
|
||||||
|
sched_outputs.append(sched_output)
|
||||||
|
# Make sure prefix caching de-duplicates the prompts in the same step,
|
||||||
|
# so all the blocks except the last are shared between the two requests.
|
||||||
|
assert len(sched_output.num_scheduled_tokens) == 2
|
||||||
|
num_blocks = num_prompt_tokens // BLOCK_SIZE
|
||||||
|
assert req0.num_cached_tokens == 0
|
||||||
|
assert req1.num_cached_tokens >= num_blocks * BLOCK_SIZE
|
||||||
|
|
||||||
|
sched_outputs.append(scheduler.schedule())
|
||||||
|
while sched_outputs:
|
||||||
|
if requests:
|
||||||
|
scheduler.add_request(requests.pop(0))
|
||||||
|
sched_output = sched_outputs.popleft()
|
||||||
|
model_runner_output = _make_model_runner_output(sched_output)
|
||||||
|
scheduler.update_from_output(sched_output, model_runner_output)
|
||||||
|
sched_output = scheduler.schedule()
|
||||||
|
if sched_output.num_scheduled_tokens:
|
||||||
|
sched_outputs.append(sched_output)
|
||||||
|
|
||||||
|
# Other requests scheduled after the two requests should also get
|
||||||
|
# prefix cache hit.
|
||||||
|
assert scheduler.get_num_unfinished_requests() == 0
|
||||||
|
for req in requests_copy[1:]:
|
||||||
|
assert req.num_cached_tokens >= num_blocks * BLOCK_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix_caching_for_multi_turn():
|
||||||
|
CHUNK_SIZE = 1000
|
||||||
|
BLOCK_SIZE = 16
|
||||||
|
num_prompt_tokens = 100
|
||||||
|
num_output_tokens = 200
|
||||||
|
scheduler = create_scheduler(async_scheduling=True,
|
||||||
|
max_num_batched_tokens=CHUNK_SIZE,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
block_size=BLOCK_SIZE)
|
||||||
|
requests = create_requests(num_requests=5,
|
||||||
|
num_tokens=num_prompt_tokens,
|
||||||
|
max_tokens=num_output_tokens)
|
||||||
|
|
||||||
|
for req in requests:
|
||||||
|
scheduler.add_request(req)
|
||||||
|
sched_outputs: deque[SchedulerOutput] = deque()
|
||||||
|
sched_outputs.append(scheduler.schedule())
|
||||||
|
sched_outputs.append(scheduler.schedule())
|
||||||
|
|
||||||
|
# Process the requests.
|
||||||
|
while sched_outputs:
|
||||||
|
sched_output = sched_outputs.popleft()
|
||||||
|
model_runner_output = _make_model_runner_output(sched_output)
|
||||||
|
scheduler.update_from_output(sched_output, model_runner_output)
|
||||||
|
sched_output = scheduler.schedule()
|
||||||
|
if sched_output.num_scheduled_tokens:
|
||||||
|
sched_outputs.append(sched_output)
|
||||||
|
assert scheduler.get_num_unfinished_requests() == 0
|
||||||
|
|
||||||
|
# Create next-turn requests whose prompts are the full output of the
|
||||||
|
# previous turn.
|
||||||
|
next_turn_requests = create_requests(
|
||||||
|
num_requests=5,
|
||||||
|
num_tokens=num_prompt_tokens + num_output_tokens,
|
||||||
|
max_tokens=num_output_tokens,
|
||||||
|
)
|
||||||
|
for i, req in enumerate(next_turn_requests):
|
||||||
|
req.prompt_token_ids = (requests[i].prompt_token_ids +
|
||||||
|
list(requests[i].output_token_ids))
|
||||||
|
# Schedule the next-turn requests.
|
||||||
|
for req in next_turn_requests:
|
||||||
|
scheduler.add_request(req)
|
||||||
|
sched_outputs.append(scheduler.schedule())
|
||||||
|
|
||||||
|
# Make sure the next-turn requests get prefix cache hit by the previous
|
||||||
|
# requests.
|
||||||
|
for req in next_turn_requests:
|
||||||
|
assert (req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE *
|
||||||
|
BLOCK_SIZE)
|
||||||
@ -19,133 +19,7 @@ from vllm.v1.request import Request, RequestStatus
|
|||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||||
|
|
||||||
EOS_TOKEN_ID = 50256
|
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
|
||||||
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
model: str = "facebook/opt-125m",
|
|
||||||
max_num_seqs: int = 16,
|
|
||||||
max_num_batched_tokens: int = 8192,
|
|
||||||
enable_prefix_caching: Optional[bool] = None,
|
|
||||||
long_prefill_token_threshold: int = 0,
|
|
||||||
disable_chunked_mm_input: bool = False,
|
|
||||||
use_kv_connector: bool = False,
|
|
||||||
num_blocks: int = 10000,
|
|
||||||
block_size: int = 16,
|
|
||||||
max_model_len: Optional[int] = None,
|
|
||||||
num_speculative_tokens: Optional[int] = None,
|
|
||||||
skip_tokenizer_init: bool = False,
|
|
||||||
) -> Scheduler:
|
|
||||||
'''Create scheduler under test.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: model under test
|
|
||||||
max_num_seqs: max sequences to schedule
|
|
||||||
max_num_batch_tokens: max num tokens to batch
|
|
||||||
enable_prefix_caching: optionally force APC config
|
|
||||||
(True/False) or use default
|
|
||||||
(None)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
{class}`Scheduler` instance
|
|
||||||
'''
|
|
||||||
if max_model_len is None:
|
|
||||||
max_model_len = max_num_batched_tokens
|
|
||||||
scheduler_config = SchedulerConfig(
|
|
||||||
max_num_seqs=max_num_seqs,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
max_model_len=max_model_len,
|
|
||||||
long_prefill_token_threshold=long_prefill_token_threshold,
|
|
||||||
disable_chunked_mm_input=disable_chunked_mm_input,
|
|
||||||
enable_chunked_prefill=True,
|
|
||||||
)
|
|
||||||
model_config = ModelConfig(
|
|
||||||
model=model,
|
|
||||||
task="auto",
|
|
||||||
tokenizer=model,
|
|
||||||
tokenizer_mode="auto",
|
|
||||||
trust_remote_code=True,
|
|
||||||
dtype="float16",
|
|
||||||
seed=42,
|
|
||||||
skip_tokenizer_init=skip_tokenizer_init,
|
|
||||||
)
|
|
||||||
# Cache config, optionally force APC
|
|
||||||
kwargs_cache = ({} if enable_prefix_caching is None else {
|
|
||||||
'enable_prefix_caching': enable_prefix_caching
|
|
||||||
})
|
|
||||||
cache_config = CacheConfig(
|
|
||||||
block_size=block_size,
|
|
||||||
gpu_memory_utilization=0.9,
|
|
||||||
swap_space=0,
|
|
||||||
cache_dtype="auto",
|
|
||||||
**kwargs_cache,
|
|
||||||
)
|
|
||||||
kv_transfer_config = KVTransferConfig(
|
|
||||||
kv_connector="SharedStorageConnector",
|
|
||||||
kv_role="kv_both",
|
|
||||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
|
||||||
) if use_kv_connector else None
|
|
||||||
|
|
||||||
speculative_config: Optional[SpeculativeConfig] = None
|
|
||||||
if num_speculative_tokens is not None:
|
|
||||||
speculative_config = SpeculativeConfig(
|
|
||||||
model="ngram", num_speculative_tokens=num_speculative_tokens)
|
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
|
||||||
scheduler_config=scheduler_config,
|
|
||||||
model_config=model_config,
|
|
||||||
cache_config=cache_config,
|
|
||||||
kv_transfer_config=kv_transfer_config,
|
|
||||||
speculative_config=speculative_config,
|
|
||||||
)
|
|
||||||
kv_cache_config = KVCacheConfig(
|
|
||||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
|
||||||
kv_cache_tensors=[],
|
|
||||||
kv_cache_groups=[
|
|
||||||
KVCacheGroupSpec(['layer'],
|
|
||||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
|
||||||
False))
|
|
||||||
],
|
|
||||||
)
|
|
||||||
cache_config.num_gpu_blocks = num_blocks
|
|
||||||
return Scheduler(
|
|
||||||
vllm_config=vllm_config,
|
|
||||||
kv_cache_config=kv_cache_config,
|
|
||||||
log_stats=True,
|
|
||||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_requests(num_requests: int,
|
|
||||||
num_tokens: int = 10,
|
|
||||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
|
||||||
max_tokens: int = 16,
|
|
||||||
stop_token_ids: Optional[list[int]] = None,
|
|
||||||
prompt_logprobs: Optional[int] = None):
|
|
||||||
sampling_params = SamplingParams(ignore_eos=False,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
stop_token_ids=stop_token_ids,
|
|
||||||
prompt_logprobs=prompt_logprobs)
|
|
||||||
requests = []
|
|
||||||
for i in range(num_requests):
|
|
||||||
if mm_positions is not None:
|
|
||||||
mm_position = mm_positions[i]
|
|
||||||
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
|
||||||
else:
|
|
||||||
mm_position = None
|
|
||||||
mm_inputs = None
|
|
||||||
request = Request(
|
|
||||||
request_id=f"{i}",
|
|
||||||
prompt_token_ids=[i] * num_tokens,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
pooling_params=None,
|
|
||||||
multi_modal_inputs=mm_inputs,
|
|
||||||
multi_modal_placeholders=mm_position,
|
|
||||||
multi_modal_hashes=None,
|
|
||||||
eos_token_id=EOS_TOKEN_ID,
|
|
||||||
)
|
|
||||||
requests.append(request)
|
|
||||||
return requests
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_requests():
|
def test_add_requests():
|
||||||
|
|||||||
152
tests/v1/core/utils.py
Normal file
152
tests/v1/core/utils.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||||
|
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||||
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||||
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
|
KVCacheGroupSpec)
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
|
EOS_TOKEN_ID = 50256
|
||||||
|
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
model: str = "facebook/opt-125m",
|
||||||
|
max_num_seqs: int = 16,
|
||||||
|
max_num_batched_tokens: int = 8192,
|
||||||
|
enable_prefix_caching: Optional[bool] = None,
|
||||||
|
long_prefill_token_threshold: int = 0,
|
||||||
|
disable_chunked_mm_input: bool = False,
|
||||||
|
use_kv_connector: bool = False,
|
||||||
|
num_blocks: int = 10000,
|
||||||
|
block_size: int = 16,
|
||||||
|
max_model_len: Optional[int] = None,
|
||||||
|
num_speculative_tokens: Optional[int] = None,
|
||||||
|
skip_tokenizer_init: bool = False,
|
||||||
|
async_scheduling: bool = False,
|
||||||
|
) -> Union[Scheduler, AsyncScheduler]:
|
||||||
|
'''Create scheduler under test.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: model under test
|
||||||
|
max_num_seqs: max sequences to schedule
|
||||||
|
max_num_batch_tokens: max num tokens to batch
|
||||||
|
enable_prefix_caching: optionally force APC config
|
||||||
|
(True/False) or use default
|
||||||
|
(None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{class}`Scheduler` instance
|
||||||
|
'''
|
||||||
|
if max_model_len is None:
|
||||||
|
max_model_len = max_num_batched_tokens
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||||
|
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
async_scheduling=async_scheduling,
|
||||||
|
)
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=model,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="float16",
|
||||||
|
seed=42,
|
||||||
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
|
)
|
||||||
|
# Cache config, optionally force APC
|
||||||
|
kwargs_cache = ({} if enable_prefix_caching is None else {
|
||||||
|
'enable_prefix_caching': enable_prefix_caching
|
||||||
|
})
|
||||||
|
cache_config = CacheConfig(
|
||||||
|
block_size=block_size,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
swap_space=0,
|
||||||
|
cache_dtype="auto",
|
||||||
|
**kwargs_cache,
|
||||||
|
)
|
||||||
|
kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="SharedStorageConnector",
|
||||||
|
kv_role="kv_both",
|
||||||
|
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||||
|
) if use_kv_connector else None
|
||||||
|
|
||||||
|
speculative_config: Optional[SpeculativeConfig] = None
|
||||||
|
if num_speculative_tokens is not None:
|
||||||
|
speculative_config = SpeculativeConfig(
|
||||||
|
model="ngram", num_speculative_tokens=num_speculative_tokens)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
)
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(['layer'],
|
||||||
|
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||||
|
False))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cache_config.num_gpu_blocks = num_blocks
|
||||||
|
scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
|
||||||
|
return scheduler_cls(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
log_stats=True,
|
||||||
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_requests(
|
||||||
|
num_requests: int,
|
||||||
|
num_tokens: int = 10,
|
||||||
|
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||||
|
max_tokens: int = 16,
|
||||||
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
|
prompt_logprobs: Optional[int] = None,
|
||||||
|
same_prompt: bool = False,
|
||||||
|
) -> list[Request]:
|
||||||
|
sampling_params = SamplingParams(ignore_eos=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
prompt_logprobs=prompt_logprobs)
|
||||||
|
requests = []
|
||||||
|
for i in range(num_requests):
|
||||||
|
if mm_positions is not None:
|
||||||
|
mm_position = mm_positions[i]
|
||||||
|
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||||
|
else:
|
||||||
|
mm_position = None
|
||||||
|
mm_inputs = None
|
||||||
|
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
|
||||||
|
num_tokens)
|
||||||
|
request = Request(
|
||||||
|
request_id=f"{i}",
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
pooling_params=None,
|
||||||
|
multi_modal_inputs=mm_inputs,
|
||||||
|
multi_modal_placeholders=mm_position,
|
||||||
|
multi_modal_hashes=None,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
)
|
||||||
|
requests.append(request)
|
||||||
|
return requests
|
||||||
@ -2308,6 +2308,13 @@ class SchedulerConfig:
|
|||||||
like full attention and sliding window attention.
|
like full attention and sliding window attention.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async_scheduling: bool = False
|
||||||
|
"""EXPERIMENTAL: If set to True, perform async scheduling. This may help
|
||||||
|
reduce the CPU overheads, leading to better latency and throughput. However,
|
||||||
|
async scheduling is currently not supported with some features such as
|
||||||
|
structured outputs, speculative decoding, and pipeline parallelism.
|
||||||
|
"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
@ -2401,6 +2408,10 @@ class SchedulerConfig:
|
|||||||
if not self.cuda_graph_sizes:
|
if not self.cuda_graph_sizes:
|
||||||
self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
|
self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
|
||||||
|
|
||||||
|
if self.async_scheduling:
|
||||||
|
self.scheduler_cls = (
|
||||||
|
"vllm.v1.core.sched.async_scheduler.AsyncScheduler")
|
||||||
|
|
||||||
@model_validator(mode='after')
|
@model_validator(mode='after')
|
||||||
def _verify_args(self) -> Self:
|
def _verify_args(self) -> Self:
|
||||||
if (self.max_num_batched_tokens < self.max_model_len
|
if (self.max_num_batched_tokens < self.max_model_len
|
||||||
|
|||||||
@ -484,6 +484,8 @@ class EngineArgs:
|
|||||||
enable_multimodal_encoder_data_parallel: bool = \
|
enable_multimodal_encoder_data_parallel: bool = \
|
||||||
ParallelConfig.enable_multimodal_encoder_data_parallel
|
ParallelConfig.enable_multimodal_encoder_data_parallel
|
||||||
|
|
||||||
|
async_scheduling: bool = SchedulerConfig.async_scheduling
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# support `EngineArgs(compilation_config={...})`
|
# support `EngineArgs(compilation_config={...})`
|
||||||
# without having to manually construct a
|
# without having to manually construct a
|
||||||
@ -921,6 +923,8 @@ class EngineArgs:
|
|||||||
scheduler_group.add_argument(
|
scheduler_group.add_argument(
|
||||||
"--disable-hybrid-kv-cache-manager",
|
"--disable-hybrid-kv-cache-manager",
|
||||||
**scheduler_kwargs["disable_hybrid_kv_cache_manager"])
|
**scheduler_kwargs["disable_hybrid_kv_cache_manager"])
|
||||||
|
scheduler_group.add_argument("--async-scheduling",
|
||||||
|
**scheduler_kwargs["async_scheduling"])
|
||||||
|
|
||||||
# vLLM arguments
|
# vLLM arguments
|
||||||
vllm_kwargs = get_kwargs(VllmConfig)
|
vllm_kwargs = get_kwargs(VllmConfig)
|
||||||
@ -1206,6 +1210,26 @@ class EngineArgs:
|
|||||||
self.data_parallel_rpc_port
|
self.data_parallel_rpc_port
|
||||||
is not None) else ParallelConfig.data_parallel_rpc_port
|
is not None) else ParallelConfig.data_parallel_rpc_port
|
||||||
|
|
||||||
|
if self.async_scheduling:
|
||||||
|
# Async scheduling does not work with the uniprocess backend.
|
||||||
|
if self.distributed_executor_backend is None:
|
||||||
|
self.distributed_executor_backend = "mp"
|
||||||
|
logger.info("Using mp-based distributed executor backend "
|
||||||
|
"for async scheduling.")
|
||||||
|
if self.distributed_executor_backend == "uni":
|
||||||
|
raise ValueError("Async scheduling is not supported with "
|
||||||
|
"uni-process backend.")
|
||||||
|
if self.pipeline_parallel_size > 1:
|
||||||
|
raise ValueError("Async scheduling is not supported with "
|
||||||
|
"pipeline-parallel-size > 1.")
|
||||||
|
|
||||||
|
# Currently, async scheduling does not support speculative decoding.
|
||||||
|
# TODO(woosuk): Support it.
|
||||||
|
if self.speculative_config is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently, speculative decoding is not supported with "
|
||||||
|
"async scheduling.")
|
||||||
|
|
||||||
parallel_config = ParallelConfig(
|
parallel_config = ParallelConfig(
|
||||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||||
tensor_parallel_size=self.tensor_parallel_size,
|
tensor_parallel_size=self.tensor_parallel_size,
|
||||||
@ -1286,6 +1310,7 @@ class EngineArgs:
|
|||||||
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||||||
disable_hybrid_kv_cache_manager=self.
|
disable_hybrid_kv_cache_manager=self.
|
||||||
disable_hybrid_kv_cache_manager,
|
disable_hybrid_kv_cache_manager,
|
||||||
|
async_scheduling=self.async_scheduling,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not model_config.is_multimodal_model and self.default_mm_loras:
|
if not model_config.is_multimodal_model and self.default_mm_loras:
|
||||||
|
|||||||
47
vllm/v1/core/sched/async_scheduler.py
Normal file
47
vllm/v1/core/sched/async_scheduler.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncScheduler(Scheduler):
|
||||||
|
|
||||||
|
def _update_after_schedule(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
) -> None:
|
||||||
|
super()._update_after_schedule(scheduler_output)
|
||||||
|
for req_id in scheduler_output.num_scheduled_tokens:
|
||||||
|
request = self.requests[req_id]
|
||||||
|
if (request.num_computed_tokens == request.num_tokens +
|
||||||
|
request.num_output_placeholders):
|
||||||
|
# The request will generate a new token in this scheduling step.
|
||||||
|
# TODO(woosuk): Support speculative decoding.
|
||||||
|
request.num_output_placeholders += 1
|
||||||
|
|
||||||
|
def _update_request_with_output(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
new_token_ids: list[int],
|
||||||
|
) -> tuple[list[int], bool]:
|
||||||
|
status_before_update = request.status
|
||||||
|
new_token_ids, stopped = super()._update_request_with_output(
|
||||||
|
request, new_token_ids)
|
||||||
|
|
||||||
|
# Update the number of output placeholders.
|
||||||
|
request.num_output_placeholders -= len(new_token_ids)
|
||||||
|
assert request.num_output_placeholders >= 0
|
||||||
|
|
||||||
|
# Cache the new tokens. Preempted requests should be skipped.
|
||||||
|
if status_before_update == RequestStatus.RUNNING:
|
||||||
|
self.kv_cache_manager.cache_blocks(
|
||||||
|
request,
|
||||||
|
request.num_computed_tokens - request.num_output_placeholders)
|
||||||
|
return new_token_ids, stopped
|
||||||
@ -204,7 +204,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
while req_index < len(self.running) and token_budget > 0:
|
while req_index < len(self.running) and token_budget > 0:
|
||||||
request = self.running[req_index]
|
request = self.running[req_index]
|
||||||
|
|
||||||
num_new_tokens = (request.num_tokens_with_spec -
|
num_new_tokens = (request.num_tokens_with_spec +
|
||||||
|
request.num_output_placeholders -
|
||||||
request.num_computed_tokens)
|
request.num_computed_tokens)
|
||||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||||
num_new_tokens):
|
num_new_tokens):
|
||||||
@ -230,9 +231,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
if num_new_tokens == 0:
|
if num_new_tokens == 0:
|
||||||
# The request cannot be scheduled because one of the following
|
# The request cannot be scheduled because one of the following
|
||||||
# reasons:
|
# reasons:
|
||||||
# 1. No new tokens to schedule. This may happen when PP>1 and
|
# 1. No new tokens to schedule. This may happen when
|
||||||
# we have already scheduled all prompt tokens but they are
|
# (1) PP>1 and we have already scheduled all prompt tokens
|
||||||
# not finished yet.
|
# but they are not finished yet.
|
||||||
|
# (2) Async scheduling and the request has reached to either
|
||||||
|
# its max_total_tokens or max_model_len.
|
||||||
# 2. The encoder budget is exhausted.
|
# 2. The encoder budget is exhausted.
|
||||||
# 3. The encoder cache is exhausted.
|
# 3. The encoder cache is exhausted.
|
||||||
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||||
@ -598,6 +601,14 @@ class Scheduler(SchedulerInterface):
|
|||||||
request = self.requests[req_id]
|
request = self.requests[req_id]
|
||||||
request.num_computed_tokens += num_scheduled_token
|
request.num_computed_tokens += num_scheduled_token
|
||||||
|
|
||||||
|
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which
|
||||||
|
# may be updated again in _update_from_output for speculative
|
||||||
|
# decoding. However, it is safe to call the method here because
|
||||||
|
# encoder inputs are always part of the prompt, not the output,
|
||||||
|
# and thus are unaffected by speculative decoding.
|
||||||
|
if request.has_encoder_inputs:
|
||||||
|
self._free_encoder_inputs(request)
|
||||||
|
|
||||||
# Clear the finished request IDs.
|
# Clear the finished request IDs.
|
||||||
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
|
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
|
||||||
# it will also affect the scheduler output.
|
# it will also affect the scheduler output.
|
||||||
@ -785,29 +796,16 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_draft_tokens=len(scheduled_spec_token_ids),
|
num_draft_tokens=len(scheduled_spec_token_ids),
|
||||||
num_accepted_tokens=len(generated_token_ids) - 1)
|
num_accepted_tokens=len(generated_token_ids) - 1)
|
||||||
|
|
||||||
# NOTE(woosuk): This has to be executed after updating
|
|
||||||
# `request.num_computed_tokens`.
|
|
||||||
if request.has_encoder_inputs:
|
|
||||||
self._free_encoder_inputs(request)
|
|
||||||
|
|
||||||
stopped = False
|
stopped = False
|
||||||
new_logprobs = None
|
new_logprobs = None
|
||||||
new_token_ids = generated_token_ids
|
new_token_ids = generated_token_ids
|
||||||
kv_transfer_params = None
|
kv_transfer_params = None
|
||||||
status_before_stop = request.status
|
status_before_stop = request.status
|
||||||
|
|
||||||
# Append generated tokens and check for stop. Note that if
|
# Check for stop and update request status.
|
||||||
# a request is still being prefilled, we expect the model runner
|
if new_token_ids:
|
||||||
# to return empty token ids for the request.
|
new_token_ids, stopped = self._update_request_with_output(
|
||||||
for num_new, output_token_id in enumerate(new_token_ids, 1):
|
request, new_token_ids)
|
||||||
request.append_output_token_ids(output_token_id)
|
|
||||||
|
|
||||||
# Check for stop and update request state.
|
|
||||||
# This must be called before we make the EngineCoreOutput.
|
|
||||||
stopped = check_stop(request, self.max_model_len)
|
|
||||||
if stopped:
|
|
||||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
|
||||||
break
|
|
||||||
|
|
||||||
# Stop checking for pooler models.
|
# Stop checking for pooler models.
|
||||||
pooler_output = None
|
pooler_output = None
|
||||||
@ -915,6 +913,26 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
return engine_core_outputs
|
return engine_core_outputs
|
||||||
|
|
||||||
|
def _update_request_with_output(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
new_token_ids: list[int],
|
||||||
|
) -> tuple[list[int], bool]:
|
||||||
|
# Append generated tokens and check for stop. Note that if
|
||||||
|
# a request is still being prefilled, we expect the model runner
|
||||||
|
# to return empty token ids for the request.
|
||||||
|
stopped = False
|
||||||
|
for num_new, output_token_id in enumerate(new_token_ids, 1):
|
||||||
|
request.append_output_token_ids(output_token_id)
|
||||||
|
|
||||||
|
# Check for stop and update request state.
|
||||||
|
# This must be called before we make the EngineCoreOutput.
|
||||||
|
stopped = check_stop(request, self.max_model_len)
|
||||||
|
if stopped:
|
||||||
|
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||||
|
break
|
||||||
|
return new_token_ids, stopped
|
||||||
|
|
||||||
def _free_encoder_inputs(self, request: Request) -> None:
|
def _free_encoder_inputs(self, request: Request) -> None:
|
||||||
cached_encoder_input_ids = (
|
cached_encoder_input_ids = (
|
||||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||||
|
|||||||
@ -367,6 +367,8 @@ class MultiprocExecutor(Executor):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def max_concurrent_batches(self) -> int:
|
def max_concurrent_batches(self) -> int:
|
||||||
|
if self.scheduler_config.async_scheduling:
|
||||||
|
return 2
|
||||||
return self.parallel_config.pipeline_parallel_size
|
return self.parallel_config.pipeline_parallel_size
|
||||||
|
|
||||||
def _get_output_rank(self) -> int:
|
def _get_output_rank(self) -> int:
|
||||||
|
|||||||
@ -33,6 +33,8 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
|||||||
"""Ray distributed executor supports pipeline parallelism,
|
"""Ray distributed executor supports pipeline parallelism,
|
||||||
meaning that it allows PP size batches to be executed concurrently.
|
meaning that it allows PP size batches to be executed concurrently.
|
||||||
"""
|
"""
|
||||||
|
if self.scheduler_config.async_scheduling:
|
||||||
|
return 2
|
||||||
return self.parallel_config.pipeline_parallel_size
|
return self.parallel_config.pipeline_parallel_size
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
|
|||||||
@ -77,6 +77,7 @@ class Request:
|
|||||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||||
self._output_token_ids: list[int] = []
|
self._output_token_ids: list[int] = []
|
||||||
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
|
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
|
||||||
|
self.num_output_placeholders = 0 # Used in async scheduling.
|
||||||
self.spec_token_ids: list[int] = []
|
self.spec_token_ids: list[int] = []
|
||||||
self.num_computed_tokens = 0
|
self.num_computed_tokens = 0
|
||||||
self.cache_salt: Optional[str] = cache_salt
|
self.cache_salt: Optional[str] = cache_salt
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user