Implement Async Scheduling (#19970)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-07-14 23:01:46 -07:00 committed by GitHub
parent 85bd6599e4
commit d4d309409f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 508 additions and 148 deletions

View File

View File

@ -0,0 +1,228 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import deque
import pytest
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import RequestStatus
from .utils import create_requests, create_scheduler
def _make_model_runner_output(
scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput:
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index={
req_id: i
for i, req_id in enumerate(req_ids)
},
sampled_token_ids=[[i] for i in range(len(req_ids))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
@pytest.mark.parametrize("max_tokens", [1, 2, 3, 5])
def test_stop_by_max_tokens(max_tokens: int):
scheduler = create_scheduler(async_scheduling=True)
requests = create_requests(num_requests=2, max_tokens=max_tokens)
req0, req1 = requests
sched_outputs: deque[SchedulerOutput] = deque()
scheduler.add_request(req0)
sched_outputs.append(scheduler.schedule())
scheduler.add_request(req1)
sched_outputs.append(scheduler.schedule())
while sched_outputs:
sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)
sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output)
assert scheduler.get_num_unfinished_requests() == 0
assert req0.num_output_tokens == max_tokens
assert req1.num_output_tokens == max_tokens
def test_abort():
scheduler = create_scheduler(async_scheduling=True)
requests = create_requests(num_requests=10, max_tokens=20)
for req in requests:
scheduler.add_request(req)
sched_outputs: deque[SchedulerOutput] = deque()
sched_outputs.append(scheduler.schedule())
sched_outputs.append(scheduler.schedule())
abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9]
abort_order_copy = abort_order.copy()
def abort_request():
if not abort_order:
return
req = requests[abort_order.pop(0)]
scheduler.finish_requests(req.request_id,
RequestStatus.FINISHED_ABORTED)
while sched_outputs:
# Abort a scheduled request.
abort_request()
sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)
sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output)
for i, req in enumerate(requests):
assert req.status == RequestStatus.FINISHED_ABORTED
assert req.num_output_tokens == abort_order_copy.index(i)
def test_preempt():
scheduler = create_scheduler(async_scheduling=True)
requests = create_requests(num_requests=10, max_tokens=20)
for req in requests:
scheduler.add_request(req)
sched_outputs: deque[SchedulerOutput] = deque()
sched_outputs.append(scheduler.schedule())
sched_outputs.append(scheduler.schedule())
abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9]
abort_order_copy = abort_order.copy()
def abort_request():
if not abort_order:
return
req = requests[abort_order.pop(0)]
scheduler.finish_requests(req.request_id,
RequestStatus.FINISHED_ABORTED)
while sched_outputs:
# Abort a scheduled request.
abort_request()
sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)
sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output)
for i, req in enumerate(requests):
assert req.status == RequestStatus.FINISHED_ABORTED
assert req.num_output_tokens == abort_order_copy.index(i)
def test_prefix_caching_for_prefill_dedup():
CHUNK_SIZE = 1000
BLOCK_SIZE = 16
num_prompt_tokens = 100
scheduler = create_scheduler(async_scheduling=True,
max_num_batched_tokens=CHUNK_SIZE,
enable_prefix_caching=True,
block_size=BLOCK_SIZE)
requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens,
max_tokens=3,
same_prompt=True)
requests_copy = requests.copy()
# Two requests with the same prompt.
req0 = requests.pop(0)
req1 = requests.pop(0)
scheduler.add_request(req0)
scheduler.add_request(req1)
sched_outputs: deque[SchedulerOutput] = deque()
sched_output = scheduler.schedule()
sched_outputs.append(sched_output)
# Make sure prefix caching de-duplicates the prompts in the same step,
# so all the blocks except the last are shared between the two requests.
assert len(sched_output.num_scheduled_tokens) == 2
num_blocks = num_prompt_tokens // BLOCK_SIZE
assert req0.num_cached_tokens == 0
assert req1.num_cached_tokens >= num_blocks * BLOCK_SIZE
sched_outputs.append(scheduler.schedule())
while sched_outputs:
if requests:
scheduler.add_request(requests.pop(0))
sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)
sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output)
# Other requests scheduled after the two requests should also get
# prefix cache hit.
assert scheduler.get_num_unfinished_requests() == 0
for req in requests_copy[1:]:
assert req.num_cached_tokens >= num_blocks * BLOCK_SIZE
def test_prefix_caching_for_multi_turn():
CHUNK_SIZE = 1000
BLOCK_SIZE = 16
num_prompt_tokens = 100
num_output_tokens = 200
scheduler = create_scheduler(async_scheduling=True,
max_num_batched_tokens=CHUNK_SIZE,
enable_prefix_caching=True,
block_size=BLOCK_SIZE)
requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens,
max_tokens=num_output_tokens)
for req in requests:
scheduler.add_request(req)
sched_outputs: deque[SchedulerOutput] = deque()
sched_outputs.append(scheduler.schedule())
sched_outputs.append(scheduler.schedule())
# Process the requests.
while sched_outputs:
sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)
sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output)
assert scheduler.get_num_unfinished_requests() == 0
# Create next-turn requests whose prompts are the full output of the
# previous turn.
next_turn_requests = create_requests(
num_requests=5,
num_tokens=num_prompt_tokens + num_output_tokens,
max_tokens=num_output_tokens,
)
for i, req in enumerate(next_turn_requests):
req.prompt_token_ids = (requests[i].prompt_token_ids +
list(requests[i].output_token_ids))
# Schedule the next-turn requests.
for req in next_turn_requests:
scheduler.add_request(req)
sched_outputs.append(scheduler.schedule())
# Make sure the next-turn requests get prefix cache hit by the previous
# requests.
for req in next_turn_requests:
assert (req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE *
BLOCK_SIZE)

