[Core] Asynchronous Output Processor (#7049)

Co-authored-by: Alexander Matveev <alexm@neuralmagic.com>
This commit is contained in:
Megha Agarwal 2024-08-26 20:53:20 -07:00 committed by GitHub
parent 015e6cc252
commit 2eedede875
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 652 additions and 214 deletions

View File

@ -86,6 +86,7 @@ def run_vllm(
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
@ -110,6 +111,7 @@ def run_vllm(
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
)
# Add the requests to the engine.
@ -237,7 +239,8 @@ def main(args: argparse.Namespace):
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format)
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -418,6 +421,11 @@ if __name__ == "__main__":
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
"--disable-async-output-proc",
action='store_true',
default=False,
help="Disable async output processor for vLLM backend.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model

View File

@ -88,6 +88,9 @@ def test_models(
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
# Due to low-precision numerical divergence, this test is too sensitive to
# the async postprocessor
@pytest.mark.parametrize("disable_async_output_proc", [True])
def test_models_with_fp8_kv_cache(
vllm_runner,
example_prompts,
@ -97,6 +100,7 @@ def test_models_with_fp8_kv_cache(
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
disable_async_output_proc: bool,
) -> None:
"""
Only checks log probs match between chunked-prefill and
@ -126,6 +130,7 @@ def test_models_with_fp8_kv_cache(
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
**extra_kwargs,
) as vllm_model:
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
@ -139,6 +144,7 @@ def test_models_with_fp8_kv_cache(
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
**extra_kwargs,
) as vllm_model:
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(

View File

@ -209,7 +209,6 @@ def test_swap_infeasible(
prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE
example_prompts = example_prompts[:1]
with vllm_runner(
model,
dtype=dtype,

View File

@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int):
def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
@ -180,7 +180,7 @@ def test_maximal_decoding():
"""Verify decoding requests are prioritized."""
block_size = 4
max_seqs = 2
max_model_len = 2
max_model_len = 8
max_num_batched_tokens = 2
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,

View File

@ -199,7 +199,7 @@ def append_new_token(out, token_id: int):
def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out

View File

@ -7,6 +7,8 @@ from vllm import CompletionOutput, LLMEngine, SamplingParams
MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200
IS_ASYNC = False
@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
@ -14,77 +16,13 @@ def vllm_model(vllm_runner):
yield vllm_model
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["."],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=".")
_test_stopping(vllm_model.model.llm_engine,
stop=["."],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".")
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo")
_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=True,
expected_output=
"VLLM is a 100% volunteer organization. We are a group of peo",
expected_reason="group of peo")
@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer or",
expected_reason="gani")
_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani")
@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
# token id 13013 => " organization"
_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
include_in_output=False,
expected_output="VLLM is a 100% volunteer",
expected_reason=13013)
_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013)
def _test_stopping(llm_engine: LLMEngine,
expected_output: str,
expected_reason: Any,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False) -> None:
include_in_output: bool = False,
use_async_output_proc: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
@ -98,6 +36,10 @@ def _test_stopping(llm_engine: LLMEngine,
output: Optional[CompletionOutput] = None
output_text = ""
stop_reason = None
if use_async_output_proc:
llm_engine.step()
while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs
@ -110,3 +52,112 @@ def _test_stopping(llm_engine: LLMEngine,
assert output is not None
assert output_text == expected_output
assert stop_reason == expected_reason
def _set_async_mode(llm_engine, is_async):
llm_engine.scheduler[0].use_async_output_proc = is_async
def _stop_basic(llm_engine, is_async):
_test_stopping(llm_engine,
stop=["."],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=".",
use_async_output_proc=is_async)
_test_stopping(llm_engine,
stop=["."],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".",
use_async_output_proc=is_async)
def _stop_multi_tokens(llm_engine, is_async):
_test_stopping(
llm_engine,
stop=["group of peo", "short"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo",
use_async_output_proc=is_async)
_test_stopping(
llm_engine,
stop=["group of peo", "short"],
include_in_output=True,
expected_output=
"VLLM is a 100% volunteer organization. We are a group of peo",
expected_reason="group of peo",
use_async_output_proc=is_async)
def _stop_partial_token(llm_engine, is_async):
_test_stopping(llm_engine,
stop=["gani"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer or",
expected_reason="gani",
use_async_output_proc=is_async)
_test_stopping(llm_engine,
stop=["gani"],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani",
use_async_output_proc=is_async)
def _stop_token_id(llm_engine, is_async):
# token id 13013 => " organization"
_test_stopping(llm_engine,
stop_token_ids=[13013],
include_in_output=False,
expected_output="VLLM is a 100% volunteer",
expected_reason=13013,
use_async_output_proc=is_async)
_test_stopping(llm_engine,
stop_token_ids=[13013],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013,
use_async_output_proc=is_async)
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_basic(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_basic(vllm_model.model.llm_engine, is_async=False)
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)
@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_partial_token(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_partial_token(vllm_model.model.llm_engine, is_async=False)
@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_token_id(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_token_id(vllm_model.model.llm_engine, is_async=False)

View File

@ -62,6 +62,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]
# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args += ["--disable-async-output-proc"]
if eager_mode:
ms_server_args.append("--enforce-eager")

View File

@ -140,6 +140,7 @@ class ModelConfig:
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
) -> None:
self.model = model
self.tokenizer = tokenizer
@ -172,6 +173,7 @@ class ModelConfig:
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
# Choose a default enforce_eager value if the user did not specify
# a value (enforce_eager is None)
@ -326,6 +328,49 @@ class ModelConfig:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)
def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
if not self.use_async_output_proc:
# Nothing to check
return
if parallel_config.pipeline_parallel_size > 1:
logger.warning("Async output processing can not be enabled "
"with pipeline parallel")
self.use_async_output_proc = False
return
if device_config.device_type != "cuda":
logger.warning(
"Async output processing is only supported for CUDA."
" Disabling it for other platforms.")
self.use_async_output_proc = False
return
if envs.VLLM_USE_RAY_SPMD_WORKER:
logger.warning(
"Async output processing can not be enabled with ray spmd")
self.use_async_output_proc = False
return
if self.enforce_eager:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
self.use_async_output_proc = not self.enforce_eager
return
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if self.embedding_mode:
self.use_async_output_proc = False
if speculative_config:
logger.warning("Async output processing is not supported with"
" speculative decoding currently.")
self.use_async_output_proc = False
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
@ -358,6 +403,11 @@ class ModelConfig:
"fallback to the eager mode.")
self.enforce_eager = True
if pipeline_parallel_size > 1 and self.use_async_output_proc:
logger.warning("Async output processor is not supported with "
"pipeline parallelism currently. Disabling it.")
self.use_async_output_proc = False
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled."""
@ -1769,6 +1819,9 @@ class EngineConfig:
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)

View File

@ -4,7 +4,8 @@ import random
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
@ -299,6 +300,7 @@ class Scheduler:
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
output_proc_callback_fn: Optional[Callable] = None,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
@ -364,10 +366,36 @@ class Scheduler:
self.num_cumulative_preemption: int = 0
# Used to cache python objects
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
scheduler_running_outputs_builder)
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
scheduled_seq_group_builder)
self._seq_group_metadata_cache: List[PyObjectCache] = []
self._scheduler_running_outputs_cache: List[PyObjectCache] = []
self._scheduled_seq_group_cache: List[PyObjectCache] = []
# For async output processing, we need to swap cache buffers between
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
self.output_proc_callback_fn = output_proc_callback_fn
self.use_async_output_proc = self.output_proc_callback_fn is not None
self.num_cache_iters = 2 if self.use_async_output_proc else 1
self.cache_id = 0
for i in range(self.num_cache_iters):
self._seq_group_metadata_cache.append(
PyObjectCache(seq_group_metadata_builder))
self._scheduler_running_outputs_cache.append(
PyObjectCache(scheduler_running_outputs_builder))
self._scheduled_seq_group_cache.append(
PyObjectCache(scheduled_seq_group_builder))
# For async postprocessor, the extra decode run cannot be done
# when the request reaches max_model_len. In this case, the request
# will be stopped during schedule() call and added to this stop list
# for processing and deallocation by the free_finished_seq_groups()
self._async_stopped: List[SequenceGroup] = []
@property
def next_cache_id(self):
return (self.cache_id + 1) % self.num_cache_iters
@property
def lora_enabled(self) -> bool:
@ -483,7 +511,7 @@ class Scheduler:
SchedulerRunningOutputs.
"""
ret: SchedulerRunningOutputs = \
self._scheduler_running_outputs_cache.get_object()
self._scheduler_running_outputs_cache[self.cache_id].get_object()
ret.blocks_to_swap_out.clear()
ret.blocks_to_copy.clear()
ret.decode_seq_groups.clear()
@ -510,8 +538,12 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
running_queue = self.running
# Store original running requests for the case of async + preemption
if self.use_async_output_proc:
orig_running = self.running.copy()
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens(
@ -521,6 +553,28 @@ class Scheduler:
break
running_queue.popleft()
# With async postprocessor, an extra decode run is done
# to process the final tokens. The check below avoids this extra
# decode run when the model max len is reached, in order to avoid
# a memory overflow.
if self.use_async_output_proc and seq_group.seqs[0].get_len(
) > self.scheduler_config.max_model_len:
self._async_stopped.append(seq_group)
continue
# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if self.use_async_output_proc and not self._can_append_slots(
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback_fn is not None
self.output_proc_callback_fn(is_async=True)
self.running = tmp
while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
@ -556,7 +610,7 @@ class Scheduler:
is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = \
self._scheduled_seq_group_cache.get_object()
self._scheduled_seq_group_cache[self.cache_id].get_object()
scheduled_seq_group.seq_group = seq_group
if is_prefill:
scheduled_seq_group.token_chunk_size = num_running_tokens
@ -579,8 +633,8 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)
self._scheduler_running_outputs_cache.reset()
self._scheduled_seq_group_cache.reset()
self._scheduler_running_outputs_cache[self.next_cache_id].reset()
self._scheduled_seq_group_cache[self.next_cache_id].reset()
return ret
@ -1031,17 +1085,31 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
)
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
no_beam_search = (seq_group.sampling_params.best_of == 1
and not seq_group.sampling_params.use_beam_search)
return no_beam_search
def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_start_time = time.perf_counter()
scheduler_outputs = self._schedule()
now = time.time()
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []
# TODO: Combine multi-step and async postprocessor
allow_async_output_proc: bool = (
self.use_async_output_proc
and not self.scheduler_config.is_multi_step)
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
for i, scheduled_seq_group in enumerate(
@ -1050,6 +1118,11 @@ class Scheduler:
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now)
seq_group_metadata = self._seq_group_metadata_cache[
self.cache_id].get_object()
seq_group_metadata.seq_data.clear()
seq_group_metadata.block_tables.clear()
# seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers
@ -1139,6 +1212,10 @@ class Scheduler:
)
seq_group_metadata_list.append(seq_group_metadata)
if allow_async_output_proc:
allow_async_output_proc = self._allow_async_output_proc(
seq_group)
# Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
@ -1147,6 +1224,8 @@ class Scheduler:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
self._seq_group_metadata_cache[self.next_cache_id].reset()
scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant
@ -1158,7 +1237,12 @@ class Scheduler:
else:
seq_group.metrics.scheduler_time = scheduler_time
return seq_group_metadata_list, scheduler_outputs
# Move to next cache (if exists)
self.cache_id = self.next_cache_id
# Return results
return (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_manager.fork(parent_seq, child_seq)
@ -1167,6 +1251,12 @@ class Scheduler:
"""Free a sequence from a block table."""
self.block_manager.free(seq)
def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
"""Free finished seqs in a sequence group."""
for seq in seq_group.get_seqs():
if seq.is_finished():
self.free_seq(seq)
def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
@ -1179,8 +1269,24 @@ class Scheduler:
self._finished_requests_ids.append(seq_group.request_id)
else:
remaining.append(seq_group)
# Free finished seqs
self._free_finished_seqs(seq_group)
self.running = remaining
# Handle async stopped sequence groups
# (ones that reached max model len)
if self._async_stopped:
for seq_group in self._async_stopped:
self._free_seq_group_cross_attn_blocks(seq_group)
self._finished_requests_ids.append(seq_group.request_id)
# Free finished seqs
self._free_finished_seqs(seq_group)
self._async_stopped.clear()
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):

View File

@ -147,6 +147,7 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
def __post_init__(self):
if self.tokenizer is None:
@ -733,6 +734,12 @@ class EngineArgs:
"modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact.")
parser.add_argument(
'--disable-async-output-proc',
action='store_true',
default=EngineArgs.disable_async_output_proc,
help="Disable async output processing. This may result in "
"lower performance.")
return parser
@classmethod
@ -792,6 +799,7 @@ class EngineArgs:
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
)
cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else

View File

@ -277,23 +277,36 @@ class _AsyncLLMEngine(LLMEngine):
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
if not allow_async_output_proc and len(self.output_queue) > 0:
self._process_model_outputs(is_async=True)
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs)
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)
if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
@ -317,6 +330,11 @@ class _AsyncLLMEngine(LLMEngine):
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.output_proc_callback_fn = \
self._process_model_outputs
# Execute the model.
output = await self.model_executor.execute_model_async(
execute_model_req)
@ -325,6 +343,9 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(self.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True)
output = []
# Finish the current step for all the sequence groups.
@ -337,19 +358,32 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Cache results in engine
self.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(is_async=False)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Tracing
self.do_tracing(scheduler_outputs)
else:
request_outputs = []
self.request_outputs = []
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Tracing
self.do_tracing(scheduler_outputs)
return request_outputs
return self.request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""

View File

@ -1,7 +1,8 @@
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, Union
@ -38,9 +39,8 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
@ -82,9 +82,10 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
@dataclass
class SchedulerOutputState:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output: Optional[SamplerOutput] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
allow_async_output_proc: bool = False
last_output: Optional[SamplerOutput] = None
class LLMEngine:
@ -190,6 +191,9 @@ class LLMEngine:
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
# To improve performance, only final requests outputs may be required.
# If this set to true, then no intermediate outputs will be returned.
step_return_finished_only: bool = False,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
@ -204,7 +208,8 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s)",
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
@ -235,6 +240,7 @@ class LLMEngine:
scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
)
# TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins
@ -253,6 +259,7 @@ class LLMEngine:
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
self.step_return_finished_only = step_return_finished_only
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
@ -340,8 +347,11 @@ class LLMEngine:
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = [
Scheduler(scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size)
Scheduler(
scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size,
self._process_model_outputs
if model_config.use_async_output_proc else None)
for _ in range(parallel_config.pipeline_parallel_size)
]
@ -396,6 +406,13 @@ class LLMEngine:
for _ in range(self.parallel_config.pipeline_parallel_size)
]
# Async output processing pointers
self.output_queue: Deque[Tuple[List[SamplerOutput],
List[SequenceGroupMetadata],
SchedulerOutputs]] = deque()
self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@ -1197,34 +1214,66 @@ class LLMEngine:
return
def _process_model_outputs(
self,
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
def _process_model_outputs(self,
is_async: bool,
clear_outputs: bool = True) -> None:
"""Apply the model output to the sequences in the scheduled seq groups.
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
clear_outputs: Sometimes existing outputs need to be combined
with outputs of this call. This happens for postprocessor
draining at the final stage (like when sequences are finished)
Returns RequestOutputs that can be returned to the client.
"""
now = time.time()
# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group(
output, num_seq_groups=len(scheduled_seq_groups))
if clear_outputs:
self.request_outputs.clear()
if len(self.output_queue) == 0:
return None
(outputs, seq_group_metadata_list,
scheduler_outputs) = self.output_queue.popleft()
# Sanity check
assert len(seq_group_metadata_list) == len(
scheduler_outputs.scheduled_seq_groups)
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if len(outputs) > 1:
outputs_by_sequence_group = create_output_by_sequence_group(
outputs, num_seq_groups=len(seq_group_metadata_list))
else:
outputs_by_sequence_group = outputs
finished_before: List[int] = []
for i, seq_group_meta in enumerate(seq_group_metadata_list):
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(
scheduled_seq_groups, output_by_sequence_group,
seq_group_metadata_list):
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if output is not None and len(output) > 0:
for o in output:
if seq_group.is_finished():
finished_before.append(i)
continue
if len(outputs) > 1:
output = outputs_by_sequence_group[i]
else:
output = [outputs_by_sequence_group[0][i]]
if not is_async:
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if outputs:
for o in outputs:
if (isinstance(o, SamplerOutput)
and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None:
@ -1239,30 +1288,75 @@ class LLMEngine:
else:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.embedding_mode:
self._process_sequence_group_outputs(seq_group, outputs)
self._process_sequence_group_outputs(seq_group, output)
continue
self.output_processor.process_prompt_logprob(seq_group, outputs)
self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(seq_group, outputs)
self.output_processor.process_outputs(seq_group, output,
is_async)
# Free the finished sequence groups.
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
for scheduled_seq_group in scheduled_seq_groups:
for i, _ in enumerate(seq_group_metadata_list):
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
if i in finished_before:
continue # Avoids double processing
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if (seq_group.is_finished()
if self.step_return_finished_only else True):
request_output = RequestOutputFactory.create(seq_group)
self.request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups:
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
for seq_group in ignored_seq_groups:
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
return request_outputs
self.request_outputs.append(request_output)
if is_async:
# Log stats.
self.do_log_stats(scheduler_outputs, outputs, finished_before)
# Tracing
self.do_tracing(scheduler_outputs)
return None
def _advance_to_next_step(
self, output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done inside output processor, but it is
required if the worker is to perform async forward pass to next step.
"""
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
zip(seq_group_metadata_list, output, scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
continue
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample"
" (i.e sampling_params.n == 1 and no "
"sampling_params.best_of > 1)")
sample = sequence_group_outputs.samples[0]
assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
@ -1325,24 +1419,32 @@ class LLMEngine:
cached_outputs = self.cached_scheduler_outputs[0]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc) = self.scheduler[0].schedule()
if not allow_async_output_proc and len(self.output_queue) > 0:
self._process_model_outputs(is_async=True)
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
0, seq_group_metadata_list, scheduler_outputs)
0, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)
if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
0].get_and_reset_finished_requests_ids()
@ -1366,6 +1468,10 @@ class LLMEngine:
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.output_proc_callback_fn = \
self._process_model_outputs
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
@ -1374,6 +1480,9 @@ class LLMEngine:
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(0, output)
else:
if len(self.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True)
output = []
# Finish the current step for all the sequence groups.
@ -1382,23 +1491,41 @@ class LLMEngine:
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
# clear the cache if we have finished all the steps.
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Add results to the output_queue
# (for async or non-async postprocessing)
self.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if output and allow_async_output_proc:
assert len(output) == 1, ("Multi step decoding does not work "
"with async output processing.")
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(is_async=False)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Tracing
self.do_tracing(scheduler_outputs)
else:
request_outputs = []
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Tracing
self.do_tracing(scheduler_outputs)
self.request_outputs = []
if not self.has_unfinished_requests():
# Drain async postprocessor
if len(self.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True, clear_outputs=False)
assert len(self.output_queue) == 0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
@ -1406,7 +1533,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
self.model_executor.stop_remote_worker_execution_loop()
return request_outputs
return self.request_outputs
def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
@ -1431,12 +1558,14 @@ class LLMEngine:
def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
scheduler_outputs: SchedulerOutputs) -> None:
self.cached_scheduler_outputs[
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
scheduler_outputs
self.cached_scheduler_outputs[virtual_engine].last_output = None
scheduler_outputs: SchedulerOutputs,
allow_async_output_proc: bool) -> None:
co = self.cached_scheduler_outputs[virtual_engine]
co.seq_group_metadata_list = seq_group_metadata_list
co.scheduler_outputs = scheduler_outputs
co.allow_async_output_proc = allow_async_output_proc
co.last_output = None
def _update_cached_scheduler_output(
self, virtual_engine: int,
@ -1472,20 +1601,21 @@ class LLMEngine:
raise KeyError(f"Logger with name {logger_name} does not exist.")
del self.stat_loggers[logger_name]
def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None:
def do_log_stats(self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
finished_before: Optional[List[int]] = None) -> None:
"""Forced log when no requests active."""
if self.log_stats:
stats = self._get_stats(scheduler_outputs, model_output)
stats = self._get_stats(scheduler_outputs, model_output,
finished_before)
for logger in self.stat_loggers.values():
logger.log(stats)
def _get_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs],
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
def _get_stats(self,
scheduler_outputs: Optional[SchedulerOutputs],
model_output: Optional[List[SamplerOutput]] = None,
finished_before: Optional[List[int]] = None) -> Stats:
"""Get Stats to be Logged to Prometheus.
Args:
@ -1550,6 +1680,10 @@ class LLMEngine:
# NOTE: This loop assumes prefill seq_groups are before
# decode seq_groups in scheduled_seq_groups.
if scheduler_outputs is not None:
# For async postprocessor, already finished sequences need to be
# not counted (to avoid double counting)
actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore
num_generation_tokens_from_prefill_groups = 0.
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
# the len of scheduler_outputs.scheduled_seq_groups is !=
@ -1558,6 +1692,11 @@ class LLMEngine:
for idx, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
# Skip double logging when using async output proc
if finished_before and idx in finished_before:
actual_num_batched_tokens -= 1
continue
group_was_prefill = idx < scheduler_outputs.num_prefill_groups
seq_group = scheduled_seq_group.seq_group
@ -1592,7 +1731,6 @@ class LLMEngine:
# Latency timings
time_e2e_requests.append(now -
seq_group.metrics.arrival_time)
# Metadata
num_prompt_tokens_requests.append(
len(seq_group.prompt_token_ids))
@ -1616,7 +1754,7 @@ class LLMEngine:
# + num_generation_tokens_from_prefill_groups (since we generate
# one token on prefills on iters where the prefill finishes).
num_generation_tokens_iter = (
scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
actual_num_batched_tokens - num_prompt_tokens_iter +
num_generation_tokens_from_prefill_groups)
# Spec decode, if enabled, emits specialized metrics from the worker in

View File

@ -40,13 +40,9 @@ class SequenceGroupOutputProcessor(ABC):
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
return SingleStepOutputProcessor(
scheduler_config,
detokenizer,
scheduler,
seq_counter,
stop_checker,
)
return SingleStepOutputProcessor(scheduler_config, detokenizer,
scheduler, seq_counter,
stop_checker)
else:
# Importing here to avoid cycle.
from vllm.engine.output_processor.multi_step import (
@ -61,7 +57,8 @@ class SequenceGroupOutputProcessor(ABC):
@abstractmethod
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.

View File

@ -57,17 +57,28 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).")
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
def process_outputs(self,
sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool = False) -> None:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
one new token per sequence.
This applies logic like stop condition checking and detokenization,
including freeing finished sequences. It also handles cases where there
are tokens emitted after the EOS token.
This applies logic like stop condition checking and detokenization.
It also handles cases where there are tokens emitted after
the EOS token.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
# TODO: Add support for async if necessary
assert not is_async
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
assert seqs, "expected running sequences"
@ -138,7 +149,3 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
)
if seq.is_finished():
break
if seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)

View File

@ -29,14 +29,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
that is currently difficult to schedule multiple steps ahead of time.
"""
def __init__(
self,
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
stop_checker: StopChecker,
):
def __init__(self, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, scheduler: List[Scheduler],
seq_counter: Counter, stop_checker: StopChecker):
self.scheduler_config = scheduler_config
self.detokenizer = detokenizer
self.scheduler = scheduler
@ -44,16 +39,24 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
self.stop_checker = stop_checker
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
assert (len(outputs) == 1
), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0])
return self._process_sequence_group_outputs(sequence_group, outputs[0],
is_async)
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
@ -80,14 +83,16 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
seq_group.prompt_logprobs.extend(prompt_logprobs)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
outputs: SequenceGroupOutput,
is_async: bool) -> None:
sampling_params = seq_group.sampling_params
if sampling_params.n == 1 and not sampling_params.use_beam_search:
# only have one output sample
sample = outputs.samples[0]
# only have one sequence
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
@ -104,6 +109,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduler.free_seq(seq)
return
# TODO: Add support for async for beam search
assert not is_async
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)

