[Engine][Chore] use local variable and remove output var assignment (#24554)

Signed-off-by: Guy Stone <guys@spotify.com>
This commit is contained in:
Guy Stone 2025-09-11 02:05:42 -04:00 committed by GitHub
parent e2d8c27f68
commit 8a894084d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -78,6 +78,7 @@ class EngineClient(ABC):
preprocessor = await self.get_input_preprocessor() preprocessor = await self.get_input_preprocessor()
tokenizer_group = preprocessor.get_tokenizer_group() tokenizer_group = preprocessor.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async() tokenizer = await tokenizer_group.get_lora_tokenizer_async()
eos_token_id = tokenizer.eos_token_id
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError raise NotImplementedError
@ -104,7 +105,7 @@ class EngineClient(ABC):
tokenized_length = len(prompt_token_ids) tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function( sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty) eos_token_id, length_penalty)
beam_search_params = SamplingParams( beam_search_params = SamplingParams(
logprobs=2 * beam_width, logprobs=2 * beam_width,
@ -154,7 +155,7 @@ class EngineClient(ABC):
if result.outputs[0].logprobs is not None: if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0] logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items(): for token_id, logprob_obj in logprobs.items():
if token_id == tokenizer.eos_token_id and \ if token_id == eos_token_id and \
not ignore_eos: not ignore_eos:
completed.append( completed.append(
BeamSearchSequence( BeamSearchSequence(
@ -166,7 +167,7 @@ class EngineClient(ABC):
cum_logprob=current_beam.cum_logprob + cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob, logprob_obj.logprob,
finish_reason="stop", finish_reason="stop",
stop_reason=tokenizer.eos_token_id)) stop_reason=eos_token_id))
else: else:
new_beams.append( new_beams.append(
BeamSearchSequence( BeamSearchSequence(
@ -189,14 +190,14 @@ class EngineClient(ABC):
best_beams = sorted_completed[:beam_width] best_beams = sorted_completed[:beam_width]
for beam in best_beams: for beam in best_beams:
if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos): if (beam.tokens[-1] == eos_token_id and not ignore_eos):
# Skip the eos token in the text. # Skip the eos token in the text.
tokens = beam.tokens[tokenized_length:-1] tokens = beam.tokens[tokenized_length:-1]
else: else:
tokens = beam.tokens[tokenized_length:] tokens = beam.tokens[tokenized_length:]
beam.text = tokenizer.decode(tokens) beam.text = tokenizer.decode(tokens)
beam_search_output = RequestOutput( yield RequestOutput(
request_id=request_id, request_id=request_id,
prompt=prompt_text, prompt=prompt_text,
outputs=[ outputs=[
@ -214,8 +215,6 @@ class EngineClient(ABC):
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt_logprobs=None) prompt_logprobs=None)
yield beam_search_output
@abstractmethod @abstractmethod
def encode( def encode(
self, self,