mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[Core] Asynchronous Output Processor (#7049)
Co-authored-by: Alexander Matveev <alexm@neuralmagic.com>
This commit is contained in:
parent
015e6cc252
commit
2eedede875
@ -86,6 +86,7 @@ def run_vllm(
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
@ -110,6 +111,7 @@ def run_vllm(
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@ -237,7 +239,8 @@ def main(args: argparse.Namespace):
|
||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||
args.gpu_memory_utilization, args.num_scheduler_steps,
|
||||
args.use_v2_block_manager, args.download_dir, args.load_format)
|
||||
args.use_v2_block_manager, args.download_dir, args.load_format,
|
||||
args.disable_async_output_proc)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -418,6 +421,11 @@ if __name__ == "__main__":
|
||||
'section for more information.\n'
|
||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||
'quantization.\n')
|
||||
parser.add_argument(
|
||||
"--disable-async-output-proc",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable async output processor for vLLM backend.")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
||||
@ -88,6 +88,9 @@ def test_models(
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
# Due to low-precision numerical divergence, this test is too sensitive to
|
||||
# the async postprocessor
|
||||
@pytest.mark.parametrize("disable_async_output_proc", [True])
|
||||
def test_models_with_fp8_kv_cache(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
@ -97,6 +100,7 @@ def test_models_with_fp8_kv_cache(
|
||||
chunked_prefill_token_size: int,
|
||||
enforce_eager: bool,
|
||||
tensor_parallel_size: int,
|
||||
disable_async_output_proc: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Only checks log probs match between chunked-prefill and
|
||||
@ -126,6 +130,7 @@ def test_models_with_fp8_kv_cache(
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
**extra_kwargs,
|
||||
) as vllm_model:
|
||||
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||
@ -139,6 +144,7 @@ def test_models_with_fp8_kv_cache(
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
**extra_kwargs,
|
||||
) as vllm_model:
|
||||
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||
|
||||
@ -209,7 +209,6 @@ def test_swap_infeasible(
|
||||
prefill_blocks = 2
|
||||
decode_blocks = max_tokens // BLOCK_SIZE
|
||||
example_prompts = example_prompts[:1]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
|
||||
@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int):
|
||||
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out = scheduler.schedule()
|
||||
metas, out, _ = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
return metas, out
|
||||
@ -180,7 +180,7 @@ def test_maximal_decoding():
|
||||
"""Verify decoding requests are prioritized."""
|
||||
block_size = 4
|
||||
max_seqs = 2
|
||||
max_model_len = 2
|
||||
max_model_len = 8
|
||||
max_num_batched_tokens = 2
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
|
||||
@ -199,7 +199,7 @@ def append_new_token(out, token_id: int):
|
||||
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out = scheduler.schedule()
|
||||
metas, out, _ = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
return metas, out
|
||||
|
||||
@ -7,6 +7,8 @@ from vllm import CompletionOutput, LLMEngine, SamplingParams
|
||||
MODEL = "meta-llama/llama-2-7b-hf"
|
||||
MAX_TOKENS = 200
|
||||
|
||||
IS_ASYNC = False
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vllm_model(vllm_runner):
|
||||
@ -14,77 +16,13 @@ def vllm_model(vllm_runner):
|
||||
yield vllm_model
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_basic(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=".")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization.",
|
||||
expected_reason=".")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_multi_tokens(vllm_model):
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||
expected_reason="group of peo")
|
||||
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=True,
|
||||
expected_output=
|
||||
"VLLM is a 100% volunteer organization. We are a group of peo",
|
||||
expected_reason="group of peo")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_partial_token(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer or",
|
||||
expected_reason="gani")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organi",
|
||||
expected_reason="gani")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_token_id(vllm_model):
|
||||
# token id 13013 => " organization"
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer",
|
||||
expected_reason=13013)
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=13013)
|
||||
|
||||
|
||||
def _test_stopping(llm_engine: LLMEngine,
|
||||
expected_output: str,
|
||||
expected_reason: Any,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_in_output: bool = False) -> None:
|
||||
include_in_output: bool = False,
|
||||
use_async_output_proc: bool = False) -> None:
|
||||
llm_engine.add_request(
|
||||
"id", "A story about vLLM:\n",
|
||||
SamplingParams(
|
||||
@ -98,6 +36,10 @@ def _test_stopping(llm_engine: LLMEngine,
|
||||
output: Optional[CompletionOutput] = None
|
||||
output_text = ""
|
||||
stop_reason = None
|
||||
|
||||
if use_async_output_proc:
|
||||
llm_engine.step()
|
||||
|
||||
while llm_engine.has_unfinished_requests():
|
||||
(request_output, ) = llm_engine.step()
|
||||
(output, ) = request_output.outputs
|
||||
@ -110,3 +52,112 @@ def _test_stopping(llm_engine: LLMEngine,
|
||||
assert output is not None
|
||||
assert output_text == expected_output
|
||||
assert stop_reason == expected_reason
|
||||
|
||||
|
||||
def _set_async_mode(llm_engine, is_async):
|
||||
llm_engine.scheduler[0].use_async_output_proc = is_async
|
||||
|
||||
|
||||
def _stop_basic(llm_engine, is_async):
|
||||
_test_stopping(llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=".",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
_test_stopping(llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization.",
|
||||
expected_reason=".",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
|
||||
def _stop_multi_tokens(llm_engine, is_async):
|
||||
_test_stopping(
|
||||
llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||
expected_reason="group of peo",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
_test_stopping(
|
||||
llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=True,
|
||||
expected_output=
|
||||
"VLLM is a 100% volunteer organization. We are a group of peo",
|
||||
expected_reason="group of peo",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
|
||||
def _stop_partial_token(llm_engine, is_async):
|
||||
_test_stopping(llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer or",
|
||||
expected_reason="gani",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
_test_stopping(llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organi",
|
||||
expected_reason="gani",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
|
||||
def _stop_token_id(llm_engine, is_async):
|
||||
# token id 13013 => " organization"
|
||||
|
||||
_test_stopping(llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer",
|
||||
expected_reason=13013,
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
_test_stopping(llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=13013,
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_basic(vllm_model):
|
||||
_set_async_mode(vllm_model.model.llm_engine, True)
|
||||
_stop_basic(vllm_model.model.llm_engine, is_async=True)
|
||||
|
||||
_set_async_mode(vllm_model.model.llm_engine, False)
|
||||
_stop_basic(vllm_model.model.llm_engine, is_async=False)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_multi_tokens(vllm_model):
|
||||
_set_async_mode(vllm_model.model.llm_engine, True)
|
||||
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)
|
||||
|
||||
_set_async_mode(vllm_model.model.llm_engine, False)
|
||||
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_partial_token(vllm_model):
|
||||
_set_async_mode(vllm_model.model.llm_engine, True)
|
||||
_stop_partial_token(vllm_model.model.llm_engine, is_async=True)
|
||||
|
||||
_set_async_mode(vllm_model.model.llm_engine, False)
|
||||
_stop_partial_token(vllm_model.model.llm_engine, is_async=False)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_token_id(vllm_model):
|
||||
_set_async_mode(vllm_model.model.llm_engine, True)
|
||||
_stop_token_id(vllm_model.model.llm_engine, is_async=True)
|
||||
|
||||
_set_async_mode(vllm_model.model.llm_engine, False)
|
||||
_stop_token_id(vllm_model.model.llm_engine, is_async=False)
|
||||
|
||||
@ -62,6 +62,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
|
||||
ms_server_args = DEFAULT_SERVER_ARGS + \
|
||||
["--num-scheduler-steps", f"{num_scheduler_steps}"]
|
||||
|
||||
# Disable output proc callback as its not supported
|
||||
# with multi-step right now
|
||||
ms_server_args += ["--disable-async-output-proc"]
|
||||
if eager_mode:
|
||||
ms_server_args.append("--enforce-eager")
|
||||
|
||||
|
||||
@ -140,6 +140,7 @@ class ModelConfig:
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@ -172,6 +173,7 @@ class ModelConfig:
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
self.model, revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
self.use_async_output_proc = use_async_output_proc
|
||||
|
||||
# Choose a default enforce_eager value if the user did not specify
|
||||
# a value (enforce_eager is None)
|
||||
@ -326,6 +328,49 @@ class ModelConfig:
|
||||
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
||||
self.max_model_len)
|
||||
|
||||
def verify_async_output_proc(self, parallel_config, speculative_config,
|
||||
device_config) -> None:
|
||||
if not self.use_async_output_proc:
|
||||
# Nothing to check
|
||||
return
|
||||
|
||||
if parallel_config.pipeline_parallel_size > 1:
|
||||
logger.warning("Async output processing can not be enabled "
|
||||
"with pipeline parallel")
|
||||
self.use_async_output_proc = False
|
||||
return
|
||||
|
||||
if device_config.device_type != "cuda":
|
||||
logger.warning(
|
||||
"Async output processing is only supported for CUDA."
|
||||
" Disabling it for other platforms.")
|
||||
self.use_async_output_proc = False
|
||||
return
|
||||
|
||||
if envs.VLLM_USE_RAY_SPMD_WORKER:
|
||||
logger.warning(
|
||||
"Async output processing can not be enabled with ray spmd")
|
||||
self.use_async_output_proc = False
|
||||
return
|
||||
|
||||
if self.enforce_eager:
|
||||
logger.warning(
|
||||
"To see benefits of async output processing, enable CUDA "
|
||||
"graph. Since, enforce-eager is enabled, async output "
|
||||
"processor cannot be used")
|
||||
self.use_async_output_proc = not self.enforce_eager
|
||||
return
|
||||
|
||||
# Async postprocessor is not necessary with embedding mode
|
||||
# since there is no token generation
|
||||
if self.embedding_mode:
|
||||
self.use_async_output_proc = False
|
||||
|
||||
if speculative_config:
|
||||
logger.warning("Async output processing is not supported with"
|
||||
" speculative decoding currently.")
|
||||
self.use_async_output_proc = False
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
@ -358,6 +403,11 @@ class ModelConfig:
|
||||
"fallback to the eager mode.")
|
||||
self.enforce_eager = True
|
||||
|
||||
if pipeline_parallel_size > 1 and self.use_async_output_proc:
|
||||
logger.warning("Async output processor is not supported with "
|
||||
"pipeline parallelism currently. Disabling it.")
|
||||
self.use_async_output_proc = False
|
||||
|
||||
def get_hf_config_sliding_window(self) -> Optional[int]:
|
||||
"""Get the sliding window size, or None if disabled."""
|
||||
|
||||
@ -1769,6 +1819,9 @@ class EngineConfig:
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
"""
|
||||
self.model_config.verify_async_output_proc(self.parallel_config,
|
||||
self.speculative_config,
|
||||
self.device_config)
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
|
||||
@ -4,7 +4,8 @@ import random
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
|
||||
Tuple, Union)
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
@ -299,6 +300,7 @@ class Scheduler:
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
pipeline_parallel_size: int = 1,
|
||||
output_proc_callback_fn: Optional[Callable] = None,
|
||||
) -> None:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
@ -364,10 +366,36 @@ class Scheduler:
|
||||
self.num_cumulative_preemption: int = 0
|
||||
|
||||
# Used to cache python objects
|
||||
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
|
||||
scheduler_running_outputs_builder)
|
||||
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
|
||||
scheduled_seq_group_builder)
|
||||
self._seq_group_metadata_cache: List[PyObjectCache] = []
|
||||
self._scheduler_running_outputs_cache: List[PyObjectCache] = []
|
||||
self._scheduled_seq_group_cache: List[PyObjectCache] = []
|
||||
|
||||
# For async output processing, we need to swap cache buffers between
|
||||
# iterations. I.e. since the output processing is lagged one step,
|
||||
# we cannot reuse the cached objects immediately when the schedule()
|
||||
# is called again, but only when schedule() is called the second time.
|
||||
self.output_proc_callback_fn = output_proc_callback_fn
|
||||
self.use_async_output_proc = self.output_proc_callback_fn is not None
|
||||
self.num_cache_iters = 2 if self.use_async_output_proc else 1
|
||||
|
||||
self.cache_id = 0
|
||||
for i in range(self.num_cache_iters):
|
||||
self._seq_group_metadata_cache.append(
|
||||
PyObjectCache(seq_group_metadata_builder))
|
||||
self._scheduler_running_outputs_cache.append(
|
||||
PyObjectCache(scheduler_running_outputs_builder))
|
||||
self._scheduled_seq_group_cache.append(
|
||||
PyObjectCache(scheduled_seq_group_builder))
|
||||
|
||||
# For async postprocessor, the extra decode run cannot be done
|
||||
# when the request reaches max_model_len. In this case, the request
|
||||
# will be stopped during schedule() call and added to this stop list
|
||||
# for processing and deallocation by the free_finished_seq_groups()
|
||||
self._async_stopped: List[SequenceGroup] = []
|
||||
|
||||
@property
|
||||
def next_cache_id(self):
|
||||
return (self.cache_id + 1) % self.num_cache_iters
|
||||
|
||||
@property
|
||||
def lora_enabled(self) -> bool:
|
||||
@ -483,7 +511,7 @@ class Scheduler:
|
||||
SchedulerRunningOutputs.
|
||||
"""
|
||||
ret: SchedulerRunningOutputs = \
|
||||
self._scheduler_running_outputs_cache.get_object()
|
||||
self._scheduler_running_outputs_cache[self.cache_id].get_object()
|
||||
ret.blocks_to_swap_out.clear()
|
||||
ret.blocks_to_copy.clear()
|
||||
ret.decode_seq_groups.clear()
|
||||
@ -510,8 +538,12 @@ class Scheduler:
|
||||
# NOTE(woosuk): Preemption happens only when there is no available slot
|
||||
# to keep all the sequence groups in the RUNNING state.
|
||||
|
||||
running_queue = self.running
|
||||
# Store original running requests for the case of async + preemption
|
||||
if self.use_async_output_proc:
|
||||
orig_running = self.running.copy()
|
||||
|
||||
running_queue = self.running
|
||||
assert len(self._async_stopped) == 0
|
||||
while running_queue:
|
||||
seq_group = running_queue[0]
|
||||
num_running_tokens = self._get_num_new_tokens(
|
||||
@ -521,6 +553,28 @@ class Scheduler:
|
||||
break
|
||||
|
||||
running_queue.popleft()
|
||||
|
||||
# With async postprocessor, an extra decode run is done
|
||||
# to process the final tokens. The check below avoids this extra
|
||||
# decode run when the model max len is reached, in order to avoid
|
||||
# a memory overflow.
|
||||
if self.use_async_output_proc and seq_group.seqs[0].get_len(
|
||||
) > self.scheduler_config.max_model_len:
|
||||
self._async_stopped.append(seq_group)
|
||||
continue
|
||||
|
||||
# With async postprocessor, when preemption kicks in, we need
|
||||
# first to drain the async postprocessor, so that all async
|
||||
# block_table freeing is applied before the preemption freeing
|
||||
# is applied.
|
||||
if self.use_async_output_proc and not self._can_append_slots(
|
||||
seq_group):
|
||||
tmp = self.running
|
||||
self.running = orig_running
|
||||
assert self.output_proc_callback_fn is not None
|
||||
self.output_proc_callback_fn(is_async=True)
|
||||
self.running = tmp
|
||||
|
||||
while not self._can_append_slots(seq_group):
|
||||
budget.subtract_num_batched_tokens(seq_group.request_id,
|
||||
num_running_tokens)
|
||||
@ -556,7 +610,7 @@ class Scheduler:
|
||||
is_prefill = seq_group.is_prefill()
|
||||
|
||||
scheduled_seq_group: ScheduledSequenceGroup = \
|
||||
self._scheduled_seq_group_cache.get_object()
|
||||
self._scheduled_seq_group_cache[self.cache_id].get_object()
|
||||
scheduled_seq_group.seq_group = seq_group
|
||||
if is_prefill:
|
||||
scheduled_seq_group.token_chunk_size = num_running_tokens
|
||||
@ -579,8 +633,8 @@ class Scheduler:
|
||||
if curr_loras is not None and seq_group.lora_int_id > 0:
|
||||
curr_loras.add(seq_group.lora_int_id)
|
||||
|
||||
self._scheduler_running_outputs_cache.reset()
|
||||
self._scheduled_seq_group_cache.reset()
|
||||
self._scheduler_running_outputs_cache[self.next_cache_id].reset()
|
||||
self._scheduled_seq_group_cache[self.next_cache_id].reset()
|
||||
|
||||
return ret
|
||||
|
||||
@ -1031,17 +1085,31 @@ class Scheduler:
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
|
||||
)
|
||||
|
||||
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
||||
def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
|
||||
no_beam_search = (seq_group.sampling_params.best_of == 1
|
||||
and not seq_group.sampling_params.use_beam_search)
|
||||
|
||||
return no_beam_search
|
||||
|
||||
def schedule(
|
||||
self
|
||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
|
||||
# Schedule sequence groups.
|
||||
# This function call changes the internal states of the scheduler
|
||||
# such as self.running, self.swapped, and self.waiting.
|
||||
scheduler_start_time = time.perf_counter()
|
||||
|
||||
scheduler_outputs = self._schedule()
|
||||
now = time.time()
|
||||
|
||||
if not self.cache_config.enable_prefix_caching:
|
||||
common_computed_block_nums = []
|
||||
|
||||
# TODO: Combine multi-step and async postprocessor
|
||||
allow_async_output_proc: bool = (
|
||||
self.use_async_output_proc
|
||||
and not self.scheduler_config.is_multi_step)
|
||||
|
||||
# Create input data structures.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for i, scheduled_seq_group in enumerate(
|
||||
@ -1050,6 +1118,11 @@ class Scheduler:
|
||||
token_chunk_size = scheduled_seq_group.token_chunk_size
|
||||
seq_group.maybe_set_first_scheduled_time(now)
|
||||
|
||||
seq_group_metadata = self._seq_group_metadata_cache[
|
||||
self.cache_id].get_object()
|
||||
seq_group_metadata.seq_data.clear()
|
||||
seq_group_metadata.block_tables.clear()
|
||||
|
||||
# seq_id -> SequenceData
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
# seq_id -> physical block numbers
|
||||
@ -1139,6 +1212,10 @@ class Scheduler:
|
||||
)
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
if allow_async_output_proc:
|
||||
allow_async_output_proc = self._allow_async_output_proc(
|
||||
seq_group)
|
||||
|
||||
# Now that the batch has been created, we can assume all blocks in the
|
||||
# batch will have been computed before the next scheduling invocation.
|
||||
# This is because the engine assumes that a failure in model execution
|
||||
@ -1147,6 +1224,8 @@ class Scheduler:
|
||||
self.block_manager.mark_blocks_as_computed(
|
||||
scheduled_seq_group.seq_group)
|
||||
|
||||
self._seq_group_metadata_cache[self.next_cache_id].reset()
|
||||
|
||||
scheduler_time = time.perf_counter() - scheduler_start_time
|
||||
# Add this to scheduler time to all the sequences that are currently
|
||||
# running. This will help estimate if the scheduler is a significant
|
||||
@ -1158,7 +1237,12 @@ class Scheduler:
|
||||
else:
|
||||
seq_group.metrics.scheduler_time = scheduler_time
|
||||
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
# Move to next cache (if exists)
|
||||
self.cache_id = self.next_cache_id
|
||||
|
||||
# Return results
|
||||
return (seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc)
|
||||
|
||||
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
self.block_manager.fork(parent_seq, child_seq)
|
||||
@ -1167,6 +1251,12 @@ class Scheduler:
|
||||
"""Free a sequence from a block table."""
|
||||
self.block_manager.free(seq)
|
||||
|
||||
def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
|
||||
"""Free finished seqs in a sequence group."""
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
self.free_seq(seq)
|
||||
|
||||
def free_finished_seq_groups(self) -> None:
|
||||
remaining: Deque[SequenceGroup] = deque()
|
||||
for seq_group in self.running:
|
||||
@ -1179,8 +1269,24 @@ class Scheduler:
|
||||
self._finished_requests_ids.append(seq_group.request_id)
|
||||
else:
|
||||
remaining.append(seq_group)
|
||||
|
||||
# Free finished seqs
|
||||
self._free_finished_seqs(seq_group)
|
||||
|
||||
self.running = remaining
|
||||
|
||||
# Handle async stopped sequence groups
|
||||
# (ones that reached max model len)
|
||||
if self._async_stopped:
|
||||
for seq_group in self._async_stopped:
|
||||
self._free_seq_group_cross_attn_blocks(seq_group)
|
||||
self._finished_requests_ids.append(seq_group.request_id)
|
||||
|
||||
# Free finished seqs
|
||||
self._free_finished_seqs(seq_group)
|
||||
|
||||
self._async_stopped.clear()
|
||||
|
||||
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
|
||||
self.block_manager.allocate(seq_group)
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
|
||||
@ -147,6 +147,7 @@ class EngineArgs:
|
||||
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
collect_detailed_traces: Optional[str] = None
|
||||
disable_async_output_proc: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
@ -733,6 +734,12 @@ class EngineArgs:
|
||||
"modules. This involves use of possibly costly and or blocking "
|
||||
"operations and hence might have a performance impact.")
|
||||
|
||||
parser.add_argument(
|
||||
'--disable-async-output-proc',
|
||||
action='store_true',
|
||||
default=EngineArgs.disable_async_output_proc,
|
||||
help="Disable async output processing. This may result in "
|
||||
"lower performance.")
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -792,6 +799,7 @@ class EngineArgs:
|
||||
skip_tokenizer_init=self.skip_tokenizer_init,
|
||||
served_model_name=self.served_model_name,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
use_async_output_proc=not self.disable_async_output_proc,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size if self.device != "neuron" else
|
||||
|
||||
@ -277,23 +277,36 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
|
||||
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||
|
||||
# skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||
virtual_engine].schedule()
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
|
||||
# If current scheduler iteration has no async postprocessor,
|
||||
# then we need first to drain the pending async postprocessor
|
||||
# before moving forward
|
||||
if not allow_async_output_proc and len(self.output_queue) > 0:
|
||||
self._process_model_outputs(is_async=True)
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
# lookahead slots
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
virtual_engine, seq_group_metadata_list, scheduler_outputs)
|
||||
virtual_engine, seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc)
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
|
||||
assert not (self.scheduler_config.is_multi_step and \
|
||||
allow_async_output_proc)
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
@ -317,6 +330,11 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
||||
# to each of the non-last PP stages for in-place prepare_input.
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
|
||||
if allow_async_output_proc:
|
||||
execute_model_req.output_proc_callback_fn = \
|
||||
self._process_model_outputs
|
||||
|
||||
# Execute the model.
|
||||
output = await self.model_executor.execute_model_async(
|
||||
execute_model_req)
|
||||
@ -325,6 +343,9 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, output)
|
||||
else:
|
||||
if len(self.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(is_async=True)
|
||||
output = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
@ -337,19 +358,32 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine] = SchedulerOutputState()
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
|
||||
# Cache results in engine
|
||||
self.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs))
|
||||
|
||||
if output and allow_async_output_proc:
|
||||
assert len(
|
||||
output
|
||||
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
if not allow_async_output_proc:
|
||||
self._process_model_outputs(is_async=False)
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
|
||||
# Tracing
|
||||
self.do_tracing(scheduler_outputs)
|
||||
|
||||
else:
|
||||
request_outputs = []
|
||||
self.request_outputs = []
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
|
||||
# Tracing
|
||||
self.do_tracing(scheduler_outputs)
|
||||
|
||||
return request_outputs
|
||||
return self.request_outputs
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
"""Stop the remote worker execution loop."""
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
|
||||
Mapping, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Type, Union
|
||||
@ -38,9 +39,8 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
PoolerOutput, SamplerOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceStatus)
|
||||
SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceStatus)
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
@ -82,9 +82,10 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
||||
@dataclass
|
||||
class SchedulerOutputState:
|
||||
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
|
||||
last_output: Optional[SamplerOutput] = None
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
allow_async_output_proc: bool = False
|
||||
last_output: Optional[SamplerOutput] = None
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
@ -190,6 +191,9 @@ class LLMEngine:
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
# To improve performance, only final requests outputs may be required.
|
||||
# If this set to true, then no intermediate outputs will be returned.
|
||||
step_return_finished_only: bool = False,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Initializing an LLM engine (v%s) with config: "
|
||||
@ -204,7 +208,8 @@ class LLMEngine:
|
||||
"quantization_param_path=%s, device_config=%s, "
|
||||
"decoding_config=%r, observability_config=%r, "
|
||||
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
|
||||
"num_scheduler_steps=%d, enable_prefix_caching=%s)",
|
||||
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
|
||||
"use_async_output_proc=%s)",
|
||||
VLLM_VERSION,
|
||||
model_config.model,
|
||||
speculative_config,
|
||||
@ -235,6 +240,7 @@ class LLMEngine:
|
||||
scheduler_config.use_v2_block_manager,
|
||||
scheduler_config.num_scheduler_steps,
|
||||
cache_config.enable_prefix_caching,
|
||||
model_config.use_async_output_proc,
|
||||
)
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
from vllm.plugins import load_general_plugins
|
||||
@ -253,6 +259,7 @@ class LLMEngine:
|
||||
self.observability_config = observability_config or ObservabilityConfig(
|
||||
)
|
||||
self.log_stats = log_stats
|
||||
self.step_return_finished_only = step_return_finished_only
|
||||
|
||||
if not self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
@ -340,8 +347,11 @@ class LLMEngine:
|
||||
# NOTE: the cache_config here have been updated with the numbers of
|
||||
# GPU and CPU blocks, which are profiled in the distributed executor.
|
||||
self.scheduler = [
|
||||
Scheduler(scheduler_config, cache_config, lora_config,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
Scheduler(
|
||||
scheduler_config, cache_config, lora_config,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
self._process_model_outputs
|
||||
if model_config.use_async_output_proc else None)
|
||||
for _ in range(parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
@ -396,6 +406,13 @@ class LLMEngine:
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Async output processing pointers
|
||||
self.output_queue: Deque[Tuple[List[SamplerOutput],
|
||||
List[SequenceGroupMetadata],
|
||||
SchedulerOutputs]] = deque()
|
||||
self.request_outputs: List[Union[RequestOutput,
|
||||
EmbeddingRequestOutput]] = []
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
"""Initialize the KV cache in the worker(s).
|
||||
|
||||
@ -1197,34 +1214,66 @@ class LLMEngine:
|
||||
|
||||
return
|
||||
|
||||
def _process_model_outputs(
|
||||
self,
|
||||
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
|
||||
scheduled_seq_groups: List[ScheduledSequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
def _process_model_outputs(self,
|
||||
is_async: bool,
|
||||
clear_outputs: bool = True) -> None:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
is_async: Indicates whether this postprocessor runs in
|
||||
parallel with the GPU forward pass and is processing
|
||||
tokens from the previous step. If this is true, then
|
||||
no tokens need to be appended since it is already done
|
||||
externally (before the next schedule() call)
|
||||
clear_outputs: Sometimes existing outputs need to be combined
|
||||
with outputs of this call. This happens for postprocessor
|
||||
draining at the final stage (like when sequences are finished)
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
"""
|
||||
|
||||
now = time.time()
|
||||
|
||||
# Organize outputs by [sequence group][step] instead of
|
||||
# [step][sequence group].
|
||||
output_by_sequence_group = create_output_by_sequence_group(
|
||||
output, num_seq_groups=len(scheduled_seq_groups))
|
||||
if clear_outputs:
|
||||
self.request_outputs.clear()
|
||||
|
||||
if len(self.output_queue) == 0:
|
||||
return None
|
||||
|
||||
(outputs, seq_group_metadata_list,
|
||||
scheduler_outputs) = self.output_queue.popleft()
|
||||
|
||||
# Sanity check
|
||||
assert len(seq_group_metadata_list) == len(
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
# Organize outputs by [step][sequence group] instead of
|
||||
# [sequence group][step].
|
||||
if len(outputs) > 1:
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
outputs, num_seq_groups=len(seq_group_metadata_list))
|
||||
else:
|
||||
outputs_by_sequence_group = outputs
|
||||
|
||||
finished_before: List[int] = []
|
||||
for i, seq_group_meta in enumerate(seq_group_metadata_list):
|
||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
for scheduled_seq_group, outputs, seq_group_meta in zip(
|
||||
scheduled_seq_groups, output_by_sequence_group,
|
||||
seq_group_metadata_list):
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
if output is not None and len(output) > 0:
|
||||
for o in output:
|
||||
|
||||
if seq_group.is_finished():
|
||||
finished_before.append(i)
|
||||
continue
|
||||
|
||||
if len(outputs) > 1:
|
||||
output = outputs_by_sequence_group[i]
|
||||
else:
|
||||
output = [outputs_by_sequence_group[0][i]]
|
||||
|
||||
if not is_async:
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
|
||||
if outputs:
|
||||
for o in outputs:
|
||||
if (isinstance(o, SamplerOutput)
|
||||
and seq_group.metrics is not None):
|
||||
if seq_group.metrics.model_forward_time is not None:
|
||||
@ -1239,30 +1288,75 @@ class LLMEngine:
|
||||
else:
|
||||
seq_group.metrics.model_execute_time = (
|
||||
o.model_execute_time)
|
||||
|
||||
if self.model_config.embedding_mode:
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
self._process_sequence_group_outputs(seq_group, output)
|
||||
continue
|
||||
|
||||
self.output_processor.process_prompt_logprob(seq_group, outputs)
|
||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||
if seq_group_meta.do_sample:
|
||||
self.output_processor.process_outputs(seq_group, outputs)
|
||||
self.output_processor.process_outputs(seq_group, output,
|
||||
is_async)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[Union[RequestOutput,
|
||||
EmbeddingRequestOutput]] = []
|
||||
for scheduled_seq_group in scheduled_seq_groups:
|
||||
for i, _ in enumerate(seq_group_metadata_list):
|
||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||
|
||||
if i in finished_before:
|
||||
continue # Avoids double processing
|
||||
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
if (seq_group.is_finished()
|
||||
if self.step_return_finished_only else True):
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
self.request_outputs.append(request_output)
|
||||
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups:
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
for seq_group in ignored_seq_groups:
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
return request_outputs
|
||||
self.request_outputs.append(request_output)
|
||||
|
||||
if is_async:
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, outputs, finished_before)
|
||||
|
||||
# Tracing
|
||||
self.do_tracing(scheduler_outputs)
|
||||
|
||||
return None
|
||||
|
||||
def _advance_to_next_step(
|
||||
self, output: List[SamplerOutput],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
|
||||
"""Given model output from a single run, append the tokens to the
|
||||
sequences. This is normally done inside output processor, but it is
|
||||
required if the worker is to perform async forward pass to next step.
|
||||
"""
|
||||
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
|
||||
zip(seq_group_metadata_list, output, scheduled_seq_groups):
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
|
||||
if seq_group.is_finished():
|
||||
continue
|
||||
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_metadata.token_chunk_size)
|
||||
|
||||
if seq_group_metadata.do_sample:
|
||||
assert len(sequence_group_outputs.samples) == 1, (
|
||||
"Async output processor expects a single sample"
|
||||
" (i.e sampling_params.n == 1 and no "
|
||||
"sampling_params.best_of > 1)")
|
||||
sample = sequence_group_outputs.samples[0]
|
||||
|
||||
assert len(seq_group.seqs) == 1
|
||||
seq = seq_group.seqs[0]
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
|
||||
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
@ -1325,24 +1419,32 @@ class LLMEngine:
|
||||
cached_outputs = self.cached_scheduler_outputs[0]
|
||||
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||
|
||||
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||
0].schedule()
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc) = self.scheduler[0].schedule()
|
||||
|
||||
if not allow_async_output_proc and len(self.output_queue) > 0:
|
||||
self._process_model_outputs(is_async=True)
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
# lookahead slots
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
0, seq_group_metadata_list, scheduler_outputs)
|
||||
0, seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc)
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
|
||||
assert not (self.scheduler_config.is_multi_step and \
|
||||
allow_async_output_proc)
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
finished_requests_ids = self.scheduler[
|
||||
0].get_and_reset_finished_requests_ids()
|
||||
@ -1366,6 +1468,10 @@ class LLMEngine:
|
||||
# to each of the non-last PP stages for in-place prepare_input.
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
|
||||
if allow_async_output_proc:
|
||||
execute_model_req.output_proc_callback_fn = \
|
||||
self._process_model_outputs
|
||||
|
||||
output = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
|
||||
@ -1374,6 +1480,9 @@ class LLMEngine:
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(0, output)
|
||||
else:
|
||||
if len(self.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(is_async=True)
|
||||
output = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
@ -1382,23 +1491,41 @@ class LLMEngine:
|
||||
seq_group.finish_step()
|
||||
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
# clear the cache if we have finished all the steps
|
||||
# clear the cache if we have finished all the steps.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
|
||||
# Add results to the output_queue
|
||||
# (for async or non-async postprocessing)
|
||||
self.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs))
|
||||
|
||||
if output and allow_async_output_proc:
|
||||
assert len(output) == 1, ("Multi step decoding does not work "
|
||||
"with async output processing.")
|
||||
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
if not allow_async_output_proc:
|
||||
self._process_model_outputs(is_async=False)
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
|
||||
# Tracing
|
||||
self.do_tracing(scheduler_outputs)
|
||||
else:
|
||||
request_outputs = []
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
|
||||
# Tracing
|
||||
self.do_tracing(scheduler_outputs)
|
||||
self.request_outputs = []
|
||||
|
||||
if not self.has_unfinished_requests():
|
||||
# Drain async postprocessor
|
||||
if len(self.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(is_async=True, clear_outputs=False)
|
||||
assert len(self.output_queue) == 0
|
||||
|
||||
# Stop the execute model loop in parallel workers until there are
|
||||
# more requests to process. This avoids waiting indefinitely in
|
||||
# torch.distributed ops which may otherwise timeout, and unblocks
|
||||
@ -1406,7 +1533,7 @@ class LLMEngine:
|
||||
# queued control plane messages, such as add/remove lora adapters.
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
|
||||
return request_outputs
|
||||
return self.request_outputs
|
||||
|
||||
def _has_remaining_steps(
|
||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||
@ -1431,12 +1558,14 @@ class LLMEngine:
|
||||
def _cache_scheduler_outputs_for_multi_step(
|
||||
self, virtual_engine: int,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
scheduler_outputs: SchedulerOutputs) -> None:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
|
||||
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
|
||||
scheduler_outputs
|
||||
self.cached_scheduler_outputs[virtual_engine].last_output = None
|
||||
scheduler_outputs: SchedulerOutputs,
|
||||
allow_async_output_proc: bool) -> None:
|
||||
co = self.cached_scheduler_outputs[virtual_engine]
|
||||
|
||||
co.seq_group_metadata_list = seq_group_metadata_list
|
||||
co.scheduler_outputs = scheduler_outputs
|
||||
co.allow_async_output_proc = allow_async_output_proc
|
||||
co.last_output = None
|
||||
|
||||
def _update_cached_scheduler_output(
|
||||
self, virtual_engine: int,
|
||||
@ -1472,20 +1601,21 @@ class LLMEngine:
|
||||
raise KeyError(f"Logger with name {logger_name} does not exist.")
|
||||
del self.stat_loggers[logger_name]
|
||||
|
||||
def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None) -> None:
|
||||
def do_log_stats(self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None,
|
||||
finished_before: Optional[List[int]] = None) -> None:
|
||||
"""Forced log when no requests active."""
|
||||
if self.log_stats:
|
||||
stats = self._get_stats(scheduler_outputs, model_output)
|
||||
stats = self._get_stats(scheduler_outputs, model_output,
|
||||
finished_before)
|
||||
for logger in self.stat_loggers.values():
|
||||
logger.log(stats)
|
||||
|
||||
def _get_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs],
|
||||
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
|
||||
def _get_stats(self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs],
|
||||
model_output: Optional[List[SamplerOutput]] = None,
|
||||
finished_before: Optional[List[int]] = None) -> Stats:
|
||||
"""Get Stats to be Logged to Prometheus.
|
||||
|
||||
Args:
|
||||
@ -1550,6 +1680,10 @@ class LLMEngine:
|
||||
# NOTE: This loop assumes prefill seq_groups are before
|
||||
# decode seq_groups in scheduled_seq_groups.
|
||||
if scheduler_outputs is not None:
|
||||
# For async postprocessor, already finished sequences need to be
|
||||
# not counted (to avoid double counting)
|
||||
actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore
|
||||
|
||||
num_generation_tokens_from_prefill_groups = 0.
|
||||
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
|
||||
# the len of scheduler_outputs.scheduled_seq_groups is !=
|
||||
@ -1558,6 +1692,11 @@ class LLMEngine:
|
||||
|
||||
for idx, scheduled_seq_group in enumerate(
|
||||
scheduler_outputs.scheduled_seq_groups):
|
||||
# Skip double logging when using async output proc
|
||||
if finished_before and idx in finished_before:
|
||||
actual_num_batched_tokens -= 1
|
||||
continue
|
||||
|
||||
group_was_prefill = idx < scheduler_outputs.num_prefill_groups
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
|
||||
@ -1592,7 +1731,6 @@ class LLMEngine:
|
||||
# Latency timings
|
||||
time_e2e_requests.append(now -
|
||||
seq_group.metrics.arrival_time)
|
||||
|
||||
# Metadata
|
||||
num_prompt_tokens_requests.append(
|
||||
len(seq_group.prompt_token_ids))
|
||||
@ -1616,7 +1754,7 @@ class LLMEngine:
|
||||
# + num_generation_tokens_from_prefill_groups (since we generate
|
||||
# one token on prefills on iters where the prefill finishes).
|
||||
num_generation_tokens_iter = (
|
||||
scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
|
||||
actual_num_batched_tokens - num_prompt_tokens_iter +
|
||||
num_generation_tokens_from_prefill_groups)
|
||||
|
||||
# Spec decode, if enabled, emits specialized metrics from the worker in
|
||||
|
||||
@ -40,13 +40,9 @@ class SequenceGroupOutputProcessor(ABC):
|
||||
# Importing here to avoid cycle.
|
||||
from vllm.engine.output_processor.single_step import (
|
||||
SingleStepOutputProcessor)
|
||||
return SingleStepOutputProcessor(
|
||||
scheduler_config,
|
||||
detokenizer,
|
||||
scheduler,
|
||||
seq_counter,
|
||||
stop_checker,
|
||||
)
|
||||
return SingleStepOutputProcessor(scheduler_config, detokenizer,
|
||||
scheduler, seq_counter,
|
||||
stop_checker)
|
||||
else:
|
||||
# Importing here to avoid cycle.
|
||||
from vllm.engine.output_processor.multi_step import (
|
||||
@ -61,7 +57,8 @@ class SequenceGroupOutputProcessor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
outputs: List[SequenceGroupOutput],
|
||||
is_async: bool) -> None:
|
||||
"""Process new token ids for the sequence group. Handles logic such as
|
||||
detokenization, stop checking, and freeing/forking sequences in the
|
||||
scheduler.
|
||||
|
||||
@ -57,17 +57,28 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
"Prompt logprob is not supported by multi step workers. "
|
||||
"(e.g., speculative decode uses multi step workers).")
|
||||
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
def process_outputs(self,
|
||||
sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput],
|
||||
is_async: bool = False) -> None:
|
||||
"""Append new tokens in the outputs to sequences in the sequence group.
|
||||
|
||||
This only supports sequence groups of size 1. It supports greater than
|
||||
one new token per sequence.
|
||||
|
||||
This applies logic like stop condition checking and detokenization,
|
||||
including freeing finished sequences. It also handles cases where there
|
||||
are tokens emitted after the EOS token.
|
||||
This applies logic like stop condition checking and detokenization.
|
||||
It also handles cases where there are tokens emitted after
|
||||
the EOS token.
|
||||
|
||||
is_async - Indicates whether this postprocessor runs in
|
||||
parallel with the GPU forward pass and is processing
|
||||
tokens from the previous step. If this is true, then
|
||||
no tokens need to be appended since it is already done
|
||||
externally (before the next schedule() call)
|
||||
"""
|
||||
# TODO: Add support for async if necessary
|
||||
assert not is_async
|
||||
|
||||
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
|
||||
assert seqs, "expected running sequences"
|
||||
@ -138,7 +149,3 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
)
|
||||
if seq.is_finished():
|
||||
break
|
||||
|
||||
if seq.is_finished():
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(seq)
|
||||
|
||||
@ -29,14 +29,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
that is currently difficult to schedule multiple steps ahead of time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler_config: SchedulerConfig,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: List[Scheduler],
|
||||
seq_counter: Counter,
|
||||
stop_checker: StopChecker,
|
||||
):
|
||||
def __init__(self, scheduler_config: SchedulerConfig,
|
||||
detokenizer: Detokenizer, scheduler: List[Scheduler],
|
||||
seq_counter: Counter, stop_checker: StopChecker):
|
||||
self.scheduler_config = scheduler_config
|
||||
self.detokenizer = detokenizer
|
||||
self.scheduler = scheduler
|
||||
@ -44,16 +39,24 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
self.stop_checker = stop_checker
|
||||
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
outputs: List[SequenceGroupOutput],
|
||||
is_async: bool) -> None:
|
||||
"""Append all new tokens to sequences in the sequence group. Fork any
|
||||
surviving beam candidates; free any unsurviving ones.
|
||||
|
||||
Invokes detokenizer to detokenize new tokens, and also marks sequences
|
||||
as finished if they meet stop conditions.
|
||||
|
||||
is_async - Indicates whether this postprocessor runs in
|
||||
parallel with the GPU forward pass and is processing
|
||||
tokens from the previous step. If this is true, then
|
||||
no tokens need to be appended since it is already done
|
||||
externally (before the next schedule() call)
|
||||
"""
|
||||
assert (len(outputs) == 1
|
||||
), f"{type(self)} does not support multiple outputs per step"
|
||||
return self._process_sequence_group_outputs(sequence_group, outputs[0])
|
||||
return self._process_sequence_group_outputs(sequence_group, outputs[0],
|
||||
is_async)
|
||||
|
||||
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
@ -80,14 +83,16 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
seq_group.prompt_logprobs.extend(prompt_logprobs)
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
outputs: SequenceGroupOutput) -> None:
|
||||
outputs: SequenceGroupOutput,
|
||||
is_async: bool) -> None:
|
||||
sampling_params = seq_group.sampling_params
|
||||
if sampling_params.n == 1 and not sampling_params.use_beam_search:
|
||||
# only have one output sample
|
||||
sample = outputs.samples[0]
|
||||
# only have one sequence
|
||||
seq = seq_group.seqs[0]
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
if not is_async:
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
if sampling_params.detokenize and self.detokenizer:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
@ -104,6 +109,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
scheduler.free_seq(seq)
|
||||
return
|
||||
|
||||
# TODO: Add support for async for beam search
|
||||
assert not is_async
|
||||
|
||||
# Process samples
|
||||
samples = outputs.samples
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
|
||||
@ -129,6 +129,7 @@ class LLM:
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
'''
|
||||
@ -170,6 +171,7 @@ class LLM:
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = LLMEngine.from_engine_args(
|
||||
@ -603,7 +605,6 @@ class LLM:
|
||||
inputs = [inputs]
|
||||
|
||||
num_requests = len(inputs)
|
||||
|
||||
if isinstance(params, list) and len(params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and params "
|
||||
"must be the same.")
|
||||
@ -678,6 +679,10 @@ class LLM:
|
||||
postfix=(f"est. speed input: {0:.2f} toks/s, "
|
||||
f"output: {0:.2f} toks/s"),
|
||||
)
|
||||
|
||||
# In the loop below, only finished outputs are used
|
||||
self.llm_engine.step_return_finished_only = True
|
||||
|
||||
# Run the engine.
|
||||
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
|
||||
total_in_toks = 0
|
||||
@ -700,6 +705,10 @@ class LLM:
|
||||
f"est. speed input: {in_spd:.2f} toks/s, "
|
||||
f"output: {out_spd:.2f} toks/s")
|
||||
pbar.update(1)
|
||||
|
||||
# Restore original behavior
|
||||
self.llm_engine.step_return_finished_only = False
|
||||
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
# Sort the outputs by request ID.
|
||||
|
||||
@ -64,8 +64,9 @@ class DistributedGPUExecutor(GPUExecutor):
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
if self.parallel_worker_tasks is None:
|
||||
self.parallel_worker_tasks = self._run_workers(
|
||||
"start_worker_execution_loop",
|
||||
@ -188,7 +189,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
|
||||
@abstractmethod
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Execute the model asynchronously in the driver worker.
|
||||
|
||||
|
||||
@ -176,5 +176,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(execute_model_req=execute_model_req, )
|
||||
)(execute_model_req=execute_model_req)
|
||||
return output
|
||||
|
||||
@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||
Tuple, Union, cast)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
|
||||
Optional, Set, Tuple, Union, cast)
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
@ -474,11 +474,8 @@ class Sequence:
|
||||
"""Reset the sequence states for recomputation."""
|
||||
self.data.reset_state_for_recompute()
|
||||
|
||||
def append_token_id(
|
||||
self,
|
||||
token_id: int,
|
||||
logprobs: Dict[int, Logprob],
|
||||
) -> None:
|
||||
def append_token_id(self, token_id: int, logprobs: Dict[int,
|
||||
Logprob]) -> None:
|
||||
assert token_id in logprobs
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.data.append_token_id(token_id, logprobs[token_id].logprob)
|
||||
@ -1293,6 +1290,8 @@ class ExecuteModelRequest(
|
||||
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
|
||||
# The last sampled token ids for multi step decoding.
|
||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
# Async postprocessor
|
||||
output_proc_callback_fn: Optional[Callable] = None
|
||||
|
||||
@property
|
||||
def is_first_multi_step(self) -> bool:
|
||||
@ -1338,4 +1337,5 @@ class ExecuteModelRequest(
|
||||
num_steps=self.num_steps,
|
||||
finished_requests_ids=self.finished_requests_ids,
|
||||
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||
if self.last_sampled_token_ids is not None else None)
|
||||
if self.last_sampled_token_ids is not None else None,
|
||||
output_proc_callback_fn=self.output_proc_callback_fn)
|
||||
|
||||
@ -6,8 +6,8 @@ import time
|
||||
import warnings
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
|
||||
Tuple, Type, TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -90,6 +90,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
virtual_engine: int = 0
|
||||
output_proc_callback_fn: Optional[Callable] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
@ -1327,7 +1328,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
finished_requests_ids: Optional[List[str]] = None,
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
@ -1451,6 +1452,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.output_proc_callback_fn is not None:
|
||||
model_input.output_proc_callback_fn(is_async=True)
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.model.sample(
|
||||
logits=logits,
|
||||
|
||||
@ -263,6 +263,12 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
broadcast_data.update(kwargs)
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
if execute_model_req.output_proc_callback_fn:
|
||||
model_input = dataclasses.replace( # type: ignore
|
||||
model_input,
|
||||
output_proc_callback_fn=execute_model_req.
|
||||
output_proc_callback_fn)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
def prepare_input(
|
||||
@ -289,7 +295,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes at least one model step on the given sequences, unless no
|
||||
sequences are provided."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user