View File

@ -129,6 +129,7 @@ class LLM:
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
**kwargs,
) -> None:
'''
@ -170,6 +171,7 @@ class LLM:
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
@ -603,7 +605,6 @@ class LLM:
inputs = [inputs]
num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
@ -678,6 +679,10 @@ class LLM:
postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
)
# In the loop below, only finished outputs are used
self.llm_engine.step_return_finished_only = True
# Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0
@ -700,6 +705,10 @@ class LLM:
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1)
# Restore original behavior
self.llm_engine.step_return_finished_only = False
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.

View File

@ -64,8 +64,9 @@ class DistributedGPUExecutor(GPUExecutor):
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
@ -188,7 +189,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
@abstractmethod
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
"""Execute the model asynchronously in the driver worker.

View File

@ -176,5 +176,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
execute_model_req: ExecuteModelRequest,
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
)(execute_model_req=execute_model_req)
return output

View File

@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Union, cast)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
Optional, Set, Tuple, Union, cast)
import msgspec
import torch
@ -474,11 +474,8 @@ class Sequence:
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()
def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
) -> None:
def append_token_id(self, token_id: int, logprobs: Dict[int,
Logprob]) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
@ -1293,6 +1290,8 @@ class ExecuteModelRequest(
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async postprocessor
output_proc_callback_fn: Optional[Callable] = None
@property
def is_first_multi_step(self) -> bool:
@ -1338,4 +1337,5 @@ class ExecuteModelRequest(
num_steps=self.num_steps,
finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None)
if self.last_sampled_token_ids is not None else None,
output_proc_callback_fn=self.output_proc_callback_fn)

View File

@ -6,8 +6,8 @@ import time
import warnings
import weakref
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
Tuple, Type, TypeVar, Union)
import numpy as np
import torch
@ -90,6 +90,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
output_proc_callback_fn: Optional[Callable] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
@ -1327,7 +1328,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
@ -1451,6 +1452,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if not self.is_driver_worker:
return []
if model_input.output_proc_callback_fn is not None:
model_input.output_proc_callback_fn(is_async=True)
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,

View File

@ -263,6 +263,12 @@ class LocalOrDistributedWorkerBase(WorkerBase):
broadcast_data.update(kwargs)
broadcast_tensor_dict(broadcast_data, src=0)
if execute_model_req.output_proc_callback_fn:
model_input = dataclasses.replace( # type: ignore
model_input,
output_proc_callback_fn=execute_model_req.
output_proc_callback_fn)
return model_input, worker_input, kwargs
def prepare_input(
@ -289,7 +295,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""