mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 23:07:26 +08:00
[Core] Async_output_proc: Add virtual engine support (towards pipeline parallel) (#7911)
This commit is contained in:
parent
51f86bf487
commit
f508e03e7f
@ -302,7 +302,7 @@ class Scheduler:
|
|||||||
cache_config: CacheConfig,
|
cache_config: CacheConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
pipeline_parallel_size: int = 1,
|
pipeline_parallel_size: int = 1,
|
||||||
output_proc_callback_fn: Optional[Callable] = None,
|
output_proc_callback: Optional[Callable] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
@ -376,8 +376,8 @@ class Scheduler:
|
|||||||
# iterations. I.e. since the output processing is lagged one step,
|
# iterations. I.e. since the output processing is lagged one step,
|
||||||
# we cannot reuse the cached objects immediately when the schedule()
|
# we cannot reuse the cached objects immediately when the schedule()
|
||||||
# is called again, but only when schedule() is called the second time.
|
# is called again, but only when schedule() is called the second time.
|
||||||
self.output_proc_callback_fn = output_proc_callback_fn
|
self.output_proc_callback = output_proc_callback
|
||||||
self.use_async_output_proc = self.output_proc_callback_fn is not None
|
self.use_async_output_proc = self.output_proc_callback is not None
|
||||||
self.num_cache_iters = 2 if self.use_async_output_proc else 1
|
self.num_cache_iters = 2 if self.use_async_output_proc else 1
|
||||||
|
|
||||||
self.cache_id = 0
|
self.cache_id = 0
|
||||||
@ -573,8 +573,8 @@ class Scheduler:
|
|||||||
seq_group):
|
seq_group):
|
||||||
tmp = self.running
|
tmp = self.running
|
||||||
self.running = orig_running
|
self.running = orig_running
|
||||||
assert self.output_proc_callback_fn is not None
|
assert self.output_proc_callback is not None
|
||||||
self.output_proc_callback_fn(is_async=True)
|
self.output_proc_callback()
|
||||||
self.running = tmp
|
self.running = tmp
|
||||||
|
|
||||||
while not self._can_append_slots(seq_group):
|
while not self._can_append_slots(seq_group):
|
||||||
@ -1091,7 +1091,6 @@ class Scheduler:
|
|||||||
no_beam_search = seq_group.sampling_params is None or (
|
no_beam_search = seq_group.sampling_params is None or (
|
||||||
seq_group.sampling_params.best_of == 1
|
seq_group.sampling_params.best_of == 1
|
||||||
and not seq_group.sampling_params.use_beam_search)
|
and not seq_group.sampling_params.use_beam_search)
|
||||||
|
|
||||||
return no_beam_search
|
return no_beam_search
|
||||||
|
|
||||||
def schedule(
|
def schedule(
|
||||||
|
|||||||
@ -279,10 +279,16 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||||
|
|
||||||
|
ctx = self.scheduler_contexts[virtual_engine]
|
||||||
|
|
||||||
# skip the scheduler if there are any remaining steps in the seq groups.
|
# skip the scheduler if there are any remaining steps in the seq groups.
|
||||||
# This ensures that the scheduler is only called again when the current
|
# This ensures that the scheduler is only called again when the current
|
||||||
# batch has completed.
|
# batch has completed.
|
||||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||||
|
|
||||||
|
# Clear outputs on scheduler iteration start
|
||||||
|
ctx.request_outputs.clear()
|
||||||
|
|
||||||
(seq_group_metadata_list, scheduler_outputs,
|
(seq_group_metadata_list, scheduler_outputs,
|
||||||
allow_async_output_proc
|
allow_async_output_proc
|
||||||
) = self.scheduler[virtual_engine].schedule()
|
) = self.scheduler[virtual_engine].schedule()
|
||||||
@ -290,8 +296,9 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
# If current scheduler iteration has no async postprocessor,
|
# If current scheduler iteration has no async postprocessor,
|
||||||
# then we need first to drain the pending async postprocessor
|
# then we need first to drain the pending async postprocessor
|
||||||
# before moving forward
|
# before moving forward
|
||||||
if not allow_async_output_proc and len(self.output_queue) > 0:
|
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||||
self._process_model_outputs(is_async=True)
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
|
is_async=True)
|
||||||
|
|
||||||
if (self.scheduler_config.is_multi_step
|
if (self.scheduler_config.is_multi_step
|
||||||
and scheduler_outputs.num_lookahead_slots > 0):
|
and scheduler_outputs.num_lookahead_slots > 0):
|
||||||
@ -332,8 +339,8 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
last_sampled_token_ids=last_sampled_token_ids)
|
last_sampled_token_ids=last_sampled_token_ids)
|
||||||
|
|
||||||
if allow_async_output_proc:
|
if allow_async_output_proc:
|
||||||
execute_model_req.output_proc_callback_fn = \
|
execute_model_req.async_callback = self.async_callback[
|
||||||
self._process_model_outputs
|
virtual_engine]
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
output = await self.model_executor.execute_model_async(
|
output = await self.model_executor.execute_model_async(
|
||||||
@ -343,9 +350,10 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
self._update_cached_scheduler_output(virtual_engine, output)
|
self._update_cached_scheduler_output(virtual_engine, output)
|
||||||
else:
|
else:
|
||||||
if len(self.output_queue) > 0:
|
if len(ctx.output_queue) > 0:
|
||||||
assert not self.scheduler_config.is_multi_step
|
assert not self.scheduler_config.is_multi_step
|
||||||
self._process_model_outputs(is_async=True)
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
|
is_async=True)
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
# Finish the current step for all the sequence groups.
|
# Finish the current step for all the sequence groups.
|
||||||
@ -360,7 +368,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
virtual_engine] = SchedulerOutputState()
|
virtual_engine] = SchedulerOutputState()
|
||||||
|
|
||||||
# Cache results in engine
|
# Cache results in engine
|
||||||
self.output_queue.append(
|
ctx.output_queue.append(
|
||||||
(output, seq_group_metadata_list, scheduler_outputs))
|
(output, seq_group_metadata_list, scheduler_outputs))
|
||||||
|
|
||||||
if output and allow_async_output_proc:
|
if output and allow_async_output_proc:
|
||||||
@ -372,7 +380,8 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
scheduler_outputs.scheduled_seq_groups)
|
scheduler_outputs.scheduled_seq_groups)
|
||||||
|
|
||||||
if not allow_async_output_proc:
|
if not allow_async_output_proc:
|
||||||
self._process_model_outputs(is_async=False)
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
|
is_async=False)
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, output)
|
self.do_log_stats(scheduler_outputs, output)
|
||||||
@ -381,9 +390,17 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
self.do_tracing(scheduler_outputs)
|
self.do_tracing(scheduler_outputs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.request_outputs = []
|
ctx.request_outputs = []
|
||||||
|
|
||||||
return self.request_outputs
|
if not self.has_unfinished_requests():
|
||||||
|
# Drain async postprocessor (if exists)
|
||||||
|
if len(ctx.output_queue) > 0:
|
||||||
|
assert not self.scheduler_config.is_multi_step
|
||||||
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
|
is_async=True)
|
||||||
|
assert len(ctx.output_queue) == 0
|
||||||
|
|
||||||
|
return ctx.request_outputs
|
||||||
|
|
||||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||||
"""Stop the remote worker execution loop."""
|
"""Stop the remote worker execution loop."""
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
|
import functools
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
|
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
|
||||||
Mapping, Optional)
|
Mapping, Optional)
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
@ -88,6 +89,17 @@ class SchedulerOutputState:
|
|||||||
last_output: Optional[SamplerOutput] = None
|
last_output: Optional[SamplerOutput] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SchedulerContext:
|
||||||
|
output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata],
|
||||||
|
SchedulerOutputs]] = field(
|
||||||
|
default_factory=lambda: deque())
|
||||||
|
|
||||||
|
request_outputs: List[Union[RequestOutput,
|
||||||
|
EmbeddingRequestOutput]] = field(
|
||||||
|
default_factory=lambda: [])
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
"""An LLM engine that receives requests and generates texts.
|
"""An LLM engine that receives requests and generates texts.
|
||||||
|
|
||||||
@ -350,9 +362,11 @@ class LLMEngine:
|
|||||||
Scheduler(
|
Scheduler(
|
||||||
scheduler_config, cache_config, lora_config,
|
scheduler_config, cache_config, lora_config,
|
||||||
parallel_config.pipeline_parallel_size,
|
parallel_config.pipeline_parallel_size,
|
||||||
self._process_model_outputs
|
functools.partial(self._process_model_outputs,
|
||||||
|
virtual_engine=v_id,
|
||||||
|
is_async=True)
|
||||||
if model_config.use_async_output_proc else None)
|
if model_config.use_async_output_proc else None)
|
||||||
for _ in range(parallel_config.pipeline_parallel_size)
|
for v_id in range(parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Metric Logging.
|
# Metric Logging.
|
||||||
@ -406,12 +420,17 @@ class LLMEngine:
|
|||||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Async output processing pointers
|
self.scheduler_contexts = [
|
||||||
self.output_queue: Deque[Tuple[List[SamplerOutput],
|
SchedulerContext()
|
||||||
List[SequenceGroupMetadata],
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||||
SchedulerOutputs]] = deque()
|
]
|
||||||
self.request_outputs: List[Union[RequestOutput,
|
|
||||||
EmbeddingRequestOutput]] = []
|
self.async_callback = [
|
||||||
|
functools.partial(self._process_model_outputs,
|
||||||
|
virtual_engine=v_id,
|
||||||
|
is_async=True)
|
||||||
|
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
def _initialize_kv_caches(self) -> None:
|
def _initialize_kv_caches(self) -> None:
|
||||||
"""Initialize the KV cache in the worker(s).
|
"""Initialize the KV cache in the worker(s).
|
||||||
@ -1221,32 +1240,28 @@ class LLMEngine:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def _process_model_outputs(self,
|
def _process_model_outputs(self, virtual_engine: int,
|
||||||
is_async: bool,
|
is_async: bool) -> None:
|
||||||
clear_outputs: bool = True) -> None:
|
|
||||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||||
|
|
||||||
|
virtual_engine: The engine id to operate on
|
||||||
is_async: Indicates whether this postprocessor runs in
|
is_async: Indicates whether this postprocessor runs in
|
||||||
parallel with the GPU forward pass and is processing
|
parallel with the GPU forward pass and is processing
|
||||||
tokens from the previous step. If this is true, then
|
tokens from the previous step. If this is true, then
|
||||||
no tokens need to be appended since it is already done
|
no tokens need to be appended since it is already done
|
||||||
externally (before the next schedule() call)
|
externally (before the next schedule() call)
|
||||||
clear_outputs: Sometimes existing outputs need to be combined
|
|
||||||
with outputs of this call. This happens for postprocessor
|
|
||||||
draining at the final stage (like when sequences are finished)
|
|
||||||
|
|
||||||
Returns RequestOutputs that can be returned to the client.
|
Returns RequestOutputs that can be returned to the client.
|
||||||
"""
|
"""
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
if clear_outputs:
|
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
|
||||||
self.request_outputs.clear()
|
|
||||||
|
|
||||||
if len(self.output_queue) == 0:
|
if len(ctx.output_queue) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
(outputs, seq_group_metadata_list,
|
(outputs, seq_group_metadata_list,
|
||||||
scheduler_outputs) = self.output_queue.popleft()
|
scheduler_outputs) = ctx.output_queue.popleft()
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
assert len(seq_group_metadata_list) == len(
|
assert len(seq_group_metadata_list) == len(
|
||||||
@ -1321,11 +1336,11 @@ class LLMEngine:
|
|||||||
if (seq_group.is_finished()
|
if (seq_group.is_finished()
|
||||||
if self.step_return_finished_only else True):
|
if self.step_return_finished_only else True):
|
||||||
request_output = RequestOutputFactory.create(seq_group)
|
request_output = RequestOutputFactory.create(seq_group)
|
||||||
self.request_outputs.append(request_output)
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
for seq_group in scheduler_outputs.ignored_seq_groups:
|
for seq_group in scheduler_outputs.ignored_seq_groups:
|
||||||
request_output = RequestOutputFactory.create(seq_group)
|
request_output = RequestOutputFactory.create(seq_group)
|
||||||
self.request_outputs.append(request_output)
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
if is_async:
|
if is_async:
|
||||||
# Log stats.
|
# Log stats.
|
||||||
@ -1421,29 +1436,43 @@ class LLMEngine:
|
|||||||
"Pipeline parallelism is only supported through AsyncLLMEngine "
|
"Pipeline parallelism is only supported through AsyncLLMEngine "
|
||||||
"as performance will be severely degraded otherwise.")
|
"as performance will be severely degraded otherwise.")
|
||||||
|
|
||||||
|
# For llm_engine, there is no pipeline parallel support, so the engine
|
||||||
|
# used is always 0
|
||||||
|
virtual_engine = 0
|
||||||
|
|
||||||
# These are cached outputs from previous iterations. None if on first
|
# These are cached outputs from previous iterations. None if on first
|
||||||
# iteration
|
# iteration
|
||||||
cached_outputs = self.cached_scheduler_outputs[0]
|
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
|
||||||
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
||||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||||
|
|
||||||
|
ctx = self.scheduler_contexts[virtual_engine]
|
||||||
|
|
||||||
# Skip the scheduler if there are any remaining steps in the seq groups.
|
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||||
# This ensures that the scheduler is only called again when the current
|
# This ensures that the scheduler is only called again when the current
|
||||||
# batch has completed.
|
# batch has completed.
|
||||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||||
(seq_group_metadata_list, scheduler_outputs,
|
|
||||||
allow_async_output_proc) = self.scheduler[0].schedule()
|
|
||||||
|
|
||||||
if not allow_async_output_proc and len(self.output_queue) > 0:
|
# Clear outputs on scheduler iteration start
|
||||||
self._process_model_outputs(is_async=True)
|
ctx.request_outputs.clear()
|
||||||
|
|
||||||
|
# Schedule iteration
|
||||||
|
(seq_group_metadata_list, scheduler_outputs,
|
||||||
|
allow_async_output_proc
|
||||||
|
) = self.scheduler[virtual_engine].schedule()
|
||||||
|
|
||||||
|
# Maybe switch from async mode to sync mode
|
||||||
|
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||||
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
|
is_async=True)
|
||||||
|
|
||||||
if (self.scheduler_config.is_multi_step
|
if (self.scheduler_config.is_multi_step
|
||||||
and scheduler_outputs.num_lookahead_slots > 0):
|
and scheduler_outputs.num_lookahead_slots > 0):
|
||||||
# cache the scheduler outputs for the next iteration if we have
|
# cache the scheduler outputs for the next iteration if we have
|
||||||
# lookahead slots
|
# lookahead slots
|
||||||
self._cache_scheduler_outputs_for_multi_step(
|
self._cache_scheduler_outputs_for_multi_step(
|
||||||
0, seq_group_metadata_list, scheduler_outputs,
|
virtual_engine, seq_group_metadata_list, scheduler_outputs,
|
||||||
allow_async_output_proc)
|
allow_async_output_proc)
|
||||||
|
|
||||||
assert seq_group_metadata_list is not None
|
assert seq_group_metadata_list is not None
|
||||||
@ -1454,14 +1483,14 @@ class LLMEngine:
|
|||||||
|
|
||||||
if not scheduler_outputs.is_empty():
|
if not scheduler_outputs.is_empty():
|
||||||
finished_requests_ids = self.scheduler[
|
finished_requests_ids = self.scheduler[
|
||||||
0].get_and_reset_finished_requests_ids()
|
virtual_engine].get_and_reset_finished_requests_ids()
|
||||||
|
|
||||||
# Check if we have a cached last_output from the previous iteration.
|
# Check if we have a cached last_output from the previous iteration.
|
||||||
# For supporting PP this is probably the best way to pass the
|
# For supporting PP this is probably the best way to pass the
|
||||||
# sampled_token_ids, as a separate broadcast over all the PP stages
|
# sampled_token_ids, as a separate broadcast over all the PP stages
|
||||||
# will cause one virtual engine's microbatch to block the pipeline.
|
# will cause one virtual engine's microbatch to block the pipeline.
|
||||||
last_sampled_token_ids = \
|
last_sampled_token_ids = \
|
||||||
self._get_last_sampled_token_ids(0)
|
self._get_last_sampled_token_ids(virtual_engine)
|
||||||
|
|
||||||
execute_model_req = ExecuteModelRequest(
|
execute_model_req = ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
@ -1476,20 +1505,24 @@ class LLMEngine:
|
|||||||
last_sampled_token_ids=last_sampled_token_ids)
|
last_sampled_token_ids=last_sampled_token_ids)
|
||||||
|
|
||||||
if allow_async_output_proc:
|
if allow_async_output_proc:
|
||||||
execute_model_req.output_proc_callback_fn = \
|
execute_model_req.async_callback = self.async_callback[
|
||||||
self._process_model_outputs
|
virtual_engine]
|
||||||
|
|
||||||
output = self.model_executor.execute_model(
|
output = self.model_executor.execute_model(
|
||||||
execute_model_req=execute_model_req)
|
execute_model_req=execute_model_req)
|
||||||
|
|
||||||
# we need to do this here so that last step's sampled_token_ids can
|
# We need to do this here so that last step's sampled_token_ids can
|
||||||
# be passed to the next iteration for PP.
|
# be passed to the next iteration for PP.
|
||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
self._update_cached_scheduler_output(0, output)
|
self._update_cached_scheduler_output(virtual_engine, output)
|
||||||
else:
|
else:
|
||||||
if len(self.output_queue) > 0:
|
# Nothing scheduled => If there is pending async postprocessor,
|
||||||
|
# then finish it here.
|
||||||
|
if len(ctx.output_queue) > 0:
|
||||||
assert not self.scheduler_config.is_multi_step
|
assert not self.scheduler_config.is_multi_step
|
||||||
self._process_model_outputs(is_async=True)
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
|
is_async=True)
|
||||||
|
# No outputs in this case
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
# Finish the current step for all the sequence groups.
|
# Finish the current step for all the sequence groups.
|
||||||
@ -1504,7 +1537,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Add results to the output_queue
|
# Add results to the output_queue
|
||||||
# (for async or non-async postprocessing)
|
# (for async or non-async postprocessing)
|
||||||
self.output_queue.append(
|
ctx.output_queue.append(
|
||||||
(output, seq_group_metadata_list, scheduler_outputs))
|
(output, seq_group_metadata_list, scheduler_outputs))
|
||||||
|
|
||||||
if output and allow_async_output_proc:
|
if output and allow_async_output_proc:
|
||||||
@ -1515,8 +1548,10 @@ class LLMEngine:
|
|||||||
output[0], seq_group_metadata_list,
|
output[0], seq_group_metadata_list,
|
||||||
scheduler_outputs.scheduled_seq_groups)
|
scheduler_outputs.scheduled_seq_groups)
|
||||||
|
|
||||||
|
# Check if need to run the usual non-async path
|
||||||
if not allow_async_output_proc:
|
if not allow_async_output_proc:
|
||||||
self._process_model_outputs(is_async=False)
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
|
is_async=False)
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, output)
|
self.do_log_stats(scheduler_outputs, output)
|
||||||
@ -1524,14 +1559,16 @@ class LLMEngine:
|
|||||||
# Tracing
|
# Tracing
|
||||||
self.do_tracing(scheduler_outputs)
|
self.do_tracing(scheduler_outputs)
|
||||||
else:
|
else:
|
||||||
self.request_outputs = []
|
# Multi-step case
|
||||||
|
ctx.request_outputs = []
|
||||||
|
|
||||||
if not self.has_unfinished_requests():
|
if not self.has_unfinished_requests():
|
||||||
# Drain async postprocessor
|
# Drain async postprocessor (if exists)
|
||||||
if len(self.output_queue) > 0:
|
if len(ctx.output_queue) > 0:
|
||||||
assert not self.scheduler_config.is_multi_step
|
assert not self.scheduler_config.is_multi_step
|
||||||
self._process_model_outputs(is_async=True, clear_outputs=False)
|
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||||
assert len(self.output_queue) == 0
|
is_async=True)
|
||||||
|
assert len(ctx.output_queue) == 0
|
||||||
|
|
||||||
# Stop the execute model loop in parallel workers until there are
|
# Stop the execute model loop in parallel workers until there are
|
||||||
# more requests to process. This avoids waiting indefinitely in
|
# more requests to process. This avoids waiting indefinitely in
|
||||||
@ -1540,7 +1577,7 @@ class LLMEngine:
|
|||||||
# queued control plane messages, such as add/remove lora adapters.
|
# queued control plane messages, such as add/remove lora adapters.
|
||||||
self.model_executor.stop_remote_worker_execution_loop()
|
self.model_executor.stop_remote_worker_execution_loop()
|
||||||
|
|
||||||
return self.request_outputs
|
return ctx.request_outputs
|
||||||
|
|
||||||
def _has_remaining_steps(
|
def _has_remaining_steps(
|
||||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||||
|
|||||||
@ -811,6 +811,9 @@ class SequenceGroup:
|
|||||||
self.is_single_seq = len(self.seqs) == 1
|
self.is_single_seq = len(self.seqs) == 1
|
||||||
|
|
||||||
def is_finished(self) -> bool:
|
def is_finished(self) -> bool:
|
||||||
|
if self.is_single_seq:
|
||||||
|
return self.seqs[0].is_finished()
|
||||||
|
|
||||||
return all(seq.is_finished() for seq in self.seqs)
|
return all(seq.is_finished() for seq in self.seqs)
|
||||||
|
|
||||||
def is_prefill(self) -> bool:
|
def is_prefill(self) -> bool:
|
||||||
@ -1290,8 +1293,8 @@ class ExecuteModelRequest(
|
|||||||
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
|
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
|
||||||
# The last sampled token ids for multi step decoding.
|
# The last sampled token ids for multi step decoding.
|
||||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||||
# Async postprocessor
|
# Async callback
|
||||||
output_proc_callback_fn: Optional[Callable] = None
|
async_callback: Optional[Callable] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_first_multi_step(self) -> bool:
|
def is_first_multi_step(self) -> bool:
|
||||||
@ -1338,4 +1341,4 @@ class ExecuteModelRequest(
|
|||||||
finished_requests_ids=self.finished_requests_ids,
|
finished_requests_ids=self.finished_requests_ids,
|
||||||
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||||
if self.last_sampled_token_ids is not None else None,
|
if self.last_sampled_token_ids is not None else None,
|
||||||
output_proc_callback_fn=self.output_proc_callback_fn)
|
async_callback=self.async_callback)
|
||||||
|
|||||||
@ -91,7 +91,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|||||||
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
||||||
finished_requests_ids: Optional[List[str]] = None
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
virtual_engine: int = 0
|
virtual_engine: int = 0
|
||||||
output_proc_callback_fn: Optional[Callable] = None
|
async_callback: Optional[Callable] = None
|
||||||
|
|
||||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
@ -1457,8 +1457,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if model_input.output_proc_callback_fn is not None:
|
if model_input.async_callback is not None:
|
||||||
model_input.output_proc_callback_fn(is_async=True)
|
model_input.async_callback()
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output: SamplerOutput = self.model.sample(
|
output: SamplerOutput = self.model.sample(
|
||||||
|
|||||||
@ -263,11 +263,10 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
broadcast_data.update(kwargs)
|
broadcast_data.update(kwargs)
|
||||||
broadcast_tensor_dict(broadcast_data, src=0)
|
broadcast_tensor_dict(broadcast_data, src=0)
|
||||||
|
|
||||||
if execute_model_req.output_proc_callback_fn:
|
if execute_model_req.async_callback:
|
||||||
model_input = dataclasses.replace( # type: ignore
|
model_input = dataclasses.replace( # type: ignore
|
||||||
model_input,
|
model_input,
|
||||||
output_proc_callback_fn=execute_model_req.
|
async_callback=execute_model_req.async_callback)
|
||||||
output_proc_callback_fn)
|
|
||||||
|
|
||||||
return model_input, worker_input, kwargs
|
return model_input, worker_input, kwargs
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user