mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:35:01 +08:00
[Core] Add multi-step support to LLMEngine (#7789)
This commit is contained in:
parent
09c7792610
commit
9db93de20c
@ -335,7 +335,8 @@ steps:
|
|||||||
- vllm/engine
|
- vllm/engine
|
||||||
- tests/multi_step
|
- tests/multi_step
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s multi_step/test_correctness.py
|
- pytest -v -s multi_step/test_correctness_async_llm.py
|
||||||
|
- pytest -v -s multi_step/test_correctness_llm.py
|
||||||
|
|
||||||
- label: Pipeline Parallelism Test # 23min
|
- label: Pipeline Parallelism Test # 23min
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
|||||||
@ -82,6 +82,8 @@ def run_vllm(
|
|||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
distributed_executor_backend: Optional[str],
|
distributed_executor_backend: Optional[str],
|
||||||
gpu_memory_utilization: float = 0.9,
|
gpu_memory_utilization: float = 0.9,
|
||||||
|
num_scheduler_steps: int = 1,
|
||||||
|
use_v2_block_manager: bool = False,
|
||||||
download_dir: Optional[str] = None,
|
download_dir: Optional[str] = None,
|
||||||
load_format: str = EngineArgs.load_format,
|
load_format: str = EngineArgs.load_format,
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -106,6 +108,8 @@ def run_vllm(
|
|||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
load_format=load_format,
|
load_format=load_format,
|
||||||
|
num_scheduler_steps=num_scheduler_steps,
|
||||||
|
use_v2_block_manager=use_v2_block_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
@ -232,7 +236,8 @@ def main(args: argparse.Namespace):
|
|||||||
args.quantization_param_path, args.device,
|
args.quantization_param_path, args.device,
|
||||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||||
args.max_num_batched_tokens, args.distributed_executor_backend,
|
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||||
args.gpu_memory_utilization, args.download_dir, args.load_format)
|
args.gpu_memory_utilization, args.num_scheduler_steps,
|
||||||
|
args.use_v2_block_manager, args.download_dir, args.load_format)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
@ -353,10 +358,18 @@ if __name__ == "__main__":
|
|||||||
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
|
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
|
||||||
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
|
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
|
||||||
'CPU.')
|
'CPU.')
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-scheduler-steps",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Maximum number of forward steps per scheduler call.")
|
||||||
|
parser.add_argument("--use-v2-block-manager",
|
||||||
|
action='store_true',
|
||||||
|
help="Enable block manager v2.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-prefix-caching",
|
"--enable-prefix-caching",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="enable automatic prefix caching for vLLM backend.")
|
help="Enable automatic prefix caching for vLLM backend.")
|
||||||
parser.add_argument("--enable-chunked-prefill",
|
parser.add_argument("--enable-chunked-prefill",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="enable chunked prefill for vLLM backend.")
|
help="enable chunked prefill for vLLM backend.")
|
||||||
|
|||||||
@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files):
|
|||||||
expected_lora_output = [
|
expected_lora_output = [
|
||||||
"more important than knowledge.\nAuthor: Albert Einstein\n",
|
"more important than knowledge.\nAuthor: Albert Einstein\n",
|
||||||
"everyone else is already taken.\nAuthor: Oscar Wilde\n",
|
"everyone else is already taken.\nAuthor: Oscar Wilde\n",
|
||||||
"so little time.\nAuthor: Frank Zappa\n",
|
"so little time\nAuthor: Frank Zappa\n",
|
||||||
]
|
]
|
||||||
|
|
||||||
output1 = do_sample(llm, gemma_lora_files, lora_id=1)
|
output1 = do_sample(llm, gemma_lora_files, lora_id=1)
|
||||||
|
|||||||
49
tests/multi_step/test_correctness_llm.py
Normal file
49
tests/multi_step/test_correctness_llm.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# Test the LLMEngine with multi-step-decoding
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ..models.utils import check_outputs_equal
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"JackFram/llama-160m",
|
||||||
|
]
|
||||||
|
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
|
||||||
|
NUM_PROMPTS = [10]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("tp_size", [1])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
|
@pytest.mark.parametrize("enforce_eager", [True])
|
||||||
|
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||||
|
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||||
|
def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str,
|
||||||
|
dtype: str, tp_size: int, max_tokens: int,
|
||||||
|
enforce_eager: int, num_scheduler_steps: int,
|
||||||
|
num_prompts: int) -> None:
|
||||||
|
|
||||||
|
prompts = example_prompts
|
||||||
|
if len(prompts) < num_prompts:
|
||||||
|
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||||
|
prompts = prompts[:num_prompts]
|
||||||
|
assert len(prompts) == num_prompts
|
||||||
|
|
||||||
|
with vllm_runner(model,
|
||||||
|
dtype=dtype,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
use_v2_block_manager=True,
|
||||||
|
num_scheduler_steps=num_scheduler_steps) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
hf_outputs = hf_model.generate_greedy(prompts, max_tokens)
|
||||||
|
|
||||||
|
check_outputs_equal(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
@ -1,11 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
||||||
Mapping, Optional, Set, Tuple, Type, Union)
|
Mapping, Optional, Set, Tuple, Type, Union)
|
||||||
|
|
||||||
import torch
|
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -15,7 +13,7 @@ from vllm.core.scheduler import SchedulerOutputs
|
|||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_timeout import asyncio_timeout
|
from vllm.engine.async_timeout import asyncio_timeout
|
||||||
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
|
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
|
||||||
PromptComponents)
|
PromptComponents, SchedulerOutputState)
|
||||||
from vllm.engine.metrics_types import StatLoggerBase
|
from vllm.engine.metrics_types import StatLoggerBase
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||||
@ -28,8 +26,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
|||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
SequenceGroupMetadata)
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
@ -257,24 +254,11 @@ class RequestTracker:
|
|||||||
return not self._new_requests.empty()
|
return not self._new_requests.empty()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SchedulerOutputState:
|
|
||||||
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
|
|
||||||
last_output: Optional[SamplerOutput] = None
|
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
|
||||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
|
||||||
|
|
||||||
|
|
||||||
class _AsyncLLMEngine(LLMEngine):
|
class _AsyncLLMEngine(LLMEngine):
|
||||||
"""Extension of LLMEngine to add async methods."""
|
"""Extension of LLMEngine to add async methods."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
pipeline_parallel_size = \
|
|
||||||
self.parallel_config.pipeline_parallel_size
|
|
||||||
self.cached_scheduler_outputs = [
|
|
||||||
SchedulerOutputState() for _ in range(pipeline_parallel_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
async def step_async(
|
async def step_async(
|
||||||
self, virtual_engine: int
|
self, virtual_engine: int
|
||||||
@ -367,60 +351,6 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
|
|
||||||
return request_outputs
|
return request_outputs
|
||||||
|
|
||||||
def _has_remaining_steps(
|
|
||||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
|
||||||
) -> bool:
|
|
||||||
if (not self.scheduler_config.is_multi_step
|
|
||||||
or not seq_group_metadata_list):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# TODO(will) this is a sanity check for nowto make sure that all the
|
|
||||||
# seqs are on the same steps. Eventually we will want to do some sort of
|
|
||||||
# dynamic scheduling when doing multi-step decoding.
|
|
||||||
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
|
|
||||||
if any([
|
|
||||||
seq_group.state.remaining_steps != ref_remaining_steps
|
|
||||||
for seq_group in seq_group_metadata_list[1:]
|
|
||||||
]):
|
|
||||||
raise AssertionError(("All running sequence groups should "
|
|
||||||
"have the same remaining steps."))
|
|
||||||
|
|
||||||
return ref_remaining_steps > 0
|
|
||||||
|
|
||||||
def _cache_scheduler_outputs_for_multi_step(
|
|
||||||
self, virtual_engine: int,
|
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
|
||||||
scheduler_outputs: SchedulerOutputs) -> None:
|
|
||||||
self.cached_scheduler_outputs[
|
|
||||||
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
|
|
||||||
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
|
|
||||||
scheduler_outputs
|
|
||||||
self.cached_scheduler_outputs[virtual_engine].last_output = None
|
|
||||||
|
|
||||||
def _get_last_sampled_token_ids(
|
|
||||||
self, virtual_engine: int) -> Optional[torch.Tensor]:
|
|
||||||
cached_last_output = self.cached_scheduler_outputs[
|
|
||||||
virtual_engine].last_output
|
|
||||||
if (self.scheduler_config.is_multi_step
|
|
||||||
and self.parallel_config.pipeline_parallel_size > 1
|
|
||||||
and cached_last_output is not None
|
|
||||||
and cached_last_output.sampled_token_ids_cpu is not None):
|
|
||||||
return cached_last_output.sampled_token_ids_cpu
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _update_cached_scheduler_output(
|
|
||||||
self, virtual_engine: int,
|
|
||||||
output: List[Optional[SamplerOutput]]) -> None:
|
|
||||||
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
|
|
||||||
and output[0] is not None):
|
|
||||||
last_output = output[-1]
|
|
||||||
assert last_output is not None
|
|
||||||
assert last_output.sampled_token_ids_cpu is not None
|
|
||||||
assert last_output.sampled_token_ids is None
|
|
||||||
assert last_output.sampled_token_probs is None
|
|
||||||
self.cached_scheduler_outputs[
|
|
||||||
virtual_engine].last_output = last_output
|
|
||||||
|
|
||||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||||
"""Stop the remote worker execution loop."""
|
"""Stop the remote worker execution loop."""
|
||||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
|
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
|
||||||
Mapping, Optional)
|
Mapping, Optional)
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Set, Tuple, Type, Union
|
from typing import Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from typing_extensions import TypeVar, assert_never
|
from typing_extensions import TypeVar, assert_never
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -77,6 +79,14 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
|||||||
Optional[MultiModalDataDict]]
|
Optional[MultiModalDataDict]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SchedulerOutputState:
|
||||||
|
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
|
||||||
|
last_output: Optional[SamplerOutput] = None
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||||
|
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
"""An LLM engine that receives requests and generates texts.
|
"""An LLM engine that receives requests and generates texts.
|
||||||
|
|
||||||
@ -194,7 +204,7 @@ class LLMEngine:
|
|||||||
"quantization_param_path=%s, device_config=%s, "
|
"quantization_param_path=%s, device_config=%s, "
|
||||||
"decoding_config=%r, observability_config=%r, "
|
"decoding_config=%r, observability_config=%r, "
|
||||||
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
|
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
|
||||||
"enable_prefix_caching=%s)",
|
"num_scheduler_steps=%d, enable_prefix_caching=%s)",
|
||||||
VLLM_VERSION,
|
VLLM_VERSION,
|
||||||
model_config.model,
|
model_config.model,
|
||||||
speculative_config,
|
speculative_config,
|
||||||
@ -223,6 +233,7 @@ class LLMEngine:
|
|||||||
model_config.seed,
|
model_config.seed,
|
||||||
model_config.served_model_name,
|
model_config.served_model_name,
|
||||||
scheduler_config.use_v2_block_manager,
|
scheduler_config.use_v2_block_manager,
|
||||||
|
scheduler_config.num_scheduler_steps,
|
||||||
cache_config.enable_prefix_caching,
|
cache_config.enable_prefix_caching,
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
@ -380,6 +391,11 @@ class LLMEngine:
|
|||||||
),
|
),
|
||||||
))
|
))
|
||||||
|
|
||||||
|
self.cached_scheduler_outputs = [
|
||||||
|
SchedulerOutputState()
|
||||||
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
def _initialize_kv_caches(self) -> None:
|
def _initialize_kv_caches(self) -> None:
|
||||||
"""Initialize the KV cache in the worker(s).
|
"""Initialize the KV cache in the worker(s).
|
||||||
|
|
||||||
@ -1304,16 +1320,40 @@ class LLMEngine:
|
|||||||
"Pipeline parallelism is only supported through AsyncLLMEngine "
|
"Pipeline parallelism is only supported through AsyncLLMEngine "
|
||||||
"as performance will be severely degraded otherwise.")
|
"as performance will be severely degraded otherwise.")
|
||||||
|
|
||||||
if self.scheduler_config.num_scheduler_steps > 1:
|
# These are cached outputs from previous iterations. None if on first
|
||||||
raise NotImplementedError(
|
# iteration
|
||||||
"Multiple scheduler steps (multi-step) are only supported "
|
cached_outputs = self.cached_scheduler_outputs[0]
|
||||||
"through AsyncLLMEngine. ")
|
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
||||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||||
0].schedule()
|
|
||||||
|
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||||
|
# This ensures that the scheduler is only called again when the current
|
||||||
|
# batch has completed.
|
||||||
|
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||||
|
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||||
|
0].schedule()
|
||||||
|
|
||||||
|
if (self.scheduler_config.is_multi_step
|
||||||
|
and scheduler_outputs.num_lookahead_slots > 0):
|
||||||
|
# cache the scheduler outputs for the next iteration if we have
|
||||||
|
# lookahead slots
|
||||||
|
self._cache_scheduler_outputs_for_multi_step(
|
||||||
|
0, seq_group_metadata_list, scheduler_outputs)
|
||||||
|
|
||||||
|
assert seq_group_metadata_list is not None
|
||||||
|
assert scheduler_outputs is not None
|
||||||
|
|
||||||
if not scheduler_outputs.is_empty():
|
if not scheduler_outputs.is_empty():
|
||||||
finished_requests_ids = self.scheduler[
|
finished_requests_ids = self.scheduler[
|
||||||
0].get_and_reset_finished_requests_ids()
|
0].get_and_reset_finished_requests_ids()
|
||||||
|
|
||||||
|
# Check if we have a cached last_output from the previous iteration.
|
||||||
|
# For supporting PP this is probably the best way to pass the
|
||||||
|
# sampled_token_ids, as a separate broadcast over all the PP stages
|
||||||
|
# will cause one virtual engine's microbatch to block the pipeline.
|
||||||
|
last_sampled_token_ids = \
|
||||||
|
self._get_last_sampled_token_ids(0)
|
||||||
|
|
||||||
execute_model_req = ExecuteModelRequest(
|
execute_model_req = ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||||
@ -1321,15 +1361,36 @@ class LLMEngine:
|
|||||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||||
running_queue_size=scheduler_outputs.running_queue_size,
|
running_queue_size=scheduler_outputs.running_queue_size,
|
||||||
finished_requests_ids=finished_requests_ids)
|
finished_requests_ids=finished_requests_ids,
|
||||||
|
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
||||||
|
# to each of the non-last PP stages for in-place prepare_input.
|
||||||
|
last_sampled_token_ids=last_sampled_token_ids)
|
||||||
|
|
||||||
output = self.model_executor.execute_model(
|
output = self.model_executor.execute_model(
|
||||||
execute_model_req=execute_model_req)
|
execute_model_req=execute_model_req)
|
||||||
|
|
||||||
|
# we need to do this here so that last step's sampled_token_ids can
|
||||||
|
# be passed to the next iteration for PP.
|
||||||
|
if self.scheduler_config.is_multi_step:
|
||||||
|
self._update_cached_scheduler_output(0, output)
|
||||||
else:
|
else:
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
request_outputs = self._process_model_outputs(
|
# Finish the current step for all the sequence groups.
|
||||||
output, scheduler_outputs.scheduled_seq_groups,
|
if self.scheduler_config.is_multi_step:
|
||||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
for seq_group in seq_group_metadata_list:
|
||||||
|
seq_group.finish_step()
|
||||||
|
|
||||||
|
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||||
|
# clear the cache if we have finished all the steps
|
||||||
|
if self.scheduler_config.is_multi_step:
|
||||||
|
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||||
|
request_outputs = self._process_model_outputs(
|
||||||
|
output, scheduler_outputs.scheduled_seq_groups,
|
||||||
|
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||||
|
|
||||||
|
else:
|
||||||
|
request_outputs = []
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, output)
|
self.do_log_stats(scheduler_outputs, output)
|
||||||
@ -1347,6 +1408,60 @@ class LLMEngine:
|
|||||||
|
|
||||||
return request_outputs
|
return request_outputs
|
||||||
|
|
||||||
|
def _has_remaining_steps(
|
||||||
|
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||||
|
) -> bool:
|
||||||
|
if (not self.scheduler_config.is_multi_step
|
||||||
|
or not seq_group_metadata_list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# TODO(will) this is a sanity check for nowto make sure that all the
|
||||||
|
# seqs are on the same steps. Eventually we will want to do some sort of
|
||||||
|
# dynamic scheduling when doing multi-step decoding.
|
||||||
|
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
|
||||||
|
if any([
|
||||||
|
seq_group.state.remaining_steps != ref_remaining_steps
|
||||||
|
for seq_group in seq_group_metadata_list[1:]
|
||||||
|
]):
|
||||||
|
raise AssertionError(("All running sequence groups should "
|
||||||
|
"have the same remaining steps."))
|
||||||
|
|
||||||
|
return ref_remaining_steps > 0
|
||||||
|
|
||||||
|
def _cache_scheduler_outputs_for_multi_step(
|
||||||
|
self, virtual_engine: int,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
|
scheduler_outputs: SchedulerOutputs) -> None:
|
||||||
|
self.cached_scheduler_outputs[
|
||||||
|
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
|
||||||
|
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
|
||||||
|
scheduler_outputs
|
||||||
|
self.cached_scheduler_outputs[virtual_engine].last_output = None
|
||||||
|
|
||||||
|
def _update_cached_scheduler_output(
|
||||||
|
self, virtual_engine: int,
|
||||||
|
output: List[Optional[SamplerOutput]]) -> None:
|
||||||
|
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
|
||||||
|
and output[0] is not None):
|
||||||
|
last_output = output[-1]
|
||||||
|
assert last_output is not None
|
||||||
|
assert last_output.sampled_token_ids_cpu is not None
|
||||||
|
assert last_output.sampled_token_ids is None
|
||||||
|
assert last_output.sampled_token_probs is None
|
||||||
|
self.cached_scheduler_outputs[
|
||||||
|
virtual_engine].last_output = last_output
|
||||||
|
|
||||||
|
def _get_last_sampled_token_ids(
|
||||||
|
self, virtual_engine: int) -> Optional[torch.Tensor]:
|
||||||
|
cached_last_output = self.cached_scheduler_outputs[
|
||||||
|
virtual_engine].last_output
|
||||||
|
if (self.scheduler_config.is_multi_step
|
||||||
|
and self.parallel_config.pipeline_parallel_size > 1
|
||||||
|
and cached_last_output is not None
|
||||||
|
and cached_last_output.sampled_token_ids_cpu is not None):
|
||||||
|
return cached_last_output.sampled_token_ids_cpu
|
||||||
|
return None
|
||||||
|
|
||||||
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
|
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
|
||||||
if logger_name in self.stat_loggers:
|
if logger_name in self.stat_loggers:
|
||||||
raise KeyError(f"Logger with name {logger_name} already exists.")
|
raise KeyError(f"Logger with name {logger_name} already exists.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user