diff --git a/tests/reasoning/test_glm4_moe_reasoning_parser.py b/tests/reasoning/test_glm4_moe_reasoning_parser.py new file mode 100644 index 0000000000000..4c5ec2c9b408d --- /dev/null +++ b/tests/reasoning/test_glm4_moe_reasoning_parser.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "glm45" +start_token = "" +end_token = "" + +REASONING_MODEL_NAME = "zai-org/GLM-4.5" + + +@pytest.fixture(scope="module") +def glm45_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +WITH_THINK = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} + +WITH_THINK_STREAM = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} + +WITHOUT_THINK = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": False, +} + +WITHOUT_THINK_STREAM = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": False, +} + +COMPLETE_REASONING = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +MULTILINE_REASONING = { + "output": + "This is a reasoning\nsectionThis is the rest\nThat", + "reasoning_content": "This is a reasoning\nsection", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +ONLY_OPEN_TAG = { + "output": "This is a reasoning section", + "reasoning_content": None, + "content": "This is a reasoning section", + "is_reasoning_end": False, +} + +ONLY_OPEN_TAG_STREAM = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} + +TEST_CASES = [ + pytest.param( + False, + WITH_THINK, + id="with_think", + ), + pytest.param( + True, + WITH_THINK_STREAM, + id="with_think_stream", + ), + pytest.param( + False, + WITHOUT_THINK, + id="without_think", + ), + pytest.param( + True, + WITHOUT_THINK_STREAM, + id="without_think_stream", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_stream", + ), + pytest.param( + False, + MULTILINE_REASONING, + id="multiline_reasoning", + ), + pytest.param( + True, + MULTILINE_REASONING, + id="multiline_reasoning_stream", + ), + pytest.param( + False, + ONLY_OPEN_TAG, + id="only_open_tag", + ), + pytest.param( + True, + ONLY_OPEN_TAG_STREAM, + id="only_open_tag_stream", + ), +] + +STILL_REASONING_PROMPT = """[gMASK]<|system|> +You are a helpful assistant.<|user|> +What is the capital of France?<|assistant|> +The user is asking for the capital of""" + +DONE_REASONING_PROMPT = """[gMASK]<|system|> +You are a helpful assistant.<|user|> +What is the capital of France?<|assistant|> +The user is asking for the capital of France. +The capital of France is Paris.""" + +MULTI_TURN_STILL_REASONING_PROMPT = """[gMASK]<|system|> +You are a helpful assistant.<|user|> +What is the capital of France?<|assistant|> + +The capital of France is Paris.<|user|> +What about Chile?<|assistant|> +The user is asking for the capital of""" + +MULTI_TURN_DONE_REASONING_PROMPT = """[gMASK]<|system|> +You are a helpful assistant.<|user|> +What is the capital of France?<|assistant|> + +The capital of France is Paris.<|user|> +What about Chile?<|assistant|> +The user is asking for the capital of Chile. +The capital of Chile is Santiago.""" + +REASONING_END_TEST_CASES = [ + pytest.param(STILL_REASONING_PROMPT, False, id="still_reasoning"), + pytest.param(DONE_REASONING_PROMPT, True, id="done_reasoning"), + pytest.param(MULTI_TURN_STILL_REASONING_PROMPT, + False, + id="multi_turn_still_reasoning"), + pytest.param(MULTI_TURN_DONE_REASONING_PROMPT, + True, + id="multi_turn_done_reasoning") +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, + glm45_tokenizer, +): + output = glm45_tokenizer.tokenize(param_dict["output"]) + output_tokens: list[str] = [ + glm45_tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(glm45_tokenizer) + + reasoning, content = run_reasoning_extraction(parser, + output_tokens, + streaming=streaming) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] + + output_ids = glm45_tokenizer.convert_tokens_to_ids(output) + is_reasoning_end = parser.is_reasoning_end(output_ids) + assert is_reasoning_end == param_dict["is_reasoning_end"] + + +@pytest.mark.parametrize("prompt, is_reasoning_end", REASONING_END_TEST_CASES) +def test_is_reasoning_end_full_prompt(prompt: str, is_reasoning_end: bool, + glm45_tokenizer): + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(glm45_tokenizer) + tokens = glm45_tokenizer.tokenize(prompt) + token_ids = glm45_tokenizer.convert_tokens_to_ids(tokens) + check_is_reasoning_end = parser.is_reasoning_end(token_ids) + assert check_is_reasoning_end == is_reasoning_end diff --git a/vllm/reasoning/glm4_moe_reasoning_parser.py b/vllm/reasoning/glm4_moe_reasoning_parser.py index 11e828a7039fa..8d7488afce68e 100644 --- a/vllm/reasoning/glm4_moe_reasoning_parser.py +++ b/vllm/reasoning/glm4_moe_reasoning_parser.py @@ -30,6 +30,7 @@ class Glm4MoeModelReasoningParser(ReasoningParser): super().__init__(tokenizer, *args, **kwargs) self.think_start_token = "" self.think_end_token = "" + self.assistant_token = "<|assistant|>" if not self.model_tokenizer: raise ValueError( @@ -38,14 +39,26 @@ class Glm4MoeModelReasoningParser(ReasoningParser): self.think_start_token_id = self.vocab.get(self.think_start_token) self.think_end_token_id = self.vocab.get(self.think_end_token) + self.assistant_token_id = self.vocab.get(self.assistant_token) if (self.think_start_token_id is None - or self.think_end_token_id is None): + or self.think_end_token_id is None + or self.assistant_token_id is None): raise RuntimeError( "Glm4MoeModel reasoning parser could not locate " - "think start/end tokens in the tokenizer!") + "think start/end or assistant tokens in the tokenizer!") def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.think_end_token_id in input_ids + """ + GLM's chat template has tokens after every + <|assistant|> token. Thus, we need to check if is + after the most recent <|assistant|> token (if present). + """ + for token_id in input_ids[::-1]: + if token_id == self.think_end_token_id: + return True + elif token_id == self.assistant_token_id: + return False + return False def extract_content_ids(self, input_ids: list[int]) -> list[int]: """