mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[Core] Add multi-step support to LLMEngine (#7789)
This commit is contained in:
parent
09c7792610
commit
9db93de20c
@ -335,7 +335,8 @@ steps:
|
||||
- vllm/engine
|
||||
- tests/multi_step
|
||||
commands:
|
||||
- pytest -v -s multi_step/test_correctness.py
|
||||
- pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 23min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
@ -82,6 +82,8 @@ def run_vllm(
|
||||
max_num_batched_tokens: int,
|
||||
distributed_executor_backend: Optional[str],
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
num_scheduler_steps: int = 1,
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
) -> float:
|
||||
@ -106,6 +108,8 @@ def run_vllm(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@ -232,7 +236,8 @@ def main(args: argparse.Namespace):
|
||||
args.quantization_param_path, args.device,
|
||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||
args.gpu_memory_utilization, args.download_dir, args.load_format)
|
||||
args.gpu_memory_utilization, args.num_scheduler_steps,
|
||||
args.use_v2_block_manager, args.download_dir, args.load_format)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -353,10 +358,18 @@ if __name__ == "__main__":
|
||||
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
|
||||
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
|
||||
'CPU.')
|
||||
parser.add_argument(
|
||||
"--num-scheduler-steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Maximum number of forward steps per scheduler call.")
|
||||
parser.add_argument("--use-v2-block-manager",
|
||||
action='store_true',
|
||||
help="Enable block manager v2.")
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="enable automatic prefix caching for vLLM backend.")
|
||||
help="Enable automatic prefix caching for vLLM backend.")
|
||||
parser.add_argument("--enable-chunked-prefill",
|
||||
action='store_true',
|
||||
help="enable chunked prefill for vLLM backend.")
|
||||
|
||||
@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files):
|
||||
expected_lora_output = [
|
||||
"more important than knowledge.\nAuthor: Albert Einstein\n",
|
||||
"everyone else is already taken.\nAuthor: Oscar Wilde\n",
|
||||
"so little time.\nAuthor: Frank Zappa\n",
|
||||
"so little time\nAuthor: Frank Zappa\n",
|
||||
]
|
||||
|
||||
output1 = do_sample(llm, gemma_lora_files, lora_id=1)
|
||||
|
||||
49
tests/multi_step/test_correctness_llm.py
Normal file
49
tests/multi_step/test_correctness_llm.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Test the LLMEngine with multi-step-decoding
|
||||
|
||||
import pytest
|
||||
|
||||
from ..models.utils import check_outputs_equal
|
||||
|
||||
MODELS = [
|
||||
"JackFram/llama-160m",
|
||||
]
|
||||
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
|
||||
NUM_PROMPTS = [10]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
dtype: str, tp_size: int, max_tokens: int,
|
||||
enforce_eager: int, num_scheduler_steps: int,
|
||||
num_prompts: int) -> None:
|
||||
|
||||
prompts = example_prompts
|
||||
if len(prompts) < num_prompts:
|
||||
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||
prompts = prompts[:num_prompts]
|
||||
assert len(prompts) == num_prompts
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
use_v2_block_manager=True,
|
||||
num_scheduler_steps=num_scheduler_steps) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
@ -1,11 +1,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
|
||||
import torch
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -15,7 +13,7 @@ from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
|
||||
PromptComponents)
|
||||
PromptComponents, SchedulerOutputState)
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||
@ -28,8 +26,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import print_warning_once
|
||||
@ -257,24 +254,11 @@ class RequestTracker:
|
||||
return not self._new_requests.empty()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerOutputState:
|
||||
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
|
||||
last_output: Optional[SamplerOutput] = None
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
pipeline_parallel_size = \
|
||||
self.parallel_config.pipeline_parallel_size
|
||||
self.cached_scheduler_outputs = [
|
||||
SchedulerOutputState() for _ in range(pipeline_parallel_size)
|
||||
]
|
||||
|
||||
async def step_async(
|
||||
self, virtual_engine: int
|
||||
@ -367,60 +351,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
return request_outputs
|
||||
|
||||
def _has_remaining_steps(
|
||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||
) -> bool:
|
||||
if (not self.scheduler_config.is_multi_step
|
||||
or not seq_group_metadata_list):
|
||||
return False
|
||||
|
||||
# TODO(will) this is a sanity check for nowto make sure that all the
|
||||
# seqs are on the same steps. Eventually we will want to do some sort of
|
||||
# dynamic scheduling when doing multi-step decoding.
|
||||
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
|
||||
if any([
|
||||
seq_group.state.remaining_steps != ref_remaining_steps
|
||||
for seq_group in seq_group_metadata_list[1:]
|
||||
]):
|
||||
raise AssertionError(("All running sequence groups should "
|
||||
"have the same remaining steps."))
|
||||
|
||||
return ref_remaining_steps > 0
|
||||
|
||||
def _cache_scheduler_outputs_for_multi_step(
|
||||
self, virtual_engine: int,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
scheduler_outputs: SchedulerOutputs) -> None:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
|
||||
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
|
||||
scheduler_outputs
|
||||
self.cached_scheduler_outputs[virtual_engine].last_output = None
|
||||
|
||||
def _get_last_sampled_token_ids(
|
||||
self, virtual_engine: int) -> Optional[torch.Tensor]:
|
||||
cached_last_output = self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and self.parallel_config.pipeline_parallel_size > 1
|
||||
and cached_last_output is not None
|
||||
and cached_last_output.sampled_token_ids_cpu is not None):
|
||||
return cached_last_output.sampled_token_ids_cpu
|
||||
return None
|
||||
|
||||
def _update_cached_scheduler_output(
|
||||
self, virtual_engine: int,
|
||||
output: List[Optional[SamplerOutput]]) -> None:
|
||||
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
|
||||
and output[0] is not None):
|
||||
last_output = output[-1]
|
||||
assert last_output is not None
|
||||
assert last_output.sampled_token_ids_cpu is not None
|
||||
assert last_output.sampled_token_ids is None
|
||||
assert last_output.sampled_token_probs is None
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output = last_output
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
|
||||
Mapping, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeVar, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -77,6 +79,14 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
||||
Optional[MultiModalDataDict]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerOutputState:
|
||||
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
|
||||
last_output: Optional[SamplerOutput] = None
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
|
||||
@ -194,7 +204,7 @@ class LLMEngine:
|
||||
"quantization_param_path=%s, device_config=%s, "
|
||||
"decoding_config=%r, observability_config=%r, "
|
||||
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
|
||||
"enable_prefix_caching=%s)",
|
||||
"num_scheduler_steps=%d, enable_prefix_caching=%s)",
|
||||
VLLM_VERSION,
|
||||
model_config.model,
|
||||
speculative_config,
|
||||
@ -223,6 +233,7 @@ class LLMEngine:
|
||||
model_config.seed,
|
||||
model_config.served_model_name,
|
||||
scheduler_config.use_v2_block_manager,
|
||||
scheduler_config.num_scheduler_steps,
|
||||
cache_config.enable_prefix_caching,
|
||||
)
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
@ -380,6 +391,11 @@ class LLMEngine:
|
||||
),
|
||||
))
|
||||
|
||||
self.cached_scheduler_outputs = [
|
||||
SchedulerOutputState()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
"""Initialize the KV cache in the worker(s).
|
||||
|
||||
@ -1304,16 +1320,40 @@ class LLMEngine:
|
||||
"Pipeline parallelism is only supported through AsyncLLMEngine "
|
||||
"as performance will be severely degraded otherwise.")
|
||||
|
||||
if self.scheduler_config.num_scheduler_steps > 1:
|
||||
raise NotImplementedError(
|
||||
"Multiple scheduler steps (multi-step) are only supported "
|
||||
"through AsyncLLMEngine. ")
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||
0].schedule()
|
||||
# These are cached outputs from previous iterations. None if on first
|
||||
# iteration
|
||||
cached_outputs = self.cached_scheduler_outputs[0]
|
||||
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
|
||||
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||
0].schedule()
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
# lookahead slots
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
0, seq_group_metadata_list, scheduler_outputs)
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
finished_requests_ids = self.scheduler[
|
||||
0].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Check if we have a cached last_output from the previous iteration.
|
||||
# For supporting PP this is probably the best way to pass the
|
||||
# sampled_token_ids, as a separate broadcast over all the PP stages
|
||||
# will cause one virtual engine's microbatch to block the pipeline.
|
||||
last_sampled_token_ids = \
|
||||
self._get_last_sampled_token_ids(0)
|
||||
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
@ -1321,15 +1361,36 @@ class LLMEngine:
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||
running_queue_size=scheduler_outputs.running_queue_size,
|
||||
finished_requests_ids=finished_requests_ids)
|
||||
finished_requests_ids=finished_requests_ids,
|
||||
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
||||
# to each of the non-last PP stages for in-place prepare_input.
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
|
||||
output = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
|
||||
# we need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(0, output)
|
||||
else:
|
||||
output = []
|
||||
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
# Finish the current step for all the sequence groups.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
for seq_group in seq_group_metadata_list:
|
||||
seq_group.finish_step()
|
||||
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
# clear the cache if we have finished all the steps
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
|
||||
else:
|
||||
request_outputs = []
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
@ -1347,6 +1408,60 @@ class LLMEngine:
|
||||
|
||||
return request_outputs
|
||||
|
||||
def _has_remaining_steps(
|
||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||
) -> bool:
|
||||
if (not self.scheduler_config.is_multi_step
|
||||
or not seq_group_metadata_list):
|
||||
return False
|
||||
|
||||
# TODO(will) this is a sanity check for nowto make sure that all the
|
||||
# seqs are on the same steps. Eventually we will want to do some sort of
|
||||
# dynamic scheduling when doing multi-step decoding.
|
||||
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
|
||||
if any([
|
||||
seq_group.state.remaining_steps != ref_remaining_steps
|
||||
for seq_group in seq_group_metadata_list[1:]
|
||||
]):
|
||||
raise AssertionError(("All running sequence groups should "
|
||||
"have the same remaining steps."))
|
||||
|
||||
return ref_remaining_steps > 0
|
||||
|
||||
def _cache_scheduler_outputs_for_multi_step(
|
||||
self, virtual_engine: int,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
scheduler_outputs: SchedulerOutputs) -> None:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
|
||||
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
|
||||
scheduler_outputs
|
||||
self.cached_scheduler_outputs[virtual_engine].last_output = None
|
||||
|
||||
def _update_cached_scheduler_output(
|
||||
self, virtual_engine: int,
|
||||
output: List[Optional[SamplerOutput]]) -> None:
|
||||
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
|
||||
and output[0] is not None):
|
||||
last_output = output[-1]
|
||||
assert last_output is not None
|
||||
assert last_output.sampled_token_ids_cpu is not None
|
||||
assert last_output.sampled_token_ids is None
|
||||
assert last_output.sampled_token_probs is None
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output = last_output
|
||||
|
||||
def _get_last_sampled_token_ids(
|
||||
self, virtual_engine: int) -> Optional[torch.Tensor]:
|
||||
cached_last_output = self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and self.parallel_config.pipeline_parallel_size > 1
|
||||
and cached_last_output is not None
|
||||
and cached_last_output.sampled_token_ids_cpu is not None):
|
||||
return cached_last_output.sampled_token_ids_cpu
|
||||
return None
|
||||
|
||||
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
|
||||
if logger_name in self.stat_loggers:
|
||||
raise KeyError(f"Logger with name {logger_name} already exists.")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user