diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 0de853ba6e5e..388f7f45e051 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -470,22 +470,184 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, assert not output_processor.has_unfinished_requests() +@pytest.mark.parametrize( + "include_stop_str_in_output,stop_token_type,ignore_eos,num_sample_logprobs", + [(False, "stop_token_ids", False, None), + (True, "stop_token_ids", False, None), + (False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST), + (True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST), + (False, "eos_token_id", False, None), (True, "eos_token_id", False, None), + (False, "eos_token_id", True, None)]) +def test_stop_token(include_stop_str_in_output: bool, + num_sample_logprobs: Optional[int], stop_token_type: str, + ignore_eos: bool, dummy_test_vectors): + """Test output processor EOS/stop token handling. + + Send mock engine core request to mock engine core and pass core outputs + to output processor. Validate output processor tokens, text and + (if enabled) sample logprobs. Batch-size one. + + The test emulates a scenario where a model outputs text tokens followed + by two identical control tokens: + ... + + If EOS is under test, the control tokens are EOS; otherwise, they are + some other token id. + + Test behavior: + + * If EOS is under test and `ignore_eos=True`, the detokenized string + should be ... and the finish + reason should be "length" (i.e. no stop occurs) + + * else, if `include_stop_str_in_output==True`, the detokenized + string should be ... and the finish + reason should be "stop" (i.e. first control token causes stop + and is represented in output text) + + * else, the detokenized string should be + ... and the finish reason should be "stop" + (i.e. first control token causes stop but is not represented + in output text.) + + Note: some test details are tuned for meta-llama/Llama-3.2-1B, + another model should work only if the test is modified. + + Args: + include_stop_str_in_output: stop token str appears in output text + num_sample_logprobs: number of sample logprobs (`None` for no logprobs) + stop_token_type: "eos_token_id" for EOS, "stop_token_ids" for stop token + ignore_eos: if True, EOS stops are disabled + dummy_test_vectors: dummy engine core outputs and other data structures + """ + model_id = dummy_test_vectors.tokenizer.name_or_path + if model_id != 'meta-llama/Llama-3.2-1B': + raise AssertionError("Test requires meta-llama/Llama-3.2-1B but " + f"{model_id} is in use.") + do_logprobs = num_sample_logprobs is not None + # EOS under test; if False, stop_token_ids under test + is_eos_test = stop_token_type == "eos_token_id" + # EOS under test but ignore_eos enabled + is_eos_ignore_test = is_eos_test and ignore_eos + eos_token_id = ( + dummy_test_vectors.tokenizer.eos_token_id if is_eos_test else None + ) # '<|end_of_text|>' + stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>' + + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=False) + # Dummy engine core outputs, with control tokens suffixed to test stops + suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids) + assert suffix_token is not None and isinstance(suffix_token[0], int) + generation_string = dummy_test_vectors.generation_strings[0] + generation_tokens = (dummy_test_vectors.generation_tokens[0] + + 2 * suffix_token) + if do_logprobs: + generation_logprobs = ( + dummy_test_vectors.generation_logprobs[0] + + 2 * [dummy_test_vectors.generation_logprobs[0][-1]]) + prompt_string = dummy_test_vectors.prompt_strings[0] + prompt_tokens = dummy_test_vectors.prompt_tokens[0] + engine_core = MockEngineCore( + tokens_list=[generation_tokens], + generated_logprobs_raw=[generation_logprobs] if do_logprobs else None, + prompt_logprobs_raw=None, + eos_token_id=eos_token_id, + stop_token_ids=stop_token_ids, + ignore_eos=ignore_eos) + + # Make request. + request_id = "request-0" + request = EngineCoreRequest( + request_id=request_id, + prompt=prompt_string, + prompt_token_ids=prompt_tokens, + arrival_time=0, + mm_inputs=None, + mm_hashes=None, + mm_placeholders=None, + eos_token_id=eos_token_id, + lora_request=None, + sampling_params=SamplingParams( + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=RequestOutputKind.DELTA, + stop=[], + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_stop_str_in_output, + logprobs=num_sample_logprobs, + prompt_logprobs=None, + ignore_eos=ignore_eos, + )) + + # Add request to the detokenizer. + output_processor.add_request(request) + + # Loop over engine core steps; run output processor + gen_string = "" + gen_tokens = [] + gen_logprobs = [] + while True: + # Mock output from the EngineCore. + outputs = engine_core.get_outputs() + if len(outputs) == 0: + break + + # Step the Detokenizer. + processed_outputs = output_processor.process_outputs(outputs) + request_outputs = processed_outputs.request_outputs + assert len(request_outputs) == 1 + # Stop token does not rely on abort + assert not processed_outputs.reqs_to_abort + + # Update tracking. + request_output = request_outputs[0] + if request_output.finished: + finish_reason = ("length" if is_eos_ignore_test else "stop") + assert request_output.outputs[0].finish_reason == finish_reason + + gen_string += request_output.outputs[0].text + gen_tokens.extend(request_output.outputs[0].token_ids) + if do_logprobs: + gen_logprobs.extend(request_output.outputs[0].logprobs) + + # Validate generated text + control_token = '<|end_of_text|>' if is_eos_test else '<|eot_id|>' + if is_eos_ignore_test: + # Length-based stop; expect full string + ref_str = generation_string + 2 * control_token + elif include_stop_str_in_output: + # Stop token triggered; include in output + ref_str = generation_string + control_token + else: + # Stop token triggered but not in output + ref_str = generation_string + assert gen_string == ref_str, (f"{gen_string=}, {ref_str=}") + + if do_logprobs: + # Validate number of sample logprobs + num_tokens = len(gen_tokens) + num_logprobs = len(gen_logprobs) + assert num_tokens == num_logprobs, ( + f"Token count ({num_tokens}) != logprobs count ({num_logprobs})") + + # Check requests are finished + assert output_processor.get_num_unfinished_requests() == 0 + assert not output_processor.has_unfinished_requests() + + @pytest.mark.parametrize("include_stop_str_in_output", [True, False]) @pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) -@pytest.mark.parametrize("num_prompt_logprobs", - [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) def test_stop_string(include_stop_str_in_output: bool, - num_sample_logprobs: Optional[int], - num_prompt_logprobs: Optional[int], dummy_test_vectors): + num_sample_logprobs: Optional[int], dummy_test_vectors): output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, log_stats=False) engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, generated_logprobs_raw=dummy_test_vectors.generation_logprobs if num_sample_logprobs else None, - prompt_logprobs_raw=dummy_test_vectors.prompt_logprobs - if num_prompt_logprobs else None) + prompt_logprobs_raw=None) # Make N requests. request_id_list = [ @@ -510,7 +672,7 @@ def test_stop_string(include_stop_str_in_output: bool, stop=STOP_STRINGS, include_stop_str_in_output=include_stop_str_in_output, logprobs=num_sample_logprobs, - prompt_logprobs=num_prompt_logprobs, + prompt_logprobs=None, )) for idx, (prompt, prompt_tokens) in enumerate( zip(dummy_test_vectors.prompt_strings, dummy_test_vectors.prompt_tokens)) @@ -594,8 +756,7 @@ def test_stop_string(include_stop_str_in_output: bool, # Confirmed tracked logprobs match what we expect _validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs, gen_cumulative_logprobs, dummy_test_vectors, - request_id_list, num_sample_logprobs, - num_prompt_logprobs) + request_id_list, num_sample_logprobs, None) assert output_processor.get_num_unfinished_requests() == 0 assert not output_processor.has_unfinished_requests() diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index f0e344cfa6fc..1ee93c72cd26 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -20,7 +20,7 @@ NUM_SAMPLE_LOGPROBS_UNDER_TEST = 5 # Number of prompt logprobs to request when testing prompt logprobs NUM_PROMPT_LOGPROBS_UNDER_TEST = 7 -TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" +TOKENIZER_NAME = "meta-llama/Llama-3.2-1B" FULL_STRINGS = [ "My name is Robert from Neural Magic and I love working on vLLM so much!", @@ -330,13 +330,21 @@ class MockEngineCore: # each matrix has dimensions # (num prompt toks) x (num prompt logprobs+1) prompt_logprobs_raw: Optional[list[LogprobsTensors]] = None, + eos_token_id: Optional[int] = None, + stop_token_ids: Optional[list[int]] = None, + ignore_eos: bool = False, ) -> None: + self.num_requests = len(tokens_list) self.tokens_list = tokens_list self.current_idx = 0 self.generated_logprobs_raw = generated_logprobs_raw self.do_logprobs = generated_logprobs_raw is not None self.prompt_logprobs_raw = prompt_logprobs_raw self.do_prompt_logprobs = prompt_logprobs_raw is not None + self.request_finished = [False for _ in range(self.num_requests)] + self.eos_token_id = eos_token_id + self.stop_token_ids = stop_token_ids + self.ignore_eos = ignore_eos def get_outputs(self) -> list[EngineCoreOutput]: do_logprobs = self.do_logprobs @@ -345,7 +353,7 @@ class MockEngineCore: outputs = [] for req_idx, token_ids in enumerate(self.tokens_list): - if len(token_ids) > token_idx: + if not self.request_finished[req_idx]: if do_logprobs: assert self.generated_logprobs_raw is not None (logprobs_token_ids_, logprobs_, sampled_token_ranks_) = ( @@ -365,14 +373,23 @@ class MockEngineCore: prompt_logprobs = None else: prompt_logprobs = None + new_token_id = token_ids[token_idx] output = EngineCoreOutput( request_id=f"request-{req_idx}", - new_token_ids=[token_ids[token_idx]], + new_token_ids=[new_token_id], new_logprobs=logprobs, new_prompt_logprobs_tensors=prompt_logprobs, ) if token_idx == len(token_ids) - 1: + output.finish_reason = FinishReason.LENGTH + self.request_finished[req_idx] = True + if not self.ignore_eos and new_token_id == self.eos_token_id: output.finish_reason = FinishReason.STOP + self.request_finished[req_idx] = True + if new_token_id in (self.stop_token_ids or ()): + output.finish_reason = FinishReason.STOP + output.stop_reason = new_token_id + self.request_finished[req_idx] = True outputs.append(output) self.current_idx += 1 diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 92754920b62d..bf06a17507b2 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -88,7 +88,8 @@ class IncrementalDetokenizer: stop_buffer_length=stop_buffer_length, ) - def update(self, new_token_ids: list[int]) -> Optional[str]: + def update(self, new_token_ids: list[int], + stop_terminated: bool) -> Optional[str]: """ Update RequestState for the request_id by: 1) Detokenize the new token ids incrementally. @@ -96,11 +97,22 @@ class IncrementalDetokenizer: Return matched stop string or None. """ - + if not new_token_ids: + # Skip detokenization if no new token ids + return None if self.tokenizer is None: + # Skip detokenization if no tokenizer self.token_ids.extend(new_token_ids) return None + if stop_terminated and not self.include_stop_str_in_output: + # If stop-terminated, exclude last token from detokenization + # based on include_stop_str_in_output parameter. + skipped_stop_token_id = new_token_ids[-1] + new_token_ids = new_token_ids[:-1] + else: + skipped_stop_token_id = None + # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of # new_token_ids is more than 1. We need to optimize this. @@ -127,7 +139,14 @@ class IncrementalDetokenizer: self.output_text += decoded_text - # 2) Evaluate stop criteria. + if stop_terminated: + if skipped_stop_token_id is not None: + # Cleanup after skipping detokenization + self.token_ids.append(skipped_stop_token_id) + # Stop token triggered; skip stop string check + return None + + # 2) Evaluate stop strings. stop_string = None if self.stop: stop = StopChecker.check_stop_strings( diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 83180b66bea0..04235eda0926 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -299,9 +299,9 @@ class OutputProcessor: # in the EngineCore. req_state.is_prefilling = not new_token_ids - # 2) Detokenize the token ids into text and check for stop - # strings. - stop_string = req_state.detokenizer.update(new_token_ids) + # 2) Detokenize the token ids into text and perform stop checks. + stop_string = req_state.detokenizer.update( + new_token_ids, finish_reason == FinishReason.STOP) if stop_string and finish_reason != FinishReason.STOP: finish_reason = FinishReason.STOP stop_reason = stop_string