[BUGFIX] Do not return ignored sentences twice in async llm engine (#2258)

This commit is contained in:
Zhuohan Li 2023-12-26 13:41:09 +08:00 committed by GitHub
parent face83c7ec
commit e0ff920001
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 22 deletions

View File

@ -183,20 +183,18 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if scheduler_outputs.is_empty():
return ignored
# Execute the model. # Execute the model.
output = await self._run_workers_async( output = (await self._run_workers_async(
"execute_model", "execute_model",
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
) )) if not scheduler_outputs.is_empty() else []
return self._process_model_outputs(output, scheduler_outputs) + ignored return self._process_model_outputs(output, scheduler_outputs)
async def _run_workers_async( async def _run_workers_async(
self, self,

View File

@ -14,8 +14,7 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceGroupOutput, SequenceGroupOutput, SequenceOutput, SequenceStatus)
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) get_tokenizer)
from vllm.utils import Counter from vllm.utils import Counter
@ -328,16 +327,6 @@ class LLMEngine:
"""Returns True if there are unfinished requests.""" """Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs() return self.scheduler.has_unfinished_seqs()
def _schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
List[RequestOutput]]:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
return seq_group_metadata_list, scheduler_outputs, [
RequestOutput.from_seq_group(seq_group)
for seq_group in scheduler_outputs.ignored_seq_groups
]
def _check_beam_search_early_stopping( def _check_beam_search_early_stopping(
self, self,
early_stopping: Union[bool, str], early_stopping: Union[bool, str],
@ -586,9 +575,7 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if scheduler_outputs.is_empty():
return ignored
# Execute the model. # Execute the model.
output = self._run_workers( output = self._run_workers(
@ -597,7 +584,7 @@ class LLMEngine:
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
) ) if not scheduler_outputs.is_empty() else []
return self._process_model_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs)