From 9db93de20ca282feb4dfaabbc56032c9312bde7b Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Fri, 23 Aug 2024 15:45:53 -0400 Subject: [PATCH] [Core] Add multi-step support to LLMEngine (#7789) --- .buildkite/test-pipeline.yaml | 3 +- benchmarks/benchmark_throughput.py | 17 ++- tests/lora/test_gemma.py | 2 +- ...tness.py => test_correctness_async_llm.py} | 0 tests/multi_step/test_correctness_llm.py | 49 +++++++ vllm/engine/async_llm_engine.py | 74 +--------- vllm/engine/llm_engine.py | 137 ++++++++++++++++-- 7 files changed, 195 insertions(+), 87 deletions(-) rename tests/multi_step/{test_correctness.py => test_correctness_async_llm.py} (100%) create mode 100644 tests/multi_step/test_correctness_llm.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d70a9ce24082..283776c06ed4 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -335,7 +335,8 @@ steps: - vllm/engine - tests/multi_step commands: - - pytest -v -s multi_step/test_correctness.py + - pytest -v -s multi_step/test_correctness_async_llm.py + - pytest -v -s multi_step/test_correctness_llm.py - label: Pipeline Parallelism Test # 23min working_dir: "/vllm-workspace/tests" diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index a52e67bbbe7e..1ccab2c65e69 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -82,6 +82,8 @@ def run_vllm( max_num_batched_tokens: int, distributed_executor_backend: Optional[str], gpu_memory_utilization: float = 0.9, + num_scheduler_steps: int = 1, + use_v2_block_manager: bool = False, download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, ) -> float: @@ -106,6 +108,8 @@ def run_vllm( max_num_batched_tokens=max_num_batched_tokens, distributed_executor_backend=distributed_executor_backend, load_format=load_format, + num_scheduler_steps=num_scheduler_steps, + use_v2_block_manager=use_v2_block_manager, ) # Add the requests to the engine. @@ -232,7 +236,8 @@ def main(args: argparse.Namespace): args.quantization_param_path, args.device, args.enable_prefix_caching, args.enable_chunked_prefill, args.max_num_batched_tokens, args.distributed_executor_backend, - args.gpu_memory_utilization, args.download_dir, args.load_format) + args.gpu_memory_utilization, args.num_scheduler_steps, + args.use_v2_block_manager, args.download_dir, args.load_format) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -353,10 +358,18 @@ if __name__ == "__main__": choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], help='device type for vLLM execution, supporting CUDA, OpenVINO and ' 'CPU.') + parser.add_argument( + "--num-scheduler-steps", + type=int, + default=1, + help="Maximum number of forward steps per scheduler call.") + parser.add_argument("--use-v2-block-manager", + action='store_true', + help="Enable block manager v2.") parser.add_argument( "--enable-prefix-caching", action='store_true', - help="enable automatic prefix caching for vLLM backend.") + help="Enable automatic prefix caching for vLLM backend.") parser.add_argument("--enable-chunked-prefill", action='store_true', help="enable chunked prefill for vLLM backend.") diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 478bb86b7861..709246179bfe 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files): expected_lora_output = [ "more important than knowledge.\nAuthor: Albert Einstein\n", "everyone else is already taken.\nAuthor: Oscar Wilde\n", - "so little time.\nAuthor: Frank Zappa\n", + "so little time\nAuthor: Frank Zappa\n", ] output1 = do_sample(llm, gemma_lora_files, lora_id=1) diff --git a/tests/multi_step/test_correctness.py b/tests/multi_step/test_correctness_async_llm.py similarity index 100% rename from tests/multi_step/test_correctness.py rename to tests/multi_step/test_correctness_async_llm.py diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py new file mode 100644 index 000000000000..36f610ba74f0 --- /dev/null +++ b/tests/multi_step/test_correctness_llm.py @@ -0,0 +1,49 @@ +# Test the LLMEngine with multi-step-decoding + +import pytest + +from ..models.utils import check_outputs_equal + +MODELS = [ + "JackFram/llama-160m", +] +NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps +NUM_PROMPTS = [10] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str, + dtype: str, tp_size: int, max_tokens: int, + enforce_eager: int, num_scheduler_steps: int, + num_prompts: int) -> None: + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + with vllm_runner(model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8812b853c066..a2a80b141213 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,11 +1,9 @@ import asyncio import time -from dataclasses import dataclass from functools import partial from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) -import torch from typing_extensions import assert_never import vllm.envs as envs @@ -15,7 +13,7 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, - PromptComponents) + PromptComponents, SchedulerOutputState) from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.ray_utils import initialize_ray_cluster, ray @@ -28,8 +26,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import print_warning_once @@ -257,24 +254,11 @@ class RequestTracker: return not self._new_requests.empty() -@dataclass -class SchedulerOutputState: - """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" - last_output: Optional[SamplerOutput] = None - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None - scheduler_outputs: Optional[SchedulerOutputs] = None - - class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - pipeline_parallel_size = \ - self.parallel_config.pipeline_parallel_size - self.cached_scheduler_outputs = [ - SchedulerOutputState() for _ in range(pipeline_parallel_size) - ] async def step_async( self, virtual_engine: int @@ -367,60 +351,6 @@ class _AsyncLLMEngine(LLMEngine): return request_outputs - def _has_remaining_steps( - self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] - ) -> bool: - if (not self.scheduler_config.is_multi_step - or not seq_group_metadata_list): - return False - - # TODO(will) this is a sanity check for nowto make sure that all the - # seqs are on the same steps. Eventually we will want to do some sort of - # dynamic scheduling when doing multi-step decoding. - ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps - if any([ - seq_group.state.remaining_steps != ref_remaining_steps - for seq_group in seq_group_metadata_list[1:] - ]): - raise AssertionError(("All running sequence groups should " - "have the same remaining steps.")) - - return ref_remaining_steps > 0 - - def _cache_scheduler_outputs_for_multi_step( - self, virtual_engine: int, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - scheduler_outputs: SchedulerOutputs) -> None: - self.cached_scheduler_outputs[ - virtual_engine].seq_group_metadata_list = seq_group_metadata_list - self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \ - scheduler_outputs - self.cached_scheduler_outputs[virtual_engine].last_output = None - - def _get_last_sampled_token_ids( - self, virtual_engine: int) -> Optional[torch.Tensor]: - cached_last_output = self.cached_scheduler_outputs[ - virtual_engine].last_output - if (self.scheduler_config.is_multi_step - and self.parallel_config.pipeline_parallel_size > 1 - and cached_last_output is not None - and cached_last_output.sampled_token_ids_cpu is not None): - return cached_last_output.sampled_token_ids_cpu - return None - - def _update_cached_scheduler_output( - self, virtual_engine: int, - output: List[Optional[SamplerOutput]]) -> None: - if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 - and output[0] is not None): - last_output = output[-1] - assert last_output is not None - assert last_output.sampled_token_ids_cpu is not None - assert last_output.sampled_token_ids is None - assert last_output.sampled_token_probs is None - self.cached_scheduler_outputs[ - virtual_engine].last_output = last_output - async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c98b64181d0..79072e403dc1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,10 +1,12 @@ import time from contextlib import contextmanager +from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Mapping, Optional) from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, Union +import torch from typing_extensions import TypeVar, assert_never import vllm.envs as envs @@ -77,6 +79,14 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], Optional[MultiModalDataDict]] +@dataclass +class SchedulerOutputState: + """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" + last_output: Optional[SamplerOutput] = None + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -194,7 +204,7 @@ class LLMEngine: "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "enable_prefix_caching=%s)", + "num_scheduler_steps=%d, enable_prefix_caching=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -223,6 +233,7 @@ class LLMEngine: model_config.seed, model_config.served_model_name, scheduler_config.use_v2_block_manager, + scheduler_config.num_scheduler_steps, cache_config.enable_prefix_caching, ) # TODO(woosuk): Print more configs in debug mode. @@ -380,6 +391,11 @@ class LLMEngine: ), )) + self.cached_scheduler_outputs = [ + SchedulerOutputState() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1304,16 +1320,40 @@ class LLMEngine: "Pipeline parallelism is only supported through AsyncLLMEngine " "as performance will be severely degraded otherwise.") - if self.scheduler_config.num_scheduler_steps > 1: - raise NotImplementedError( - "Multiple scheduler steps (multi-step) are only supported " - "through AsyncLLMEngine. ") - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - 0].schedule() + # These are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[0] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + + # Skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + seq_group_metadata_list, scheduler_outputs = self.scheduler[ + 0].schedule() + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + 0, seq_group_metadata_list, scheduler_outputs) + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ 0].get_and_reset_finished_requests_ids() + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(0) + execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, @@ -1321,15 +1361,36 @@ class LLMEngine: blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids) + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + output = self.model_executor.execute_model( execute_model_req=execute_model_req) + + # we need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(0, output) else: output = [] - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[0] = SchedulerOutputState() + request_outputs = self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + + else: + request_outputs = [] # Log stats. self.do_log_stats(scheduler_outputs, output) @@ -1347,6 +1408,60 @@ class LLMEngine: return request_outputs + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + if (not self.scheduler_config.is_multi_step + or not seq_group_metadata_list): + return False + + # TODO(will) this is a sanity check for nowto make sure that all the + # seqs are on the same steps. Eventually we will want to do some sort of + # dynamic scheduling when doing multi-step decoding. + ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps + if any([ + seq_group.state.remaining_steps != ref_remaining_steps + for seq_group in seq_group_metadata_list[1:] + ]): + raise AssertionError(("All running sequence groups should " + "have the same remaining steps.")) + + return ref_remaining_steps > 0 + + def _cache_scheduler_outputs_for_multi_step( + self, virtual_engine: int, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + scheduler_outputs: SchedulerOutputs) -> None: + self.cached_scheduler_outputs[ + virtual_engine].seq_group_metadata_list = seq_group_metadata_list + self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \ + scheduler_outputs + self.cached_scheduler_outputs[virtual_engine].last_output = None + + def _update_cached_scheduler_output( + self, virtual_engine: int, + output: List[Optional[SamplerOutput]]) -> None: + if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 + and output[0] is not None): + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_cpu is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output + + def _get_last_sampled_token_ids( + self, virtual_engine: int) -> Optional[torch.Tensor]: + cached_last_output = self.cached_scheduler_outputs[ + virtual_engine].last_output + if (self.scheduler_config.is_multi_step + and self.parallel_config.pipeline_parallel_size > 1 + and cached_last_output is not None + and cached_last_output.sampled_token_ids_cpu is not None): + return cached_last_output.sampled_token_ids_cpu + return None + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: if logger_name in self.stat_loggers: raise KeyError(f"Logger with name {logger_name} already exists.")