View File

@ -19,133 +19,7 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest
EOS_TOKEN_ID = 50256
def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
skip_tokenizer_init: bool = False,
) -> Scheduler:
'''Create scheduler under test.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
{class}`Scheduler` instance
'''
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
skip_tokenizer_init=skip_tokenizer_init,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_requests(num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
else:
mm_position = None
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
)
requests.append(request)
return requests
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
def test_add_requests():

152
tests/v1/core/utils.py Normal file
View File

@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
EOS_TOKEN_ID = 50256
def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
skip_tokenizer_init: bool = False,
async_scheduling: bool = False,
) -> Union[Scheduler, AsyncScheduler]:
'''Create scheduler under test.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
{class}`Scheduler` instance
'''
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
async_scheduling=async_scheduling,
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
skip_tokenizer_init=skip_tokenizer_init,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
cache_config.num_gpu_blocks = num_blocks
scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
return scheduler_cls(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None,
same_prompt: bool = False,
) -> list[Request]:
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
else:
mm_position = None
mm_inputs = None
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens)
request = Request(
request_id=f"{i}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
)
requests.append(request)
return requests

View File

@ -2308,6 +2308,13 @@ class SchedulerConfig:
like full attention and sliding window attention.
"""
async_scheduling: bool = False
"""EXPERIMENTAL: If set to True, perform async scheduling. This may help
reduce the CPU overheads, leading to better latency and throughput. However,
async scheduling is currently not supported with some features such as
structured outputs, speculative decoding, and pipeline parallelism.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@ -2401,6 +2408,10 @@ class SchedulerConfig:
if not self.cuda_graph_sizes:
self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
if self.async_scheduling:
self.scheduler_cls = (
"vllm.v1.core.sched.async_scheduler.AsyncScheduler")
@model_validator(mode='after')
def _verify_args(self) -> Self:
if (self.max_num_batched_tokens < self.max_model_len

View File

@ -484,6 +484,8 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel
async_scheduling: bool = SchedulerConfig.async_scheduling
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
@ -921,6 +923,8 @@ class EngineArgs:
scheduler_group.add_argument(
"--disable-hybrid-kv-cache-manager",
**scheduler_kwargs["disable_hybrid_kv_cache_manager"])
scheduler_group.add_argument("--async-scheduling",
**scheduler_kwargs["async_scheduling"])
# vLLM arguments
vllm_kwargs = get_kwargs(VllmConfig)
@ -1206,6 +1210,26 @@ class EngineArgs:
self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port
if self.async_scheduling:
# Async scheduling does not work with the uniprocess backend.
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "mp"
logger.info("Using mp-based distributed executor backend "
"for async scheduling.")
if self.distributed_executor_backend == "uni":
raise ValueError("Async scheduling is not supported with "
"uni-process backend.")
if self.pipeline_parallel_size > 1:
raise ValueError("Async scheduling is not supported with "
"pipeline-parallel-size > 1.")
# Currently, async scheduling does not support speculative decoding.
# TODO(woosuk): Support it.
if self.speculative_config is not None:
raise ValueError(
"Currently, speculative decoding is not supported with "
"async scheduling.")
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
@ -1286,6 +1310,7 @@ class EngineArgs:
long_prefill_token_threshold=self.long_prefill_token_threshold,
disable_hybrid_kv_cache_manager=self.
disable_hybrid_kv_cache_manager,
async_scheduling=self.async_scheduling,
)
if not model_config.is_multimodal_model and self.default_mm_loras:

