From e0ff920001988f2240cf05551bef566898ed6e4b Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 26 Dec 2023 13:41:09 +0800 Subject: [PATCH] [BUGFIX] Do not return ignored sentences twice in async llm engine (#2258) --- vllm/engine/async_llm_engine.py | 10 ++++------ vllm/engine/llm_engine.py | 19 +++---------------- 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d854a20b8b95..611da51f6193 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -183,20 +183,18 @@ class _AsyncLLMEngine(LLMEngine): and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() - if scheduler_outputs.is_empty(): - return ignored + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() # Execute the model. - output = await self._run_workers_async( + output = (await self._run_workers_async( "execute_model", seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, - ) + )) if not scheduler_outputs.is_empty() else [] - return self._process_model_outputs(output, scheduler_outputs) + ignored + return self._process_model_outputs(output, scheduler_outputs) async def _run_workers_async( self, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 93ff0fc05d50..43bf9747ee18 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -14,8 +14,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceGroupOutput, - SequenceOutput, SequenceStatus) + SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) from vllm.utils import Counter @@ -328,16 +327,6 @@ class LLMEngine: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() - def _schedule( - self - ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, - List[RequestOutput]]: - seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - return seq_group_metadata_list, scheduler_outputs, [ - RequestOutput.from_seq_group(seq_group) - for seq_group in scheduler_outputs.ignored_seq_groups - ] - def _check_beam_search_early_stopping( self, early_stopping: Union[bool, str], @@ -586,9 +575,7 @@ class LLMEngine: and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() - if scheduler_outputs.is_empty(): - return ignored + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() # Execute the model. output = self._run_workers( @@ -597,7 +584,7 @@ class LLMEngine: blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, - ) + ) if not scheduler_outputs.is_empty() else [] return self._process_model_outputs(output, scheduler_outputs)