diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1ccab2c65e69..eaf256f7cb8c 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -86,6 +86,7 @@ def run_vllm( use_v2_block_manager: bool = False, download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, + disable_async_output_proc: bool = False, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -110,6 +111,7 @@ def run_vllm( load_format=load_format, num_scheduler_steps=num_scheduler_steps, use_v2_block_manager=use_v2_block_manager, + disable_async_output_proc=disable_async_output_proc, ) # Add the requests to the engine. @@ -237,7 +239,8 @@ def main(args: argparse.Namespace): args.enable_prefix_caching, args.enable_chunked_prefill, args.max_num_batched_tokens, args.distributed_executor_backend, args.gpu_memory_utilization, args.num_scheduler_steps, - args.use_v2_block_manager, args.download_dir, args.load_format) + args.use_v2_block_manager, args.download_dir, args.load_format, + args.disable_async_output_proc) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -418,6 +421,11 @@ if __name__ == "__main__": 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') + parser.add_argument( + "--disable-async-output-proc", + action='store_true', + default=False, + help="Disable async output processor for vLLM backend.") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 9c6364ecc679..1211e6ba5aaf 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -88,6 +88,9 @@ def test_models( # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) +# Due to low-precision numerical divergence, this test is too sensitive to +# the async postprocessor +@pytest.mark.parametrize("disable_async_output_proc", [True]) def test_models_with_fp8_kv_cache( vllm_runner, example_prompts, @@ -97,6 +100,7 @@ def test_models_with_fp8_kv_cache( chunked_prefill_token_size: int, enforce_eager: bool, tensor_parallel_size: int, + disable_async_output_proc: bool, ) -> None: """ Only checks log probs match between chunked-prefill and @@ -126,6 +130,7 @@ def test_models_with_fp8_kv_cache( enforce_eager=enforce_eager, max_num_seqs=max_num_seqs, kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, **extra_kwargs, ) as vllm_model: no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( @@ -139,6 +144,7 @@ def test_models_with_fp8_kv_cache( enforce_eager=enforce_eager, max_num_seqs=max_num_seqs, kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, **extra_kwargs, ) as vllm_model: chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 7c62de9fa9e3..7e77037da07d 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -209,7 +209,6 @@ def test_swap_infeasible( prefill_blocks = 2 decode_blocks = max_tokens // BLOCK_SIZE example_prompts = example_prompts[:1] - with vllm_runner( model, dtype=dtype, diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index a3b76327e0a5..6d9c2f3ebba4 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int): def schedule_and_update_computed_tokens(scheduler): - metas, out = scheduler.schedule() + metas, out, _ = scheduler.schedule() for s, meta in zip(out.scheduled_seq_groups, metas): s.seq_group.update_num_computed_tokens(meta.token_chunk_size) return metas, out @@ -180,7 +180,7 @@ def test_maximal_decoding(): """Verify decoding requests are prioritized.""" block_size = 4 max_seqs = 2 - max_model_len = 2 + max_model_len = 8 max_num_batched_tokens = 2 scheduler_config = SchedulerConfig(max_num_batched_tokens, max_seqs, diff --git a/tests/core/utils.py b/tests/core/utils.py index 12b66d50749d..40d8f51fc186 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -199,7 +199,7 @@ def append_new_token(out, token_id: int): def schedule_and_update_computed_tokens(scheduler): - metas, out = scheduler.schedule() + metas, out, _ = scheduler.schedule() for s, meta in zip(out.scheduled_seq_groups, metas): s.seq_group.update_num_computed_tokens(meta.token_chunk_size) return metas, out diff --git a/tests/engine/test_stop_strings.py b/tests/engine/test_stop_strings.py index 1584b85aeb06..499935620c16 100644 --- a/tests/engine/test_stop_strings.py +++ b/tests/engine/test_stop_strings.py @@ -7,6 +7,8 @@ from vllm import CompletionOutput, LLMEngine, SamplingParams MODEL = "meta-llama/llama-2-7b-hf" MAX_TOKENS = 200 +IS_ASYNC = False + @pytest.fixture(scope="session") def vllm_model(vllm_runner): @@ -14,77 +16,13 @@ def vllm_model(vllm_runner): yield vllm_model -@pytest.mark.skip_global_cleanup -def test_stop_basic(vllm_model): - _test_stopping(vllm_model.model.llm_engine, - stop=["."], - include_in_output=False, - expected_output="VLLM is a 100% volunteer organization", - expected_reason=".") - - _test_stopping(vllm_model.model.llm_engine, - stop=["."], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organization.", - expected_reason=".") - - -@pytest.mark.skip_global_cleanup -def test_stop_multi_tokens(vllm_model): - _test_stopping( - vllm_model.model.llm_engine, - stop=["group of peo", "short"], - include_in_output=False, - expected_output="VLLM is a 100% volunteer organization. We are a ", - expected_reason="group of peo") - - _test_stopping( - vllm_model.model.llm_engine, - stop=["group of peo", "short"], - include_in_output=True, - expected_output= - "VLLM is a 100% volunteer organization. We are a group of peo", - expected_reason="group of peo") - - -@pytest.mark.skip_global_cleanup -def test_stop_partial_token(vllm_model): - _test_stopping(vllm_model.model.llm_engine, - stop=["gani"], - include_in_output=False, - expected_output="VLLM is a 100% volunteer or", - expected_reason="gani") - - _test_stopping(vllm_model.model.llm_engine, - stop=["gani"], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organi", - expected_reason="gani") - - -@pytest.mark.skip_global_cleanup -def test_stop_token_id(vllm_model): - # token id 13013 => " organization" - - _test_stopping(vllm_model.model.llm_engine, - stop_token_ids=[13013], - include_in_output=False, - expected_output="VLLM is a 100% volunteer", - expected_reason=13013) - - _test_stopping(vllm_model.model.llm_engine, - stop_token_ids=[13013], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organization", - expected_reason=13013) - - def _test_stopping(llm_engine: LLMEngine, expected_output: str, expected_reason: Any, stop: Optional[List[str]] = None, stop_token_ids: Optional[List[int]] = None, - include_in_output: bool = False) -> None: + include_in_output: bool = False, + use_async_output_proc: bool = False) -> None: llm_engine.add_request( "id", "A story about vLLM:\n", SamplingParams( @@ -98,6 +36,10 @@ def _test_stopping(llm_engine: LLMEngine, output: Optional[CompletionOutput] = None output_text = "" stop_reason = None + + if use_async_output_proc: + llm_engine.step() + while llm_engine.has_unfinished_requests(): (request_output, ) = llm_engine.step() (output, ) = request_output.outputs @@ -110,3 +52,112 @@ def _test_stopping(llm_engine: LLMEngine, assert output is not None assert output_text == expected_output assert stop_reason == expected_reason + + +def _set_async_mode(llm_engine, is_async): + llm_engine.scheduler[0].use_async_output_proc = is_async + + +def _stop_basic(llm_engine, is_async): + _test_stopping(llm_engine, + stop=["."], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=".", + use_async_output_proc=is_async) + + _test_stopping(llm_engine, + stop=["."], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization.", + expected_reason=".", + use_async_output_proc=is_async) + + +def _stop_multi_tokens(llm_engine, is_async): + _test_stopping( + llm_engine, + stop=["group of peo", "short"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization. We are a ", + expected_reason="group of peo", + use_async_output_proc=is_async) + + _test_stopping( + llm_engine, + stop=["group of peo", "short"], + include_in_output=True, + expected_output= + "VLLM is a 100% volunteer organization. We are a group of peo", + expected_reason="group of peo", + use_async_output_proc=is_async) + + +def _stop_partial_token(llm_engine, is_async): + _test_stopping(llm_engine, + stop=["gani"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer or", + expected_reason="gani", + use_async_output_proc=is_async) + + _test_stopping(llm_engine, + stop=["gani"], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organi", + expected_reason="gani", + use_async_output_proc=is_async) + + +def _stop_token_id(llm_engine, is_async): + # token id 13013 => " organization" + + _test_stopping(llm_engine, + stop_token_ids=[13013], + include_in_output=False, + expected_output="VLLM is a 100% volunteer", + expected_reason=13013, + use_async_output_proc=is_async) + + _test_stopping(llm_engine, + stop_token_ids=[13013], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=13013, + use_async_output_proc=is_async) + + +@pytest.mark.skip_global_cleanup +def test_stop_basic(vllm_model): + _set_async_mode(vllm_model.model.llm_engine, True) + _stop_basic(vllm_model.model.llm_engine, is_async=True) + + _set_async_mode(vllm_model.model.llm_engine, False) + _stop_basic(vllm_model.model.llm_engine, is_async=False) + + +@pytest.mark.skip_global_cleanup +def test_stop_multi_tokens(vllm_model): + _set_async_mode(vllm_model.model.llm_engine, True) + _stop_multi_tokens(vllm_model.model.llm_engine, is_async=True) + + _set_async_mode(vllm_model.model.llm_engine, False) + _stop_multi_tokens(vllm_model.model.llm_engine, is_async=False) + + +@pytest.mark.skip_global_cleanup +def test_stop_partial_token(vllm_model): + _set_async_mode(vllm_model.model.llm_engine, True) + _stop_partial_token(vllm_model.model.llm_engine, is_async=True) + + _set_async_mode(vllm_model.model.llm_engine, False) + _stop_partial_token(vllm_model.model.llm_engine, is_async=False) + + +@pytest.mark.skip_global_cleanup +def test_stop_token_id(vllm_model): + _set_async_mode(vllm_model.model.llm_engine, True) + _stop_token_id(vllm_model.model.llm_engine, is_async=True) + + _set_async_mode(vllm_model.model.llm_engine, False) + _stop_token_id(vllm_model.model.llm_engine, is_async=False) diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index bc14311c6642..c5182cfd2fc0 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -62,6 +62,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, ms_server_args = DEFAULT_SERVER_ARGS + \ ["--num-scheduler-steps", f"{num_scheduler_steps}"] + # Disable output proc callback as its not supported + # with multi-step right now + ms_server_args += ["--disable-async-output-proc"] if eager_mode: ms_server_args.append("--enforce-eager") diff --git a/vllm/config.py b/vllm/config.py index 4cbdde5e113a..74b18341e5ac 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -140,6 +140,7 @@ class ModelConfig: skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, ) -> None: self.model = model self.tokenizer = tokenizer @@ -172,6 +173,7 @@ class ModelConfig: self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + self.use_async_output_proc = use_async_output_proc # Choose a default enforce_eager value if the user did not specify # a value (enforce_eager is None) @@ -326,6 +328,49 @@ class ModelConfig: self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) + def verify_async_output_proc(self, parallel_config, speculative_config, + device_config) -> None: + if not self.use_async_output_proc: + # Nothing to check + return + + if parallel_config.pipeline_parallel_size > 1: + logger.warning("Async output processing can not be enabled " + "with pipeline parallel") + self.use_async_output_proc = False + return + + if device_config.device_type != "cuda": + logger.warning( + "Async output processing is only supported for CUDA." + " Disabling it for other platforms.") + self.use_async_output_proc = False + return + + if envs.VLLM_USE_RAY_SPMD_WORKER: + logger.warning( + "Async output processing can not be enabled with ray spmd") + self.use_async_output_proc = False + return + + if self.enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + self.use_async_output_proc = not self.enforce_eager + return + + # Async postprocessor is not necessary with embedding mode + # since there is no token generation + if self.embedding_mode: + self.use_async_output_proc = False + + if speculative_config: + logger.warning("Async output processing is not supported with" + " speculative decoding currently.") + self.use_async_output_proc = False + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", @@ -358,6 +403,11 @@ class ModelConfig: "fallback to the eager mode.") self.enforce_eager = True + if pipeline_parallel_size > 1 and self.use_async_output_proc: + logger.warning("Async output processor is not supported with " + "pipeline parallelism currently. Disabling it.") + self.use_async_output_proc = False + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" @@ -1769,6 +1819,9 @@ class EngineConfig: def __post_init__(self): """Verify configs are valid & consistent with each other. """ + self.model_config.verify_async_output_proc(self.parallel_config, + self.speculative_config, + self.device_config) self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 3b716e32032c..280d7b7e61e2 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -4,7 +4,8 @@ import random import time from collections import deque from dataclasses import dataclass, field -from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set, + Tuple, Union) from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager @@ -299,6 +300,7 @@ class Scheduler: cache_config: CacheConfig, lora_config: Optional[LoRAConfig], pipeline_parallel_size: int = 1, + output_proc_callback_fn: Optional[Callable] = None, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -364,10 +366,36 @@ class Scheduler: self.num_cumulative_preemption: int = 0 # Used to cache python objects - self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache( - scheduler_running_outputs_builder) - self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache( - scheduled_seq_group_builder) + self._seq_group_metadata_cache: List[PyObjectCache] = [] + self._scheduler_running_outputs_cache: List[PyObjectCache] = [] + self._scheduled_seq_group_cache: List[PyObjectCache] = [] + + # For async output processing, we need to swap cache buffers between + # iterations. I.e. since the output processing is lagged one step, + # we cannot reuse the cached objects immediately when the schedule() + # is called again, but only when schedule() is called the second time. + self.output_proc_callback_fn = output_proc_callback_fn + self.use_async_output_proc = self.output_proc_callback_fn is not None + self.num_cache_iters = 2 if self.use_async_output_proc else 1 + + self.cache_id = 0 + for i in range(self.num_cache_iters): + self._seq_group_metadata_cache.append( + PyObjectCache(seq_group_metadata_builder)) + self._scheduler_running_outputs_cache.append( + PyObjectCache(scheduler_running_outputs_builder)) + self._scheduled_seq_group_cache.append( + PyObjectCache(scheduled_seq_group_builder)) + + # For async postprocessor, the extra decode run cannot be done + # when the request reaches max_model_len. In this case, the request + # will be stopped during schedule() call and added to this stop list + # for processing and deallocation by the free_finished_seq_groups() + self._async_stopped: List[SequenceGroup] = [] + + @property + def next_cache_id(self): + return (self.cache_id + 1) % self.num_cache_iters @property def lora_enabled(self) -> bool: @@ -483,7 +511,7 @@ class Scheduler: SchedulerRunningOutputs. """ ret: SchedulerRunningOutputs = \ - self._scheduler_running_outputs_cache.get_object() + self._scheduler_running_outputs_cache[self.cache_id].get_object() ret.blocks_to_swap_out.clear() ret.blocks_to_copy.clear() ret.decode_seq_groups.clear() @@ -510,8 +538,12 @@ class Scheduler: # NOTE(woosuk): Preemption happens only when there is no available slot # to keep all the sequence groups in the RUNNING state. - running_queue = self.running + # Store original running requests for the case of async + preemption + if self.use_async_output_proc: + orig_running = self.running.copy() + running_queue = self.running + assert len(self._async_stopped) == 0 while running_queue: seq_group = running_queue[0] num_running_tokens = self._get_num_new_tokens( @@ -521,6 +553,28 @@ class Scheduler: break running_queue.popleft() + + # With async postprocessor, an extra decode run is done + # to process the final tokens. The check below avoids this extra + # decode run when the model max len is reached, in order to avoid + # a memory overflow. + if self.use_async_output_proc and seq_group.seqs[0].get_len( + ) > self.scheduler_config.max_model_len: + self._async_stopped.append(seq_group) + continue + + # With async postprocessor, when preemption kicks in, we need + # first to drain the async postprocessor, so that all async + # block_table freeing is applied before the preemption freeing + # is applied. + if self.use_async_output_proc and not self._can_append_slots( + seq_group): + tmp = self.running + self.running = orig_running + assert self.output_proc_callback_fn is not None + self.output_proc_callback_fn(is_async=True) + self.running = tmp + while not self._can_append_slots(seq_group): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) @@ -556,7 +610,7 @@ class Scheduler: is_prefill = seq_group.is_prefill() scheduled_seq_group: ScheduledSequenceGroup = \ - self._scheduled_seq_group_cache.get_object() + self._scheduled_seq_group_cache[self.cache_id].get_object() scheduled_seq_group.seq_group = seq_group if is_prefill: scheduled_seq_group.token_chunk_size = num_running_tokens @@ -579,8 +633,8 @@ class Scheduler: if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) - self._scheduler_running_outputs_cache.reset() - self._scheduled_seq_group_cache.reset() + self._scheduler_running_outputs_cache[self.next_cache_id].reset() + self._scheduled_seq_group_cache[self.next_cache_id].reset() return ret @@ -1031,17 +1085,31 @@ class Scheduler: num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), ) - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: + def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: + no_beam_search = (seq_group.sampling_params.best_of == 1 + and not seq_group.sampling_params.use_beam_search) + + return no_beam_search + + def schedule( + self + ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. scheduler_start_time = time.perf_counter() + scheduler_outputs = self._schedule() now = time.time() if not self.cache_config.enable_prefix_caching: common_computed_block_nums = [] + # TODO: Combine multi-step and async postprocessor + allow_async_output_proc: bool = ( + self.use_async_output_proc + and not self.scheduler_config.is_multi_step) + # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] for i, scheduled_seq_group in enumerate( @@ -1050,6 +1118,11 @@ class Scheduler: token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) + seq_group_metadata = self._seq_group_metadata_cache[ + self.cache_id].get_object() + seq_group_metadata.seq_data.clear() + seq_group_metadata.block_tables.clear() + # seq_id -> SequenceData seq_data: Dict[int, SequenceData] = {} # seq_id -> physical block numbers @@ -1139,6 +1212,10 @@ class Scheduler: ) seq_group_metadata_list.append(seq_group_metadata) + if allow_async_output_proc: + allow_async_output_proc = self._allow_async_output_proc( + seq_group) + # Now that the batch has been created, we can assume all blocks in the # batch will have been computed before the next scheduling invocation. # This is because the engine assumes that a failure in model execution @@ -1147,6 +1224,8 @@ class Scheduler: self.block_manager.mark_blocks_as_computed( scheduled_seq_group.seq_group) + self._seq_group_metadata_cache[self.next_cache_id].reset() + scheduler_time = time.perf_counter() - scheduler_start_time # Add this to scheduler time to all the sequences that are currently # running. This will help estimate if the scheduler is a significant @@ -1158,7 +1237,12 @@ class Scheduler: else: seq_group.metrics.scheduler_time = scheduler_time - return seq_group_metadata_list, scheduler_outputs + # Move to next cache (if exists) + self.cache_id = self.next_cache_id + + # Return results + return (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_manager.fork(parent_seq, child_seq) @@ -1167,6 +1251,12 @@ class Scheduler: """Free a sequence from a block table.""" self.block_manager.free(seq) + def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: + """Free finished seqs in a sequence group.""" + for seq in seq_group.get_seqs(): + if seq.is_finished(): + self.free_seq(seq) + def free_finished_seq_groups(self) -> None: remaining: Deque[SequenceGroup] = deque() for seq_group in self.running: @@ -1179,8 +1269,24 @@ class Scheduler: self._finished_requests_ids.append(seq_group.request_id) else: remaining.append(seq_group) + + # Free finished seqs + self._free_finished_seqs(seq_group) + self.running = remaining + # Handle async stopped sequence groups + # (ones that reached max model len) + if self._async_stopped: + for seq_group in self._async_stopped: + self._free_seq_group_cross_attn_blocks(seq_group) + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + self._async_stopped.clear() + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d759ce04d75e..efcc646d0e8e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -147,6 +147,7 @@ class EngineArgs: otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None + disable_async_output_proc: bool = False def __post_init__(self): if self.tokenizer is None: @@ -733,6 +734,12 @@ class EngineArgs: "modules. This involves use of possibly costly and or blocking " "operations and hence might have a performance impact.") + parser.add_argument( + '--disable-async-output-proc', + action='store_true', + default=EngineArgs.disable_async_output_proc, + help="Disable async output processing. This may result in " + "lower performance.") return parser @classmethod @@ -792,6 +799,7 @@ class EngineArgs: skip_tokenizer_init=self.skip_tokenizer_init, served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, + use_async_output_proc=not self.disable_async_output_proc, ) cache_config = CacheConfig( block_size=self.block_size if self.device != "neuron" else diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a2a80b141213..3445b7084bbc 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -277,23 +277,36 @@ class _AsyncLLMEngine(LLMEngine): cached_outputs = self.cached_scheduler_outputs[virtual_engine] seq_group_metadata_list = cached_outputs.seq_group_metadata_list scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + # skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - virtual_engine].schedule() + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() + + # If current scheduler iteration has no async postprocessor, + # then we need first to drain the pending async postprocessor + # before moving forward + if not allow_async_output_proc and len(self.output_queue) > 0: + self._process_model_outputs(is_async=True) if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have # lookahead slots self._cache_scheduler_outputs_for_multi_step( - virtual_engine, seq_group_metadata_list, scheduler_outputs) + virtual_engine, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) assert seq_group_metadata_list is not None assert scheduler_outputs is not None + assert not (self.scheduler_config.is_multi_step and \ + allow_async_output_proc) + if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -317,6 +330,11 @@ class _AsyncLLMEngine(LLMEngine): # We use ExecuteModelRequest to pass the last sampled_token_ids # to each of the non-last PP stages for in-place prepare_input. last_sampled_token_ids=last_sampled_token_ids) + + if allow_async_output_proc: + execute_model_req.output_proc_callback_fn = \ + self._process_model_outputs + # Execute the model. output = await self.model_executor.execute_model_async( execute_model_req) @@ -325,6 +343,9 @@ class _AsyncLLMEngine(LLMEngine): if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: + if len(self.output_queue) > 0: + assert not self.scheduler_config.is_multi_step + self._process_model_outputs(is_async=True) output = [] # Finish the current step for all the sequence groups. @@ -337,19 +358,32 @@ class _AsyncLLMEngine(LLMEngine): if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + + # Cache results in engine + self.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) + + if output and allow_async_output_proc: + assert len( + output + ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + if not allow_async_output_proc: + self._process_model_outputs(is_async=False) + + # Log stats. + self.do_log_stats(scheduler_outputs, output) + + # Tracing + self.do_tracing(scheduler_outputs) + else: - request_outputs = [] + self.request_outputs = [] - # Log stats. - self.do_log_stats(scheduler_outputs, output) - - # Tracing - self.do_tracing(scheduler_outputs) - - return request_outputs + return self.request_outputs async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 79072e403dc1..7356c1abbfa8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,7 +1,8 @@ import time +from collections import deque from contextlib import contextmanager from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, +from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List, Mapping, Optional) from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, Union @@ -38,9 +39,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - PoolerOutput, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata, - SequenceStatus) + SamplerOutput, Sequence, SequenceGroup, + SequenceGroupMetadata, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -82,9 +82,10 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], @dataclass class SchedulerOutputState: """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" - last_output: Optional[SamplerOutput] = None seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None scheduler_outputs: Optional[SchedulerOutputs] = None + allow_async_output_proc: bool = False + last_output: Optional[SamplerOutput] = None class LLMEngine: @@ -190,6 +191,9 @@ class LLMEngine: usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + # To improve performance, only final requests outputs may be required. + # If this set to true, then no intermediate outputs will be returned. + step_return_finished_only: bool = False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -204,7 +208,8 @@ class LLMEngine: "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s)", + "num_scheduler_steps=%d, enable_prefix_caching=%s, " + "use_async_output_proc=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -235,6 +240,7 @@ class LLMEngine: scheduler_config.use_v2_block_manager, scheduler_config.num_scheduler_steps, cache_config.enable_prefix_caching, + model_config.use_async_output_proc, ) # TODO(woosuk): Print more configs in debug mode. from vllm.plugins import load_general_plugins @@ -253,6 +259,7 @@ class LLMEngine: self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats + self.step_return_finished_only = step_return_finished_only if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -340,8 +347,11 @@ class LLMEngine: # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = [ - Scheduler(scheduler_config, cache_config, lora_config, - parallel_config.pipeline_parallel_size) + Scheduler( + scheduler_config, cache_config, lora_config, + parallel_config.pipeline_parallel_size, + self._process_model_outputs + if model_config.use_async_output_proc else None) for _ in range(parallel_config.pipeline_parallel_size) ] @@ -396,6 +406,13 @@ class LLMEngine: for _ in range(self.parallel_config.pipeline_parallel_size) ] + # Async output processing pointers + self.output_queue: Deque[Tuple[List[SamplerOutput], + List[SequenceGroupMetadata], + SchedulerOutputs]] = deque() + self.request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1197,34 +1214,66 @@ class LLMEngine: return - def _process_model_outputs( - self, - output: GenericSequence[Union[SamplerOutput, PoolerOutput]], - scheduled_seq_groups: List[ScheduledSequenceGroup], - ignored_seq_groups: List[SequenceGroup], - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + def _process_model_outputs(self, + is_async: bool, + clear_outputs: bool = True) -> None: """Apply the model output to the sequences in the scheduled seq groups. + is_async: Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) + clear_outputs: Sometimes existing outputs need to be combined + with outputs of this call. This happens for postprocessor + draining at the final stage (like when sequences are finished) + Returns RequestOutputs that can be returned to the client. """ - now = time.time() - # Organize outputs by [sequence group][step] instead of - # [step][sequence group]. - output_by_sequence_group = create_output_by_sequence_group( - output, num_seq_groups=len(scheduled_seq_groups)) + if clear_outputs: + self.request_outputs.clear() + + if len(self.output_queue) == 0: + return None + + (outputs, seq_group_metadata_list, + scheduler_outputs) = self.output_queue.popleft() + + # Sanity check + assert len(seq_group_metadata_list) == len( + scheduler_outputs.scheduled_seq_groups) + + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. + if len(outputs) > 1: + outputs_by_sequence_group = create_output_by_sequence_group( + outputs, num_seq_groups=len(seq_group_metadata_list)) + else: + outputs_by_sequence_group = outputs + + finished_before: List[int] = [] + for i, seq_group_meta in enumerate(seq_group_metadata_list): + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - # Update the scheduled sequence groups with the model outputs. - for scheduled_seq_group, outputs, seq_group_meta in zip( - scheduled_seq_groups, output_by_sequence_group, - seq_group_metadata_list): seq_group = scheduled_seq_group.seq_group - seq_group.update_num_computed_tokens( - scheduled_seq_group.token_chunk_size) - if output is not None and len(output) > 0: - for o in output: + + if seq_group.is_finished(): + finished_before.append(i) + continue + + if len(outputs) > 1: + output = outputs_by_sequence_group[i] + else: + output = [outputs_by_sequence_group[0][i]] + + if not is_async: + seq_group.update_num_computed_tokens( + scheduled_seq_group.token_chunk_size) + + if outputs: + for o in outputs: if (isinstance(o, SamplerOutput) and seq_group.metrics is not None): if seq_group.metrics.model_forward_time is not None: @@ -1239,30 +1288,75 @@ class LLMEngine: else: seq_group.metrics.model_execute_time = ( o.model_execute_time) + if self.model_config.embedding_mode: - self._process_sequence_group_outputs(seq_group, outputs) + self._process_sequence_group_outputs(seq_group, output) continue - self.output_processor.process_prompt_logprob(seq_group, outputs) + self.output_processor.process_prompt_logprob(seq_group, output) if seq_group_meta.do_sample: - self.output_processor.process_outputs(seq_group, outputs) + self.output_processor.process_outputs(seq_group, output, + is_async) # Free the finished sequence groups. for scheduler in self.scheduler: scheduler.free_finished_seq_groups() # Create the outputs. - request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = [] - for scheduled_seq_group in scheduled_seq_groups: + for i, _ in enumerate(seq_group_metadata_list): + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + if i in finished_before: + continue # Avoids double processing + seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) + if (seq_group.is_finished() + if self.step_return_finished_only else True): + request_output = RequestOutputFactory.create(seq_group) + self.request_outputs.append(request_output) + + for seq_group in scheduler_outputs.ignored_seq_groups: request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) - for seq_group in ignored_seq_groups: - request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) - return request_outputs + self.request_outputs.append(request_output) + + if is_async: + # Log stats. + self.do_log_stats(scheduler_outputs, outputs, finished_before) + + # Tracing + self.do_tracing(scheduler_outputs) + + return None + + def _advance_to_next_step( + self, output: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done inside output processor, but it is + required if the worker is to perform async forward pass to next step. + """ + for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ + zip(seq_group_metadata_list, output, scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + continue + + seq_group.update_num_computed_tokens( + seq_group_metadata.token_chunk_size) + + if seq_group_metadata.do_sample: + assert len(sequence_group_outputs.samples) == 1, ( + "Async output processor expects a single sample" + " (i.e sampling_params.n == 1 and no " + "sampling_params.best_of > 1)") + sample = sequence_group_outputs.samples[0] + + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + seq.append_token_id(sample.output_token, sample.logprobs) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. @@ -1325,24 +1419,32 @@ class LLMEngine: cached_outputs = self.cached_scheduler_outputs[0] seq_group_metadata_list = cached_outputs.seq_group_metadata_list scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc # Skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - 0].schedule() + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) = self.scheduler[0].schedule() + + if not allow_async_output_proc and len(self.output_queue) > 0: + self._process_model_outputs(is_async=True) if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have # lookahead slots self._cache_scheduler_outputs_for_multi_step( - 0, seq_group_metadata_list, scheduler_outputs) + 0, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) assert seq_group_metadata_list is not None assert scheduler_outputs is not None + assert not (self.scheduler_config.is_multi_step and \ + allow_async_output_proc) + if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ 0].get_and_reset_finished_requests_ids() @@ -1366,6 +1468,10 @@ class LLMEngine: # to each of the non-last PP stages for in-place prepare_input. last_sampled_token_ids=last_sampled_token_ids) + if allow_async_output_proc: + execute_model_req.output_proc_callback_fn = \ + self._process_model_outputs + output = self.model_executor.execute_model( execute_model_req=execute_model_req) @@ -1374,6 +1480,9 @@ class LLMEngine: if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(0, output) else: + if len(self.output_queue) > 0: + assert not self.scheduler_config.is_multi_step + self._process_model_outputs(is_async=True) output = [] # Finish the current step for all the sequence groups. @@ -1382,23 +1491,41 @@ class LLMEngine: seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps + # clear the cache if we have finished all the steps. if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[0] = SchedulerOutputState() - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + # Add results to the output_queue + # (for async or non-async postprocessing) + self.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) + + if output and allow_async_output_proc: + assert len(output) == 1, ("Multi step decoding does not work " + "with async output processing.") + + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + if not allow_async_output_proc: + self._process_model_outputs(is_async=False) + + # Log stats. + self.do_log_stats(scheduler_outputs, output) + + # Tracing + self.do_tracing(scheduler_outputs) else: - request_outputs = [] - - # Log stats. - self.do_log_stats(scheduler_outputs, output) - - # Tracing - self.do_tracing(scheduler_outputs) + self.request_outputs = [] if not self.has_unfinished_requests(): + # Drain async postprocessor + if len(self.output_queue) > 0: + assert not self.scheduler_config.is_multi_step + self._process_model_outputs(is_async=True, clear_outputs=False) + assert len(self.output_queue) == 0 + # Stop the execute model loop in parallel workers until there are # more requests to process. This avoids waiting indefinitely in # torch.distributed ops which may otherwise timeout, and unblocks @@ -1406,7 +1533,7 @@ class LLMEngine: # queued control plane messages, such as add/remove lora adapters. self.model_executor.stop_remote_worker_execution_loop() - return request_outputs + return self.request_outputs def _has_remaining_steps( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] @@ -1431,12 +1558,14 @@ class LLMEngine: def _cache_scheduler_outputs_for_multi_step( self, virtual_engine: int, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - scheduler_outputs: SchedulerOutputs) -> None: - self.cached_scheduler_outputs[ - virtual_engine].seq_group_metadata_list = seq_group_metadata_list - self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \ - scheduler_outputs - self.cached_scheduler_outputs[virtual_engine].last_output = None + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + co = self.cached_scheduler_outputs[virtual_engine] + + co.seq_group_metadata_list = seq_group_metadata_list + co.scheduler_outputs = scheduler_outputs + co.allow_async_output_proc = allow_async_output_proc + co.last_output = None def _update_cached_scheduler_output( self, virtual_engine: int, @@ -1472,20 +1601,21 @@ class LLMEngine: raise KeyError(f"Logger with name {logger_name} does not exist.") del self.stat_loggers[logger_name] - def do_log_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None) -> None: + def do_log_stats(self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: - stats = self._get_stats(scheduler_outputs, model_output) + stats = self._get_stats(scheduler_outputs, model_output, + finished_before) for logger in self.stat_loggers.values(): logger.log(stats) - def _get_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs], - model_output: Optional[List[SamplerOutput]] = None) -> Stats: + def _get_stats(self, + scheduler_outputs: Optional[SchedulerOutputs], + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None) -> Stats: """Get Stats to be Logged to Prometheus. Args: @@ -1550,6 +1680,10 @@ class LLMEngine: # NOTE: This loop assumes prefill seq_groups are before # decode seq_groups in scheduled_seq_groups. if scheduler_outputs is not None: + # For async postprocessor, already finished sequences need to be + # not counted (to avoid double counting) + actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore + num_generation_tokens_from_prefill_groups = 0. # NOTE: if scheduler_outputs.num_prefill_groups > 0 and # the len of scheduler_outputs.scheduled_seq_groups is != @@ -1558,6 +1692,11 @@ class LLMEngine: for idx, scheduled_seq_group in enumerate( scheduler_outputs.scheduled_seq_groups): + # Skip double logging when using async output proc + if finished_before and idx in finished_before: + actual_num_batched_tokens -= 1 + continue + group_was_prefill = idx < scheduler_outputs.num_prefill_groups seq_group = scheduled_seq_group.seq_group @@ -1592,7 +1731,6 @@ class LLMEngine: # Latency timings time_e2e_requests.append(now - seq_group.metrics.arrival_time) - # Metadata num_prompt_tokens_requests.append( len(seq_group.prompt_token_ids)) @@ -1616,7 +1754,7 @@ class LLMEngine: # + num_generation_tokens_from_prefill_groups (since we generate # one token on prefills on iters where the prefill finishes). num_generation_tokens_iter = ( - scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter + + actual_num_batched_tokens - num_prompt_tokens_iter + num_generation_tokens_from_prefill_groups) # Spec decode, if enabled, emits specialized metrics from the worker in diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index a385f37d807a..50adaf4e5918 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -40,13 +40,9 @@ class SequenceGroupOutputProcessor(ABC): # Importing here to avoid cycle. from vllm.engine.output_processor.single_step import ( SingleStepOutputProcessor) - return SingleStepOutputProcessor( - scheduler_config, - detokenizer, - scheduler, - seq_counter, - stop_checker, - ) + return SingleStepOutputProcessor(scheduler_config, detokenizer, + scheduler, seq_counter, + stop_checker) else: # Importing here to avoid cycle. from vllm.engine.output_processor.multi_step import ( @@ -61,7 +57,8 @@ class SequenceGroupOutputProcessor(ABC): @abstractmethod def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: """Process new token ids for the sequence group. Handles logic such as detokenization, stop checking, and freeing/forking sequences in the scheduler. diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 6c472528a7a9..49a33ded5fca 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -57,17 +57,28 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): "Prompt logprob is not supported by multi step workers. " "(e.g., speculative decode uses multi step workers).") - def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: + def process_outputs(self, + sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool = False) -> None: """Append new tokens in the outputs to sequences in the sequence group. This only supports sequence groups of size 1. It supports greater than one new token per sequence. - This applies logic like stop condition checking and detokenization, - including freeing finished sequences. It also handles cases where there - are tokens emitted after the EOS token. + This applies logic like stop condition checking and detokenization. + It also handles cases where there are tokens emitted after + the EOS token. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) """ + # TODO: Add support for async if necessary + assert not is_async + seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) assert seqs, "expected running sequences" @@ -138,7 +149,3 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ) if seq.is_finished(): break - - if seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4a46c93f8425..4b0c3f37a5e2 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -29,14 +29,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): that is currently difficult to schedule multiple steps ahead of time. """ - def __init__( - self, - scheduler_config: SchedulerConfig, - detokenizer: Detokenizer, - scheduler: List[Scheduler], - seq_counter: Counter, - stop_checker: StopChecker, - ): + def __init__(self, scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, scheduler: List[Scheduler], + seq_counter: Counter, stop_checker: StopChecker): self.scheduler_config = scheduler_config self.detokenizer = detokenizer self.scheduler = scheduler @@ -44,16 +39,24 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): self.stop_checker = stop_checker def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: """Append all new tokens to sequences in the sequence group. Fork any surviving beam candidates; free any unsurviving ones. Invokes detokenizer to detokenize new tokens, and also marks sequences as finished if they meet stop conditions. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) """ assert (len(outputs) == 1 ), f"{type(self)} does not support multiple outputs per step" - return self._process_sequence_group_outputs(sequence_group, outputs[0]) + return self._process_sequence_group_outputs(sequence_group, outputs[0], + is_async) def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: @@ -80,14 +83,16 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): seq_group.prompt_logprobs.extend(prompt_logprobs) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: + outputs: SequenceGroupOutput, + is_async: bool) -> None: sampling_params = seq_group.sampling_params if sampling_params.n == 1 and not sampling_params.use_beam_search: # only have one output sample sample = outputs.samples[0] # only have one sequence seq = seq_group.seqs[0] - seq.append_token_id(sample.output_token, sample.logprobs) + if not is_async: + seq.append_token_id(sample.output_token, sample.logprobs) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) @@ -104,6 +109,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): scheduler.free_seq(seq) return + # TODO: Add support for async for beam search + assert not is_async + # Process samples samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 31175724c6c7..ecc3c4004bbf 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -129,6 +129,7 @@ class LLM: max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, **kwargs, ) -> None: ''' @@ -170,6 +171,7 @@ class LLM: max_context_len_to_capture=max_context_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( @@ -603,7 +605,6 @@ class LLM: inputs = [inputs] num_requests = len(inputs) - if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -678,6 +679,10 @@ class LLM: postfix=(f"est. speed input: {0:.2f} toks/s, " f"output: {0:.2f} toks/s"), ) + + # In the loop below, only finished outputs are used + self.llm_engine.step_return_finished_only = True + # Run the engine. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_in_toks = 0 @@ -700,6 +705,10 @@ class LLM: f"est. speed input: {in_spd:.2f} toks/s, " f"output: {out_spd:.2f} toks/s") pbar.update(1) + + # Restore original behavior + self.llm_engine.step_return_finished_only = False + if use_tqdm: pbar.close() # Sort the outputs by request ID. diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 4df54a09e5e8..1a35a7c3b8f7 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -64,8 +64,9 @@ class DistributedGPUExecutor(GPUExecutor): num_cpu_blocks=num_cpu_blocks) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", @@ -188,7 +189,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): @abstractmethod async def _driver_execute_model_async( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> List[SamplerOutput]: """Execute the model asynchronously in the driver worker. diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 8346c3cc1d3e..795692195f84 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -176,5 +176,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): execute_model_req: ExecuteModelRequest, ) -> List[Union[SamplerOutput, PoolerOutput]]: output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req, ) + )(execute_model_req=execute_model_req) return output diff --git a/vllm/sequence.py b/vllm/sequence.py index 2fe8ae9d7b27..964072dd7c8f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,8 +5,8 @@ from abc import ABC, abstractmethod from array import array from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, - Tuple, Union, cast) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, + Optional, Set, Tuple, Union, cast) import msgspec import torch @@ -474,11 +474,8 @@ class Sequence: """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id( - self, - token_id: int, - logprobs: Dict[int, Logprob], - ) -> None: + def append_token_id(self, token_id: int, logprobs: Dict[int, + Logprob]) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) self.data.append_token_id(token_id, logprobs[token_id].logprob) @@ -1293,6 +1290,8 @@ class ExecuteModelRequest( finished_requests_ids: List[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. last_sampled_token_ids: Optional[torch.Tensor] = None + # Async postprocessor + output_proc_callback_fn: Optional[Callable] = None @property def is_first_multi_step(self) -> bool: @@ -1338,4 +1337,5 @@ class ExecuteModelRequest( num_steps=self.num_steps, finished_requests_ids=self.finished_requests_ids, last_sampled_token_ids=self.last_sampled_token_ids.clone() - if self.last_sampled_token_ids is not None else None) + if self.last_sampled_token_ids is not None else None, + output_proc_callback_fn=self.output_proc_callback_fn) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5d930919b8ae..adfdfdd32cb4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,8 +6,8 @@ import time import warnings import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, - TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, + Tuple, Type, TypeVar, Union) import numpy as np import torch @@ -90,6 +90,7 @@ class ModelInputForGPU(ModelRunnerInputBase): request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 + output_proc_callback_fn: Optional[Callable] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -1327,7 +1328,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None, ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -1451,6 +1452,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): if not self.is_driver_worker: return [] + if model_input.output_proc_callback_fn is not None: + model_input.output_proc_callback_fn(is_async=True) + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 516e38659519..e35d5c962a48 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -263,6 +263,12 @@ class LocalOrDistributedWorkerBase(WorkerBase): broadcast_data.update(kwargs) broadcast_tensor_dict(broadcast_data, src=0) + if execute_model_req.output_proc_callback_fn: + model_input = dataclasses.replace( # type: ignore + model_input, + output_proc_callback_fn=execute_model_req. + output_proc_callback_fn) + return model_input, worker_input, kwargs def prepare_input( @@ -289,7 +295,7 @@ class LocalOrDistributedWorkerBase(WorkerBase): def execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided."""