View File

@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
class AsyncScheduler(Scheduler):
def _update_after_schedule(
self,
scheduler_output: SchedulerOutput,
) -> None:
super()._update_after_schedule(scheduler_output)
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
if (request.num_computed_tokens == request.num_tokens +
request.num_output_placeholders):
# The request will generate a new token in this scheduling step.
# TODO(woosuk): Support speculative decoding.
request.num_output_placeholders += 1
def _update_request_with_output(
self,
request: Request,
new_token_ids: list[int],
) -> tuple[list[int], bool]:
status_before_update = request.status
new_token_ids, stopped = super()._update_request_with_output(
request, new_token_ids)
# Update the number of output placeholders.
request.num_output_placeholders -= len(new_token_ids)
assert request.num_output_placeholders >= 0
# Cache the new tokens. Preempted requests should be skipped.
if status_before_update == RequestStatus.RUNNING:
self.kv_cache_manager.cache_blocks(
request,
request.num_computed_tokens - request.num_output_placeholders)
return new_token_ids, stopped

View File

@ -204,7 +204,8 @@ class Scheduler(SchedulerInterface):
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
num_new_tokens = (request.num_tokens_with_spec -
num_new_tokens = (request.num_tokens_with_spec +
request.num_output_placeholders -
request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
@ -230,9 +231,11 @@ class Scheduler(SchedulerInterface):
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when PP>1 and
# we have already scheduled all prompt tokens but they are
# not finished yet.
# 1. No new tokens to schedule. This may happen when
# (1) PP>1 and we have already scheduled all prompt tokens
# but they are not finished yet.
# (2) Async scheduling and the request has reached to either
# its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
@ -598,6 +601,14 @@ class Scheduler(SchedulerInterface):
request = self.requests[req_id]
request.num_computed_tokens += num_scheduled_token
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which
# may be updated again in _update_from_output for speculative
# decoding. However, it is safe to call the method here because
# encoder inputs are always part of the prompt, not the output,
# and thus are unaffected by speculative decoding.
if request.has_encoder_inputs:
self._free_encoder_inputs(request)
# Clear the finished request IDs.
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
# it will also affect the scheduler output.
@ -785,29 +796,16 @@ class Scheduler(SchedulerInterface):
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
if request.has_encoder_inputs:
self._free_encoder_inputs(request)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
status_before_stop = request.status
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
del new_token_ids[num_new:] # Trim new tokens if needed.
break
# Check for stop and update request status.
if new_token_ids:
new_token_ids, stopped = self._update_request_with_output(
request, new_token_ids)
# Stop checking for pooler models.
pooler_output = None
@ -915,6 +913,26 @@ class Scheduler(SchedulerInterface):
return engine_core_outputs
def _update_request_with_output(
self,
request: Request,
new_token_ids: list[int],
) -> tuple[list[int], bool]:
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
stopped = False
for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
del new_token_ids[num_new:] # Trim new tokens if needed.
break
return new_token_ids, stopped
def _free_encoder_inputs(self, request: Request) -> None:
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))

View File

@ -367,6 +367,8 @@ class MultiprocExecutor(Executor):
@property
def max_concurrent_batches(self) -> int:
if self.scheduler_config.async_scheduling:
return 2
return self.parallel_config.pipeline_parallel_size
def _get_output_rank(self) -> int:

View File

@ -33,6 +33,8 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
"""Ray distributed executor supports pipeline parallelism,
meaning that it allows PP size batches to be executed concurrently.
"""
if self.scheduler_config.async_scheduling:
return 2
return self.parallel_config.pipeline_parallel_size
def execute_model(

View File

@ -77,6 +77,7 @@ class Request:
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: list[int] = []
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.num_output_placeholders = 0 # Used in async scheduling.
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
self.cache_salt: Optional[str] = cache_salt