diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index b0b11a33a4443..94eacfbdfb301 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -78,6 +78,7 @@ class EngineClient(ABC): preprocessor = await self.get_input_preprocessor() tokenizer_group = preprocessor.get_tokenizer_group() tokenizer = await tokenizer_group.get_lora_tokenizer_async() + eos_token_id = tokenizer.eos_token_id if is_explicit_encoder_decoder_prompt(prompt): raise NotImplementedError @@ -104,7 +105,7 @@ class EngineClient(ABC): tokenized_length = len(prompt_token_ids) sort_beams_key = create_sort_beams_key_function( - tokenizer.eos_token_id, length_penalty) + eos_token_id, length_penalty) beam_search_params = SamplingParams( logprobs=2 * beam_width, @@ -154,7 +155,7 @@ class EngineClient(ABC): if result.outputs[0].logprobs is not None: logprobs = result.outputs[0].logprobs[0] for token_id, logprob_obj in logprobs.items(): - if token_id == tokenizer.eos_token_id and \ + if token_id == eos_token_id and \ not ignore_eos: completed.append( BeamSearchSequence( @@ -166,7 +167,7 @@ class EngineClient(ABC): cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, finish_reason="stop", - stop_reason=tokenizer.eos_token_id)) + stop_reason=eos_token_id)) else: new_beams.append( BeamSearchSequence( @@ -189,14 +190,14 @@ class EngineClient(ABC): best_beams = sorted_completed[:beam_width] for beam in best_beams: - if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos): + if (beam.tokens[-1] == eos_token_id and not ignore_eos): # Skip the eos token in the text. tokens = beam.tokens[tokenized_length:-1] else: tokens = beam.tokens[tokenized_length:] beam.text = tokenizer.decode(tokens) - beam_search_output = RequestOutput( + yield RequestOutput( request_id=request_id, prompt=prompt_text, outputs=[ @@ -214,8 +215,6 @@ class EngineClient(ABC): prompt_token_ids=prompt_token_ids, prompt_logprobs=None) - yield beam_search_output - @abstractmethod def encode( self,