diff --git a/tests/tool_use/test_kimi_k2_tool_parser.py b/tests/tool_use/test_kimi_k2_tool_parser.py index c358589dbc292..33dabbc7e7b91 100644 --- a/tests/tool_use/test_kimi_k2_tool_parser.py +++ b/tests/tool_use/test_kimi_k2_tool_parser.py @@ -209,3 +209,596 @@ def test_streaming_no_tool_calls(kimi_k2_tool_parser): assert result is not None assert hasattr(result, "content") assert result.content == " without any tool calls." + + +def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser): + """ + Test that text between <|tool_calls_section_begin|> and <|tool_call_begin|> + is suppressed and does not leak into reasoning_delta. + This is the main vulnerability being fixed. + """ + kimi_k2_tool_parser.reset_streaming_state() + + # Get token IDs for the markers + section_begin_token_id = kimi_k2_tool_parser.vocab.get( + "<|tool_calls_section_begin|>" + ) + tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>") + + # Simulate streaming sequence: + # Delta 1: "I'll help you with that. " + result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="I'll help you with that. ", + delta_text="I'll help you with that. ", + previous_token_ids=[], + current_token_ids=[1, 2, 3], # Regular tokens + delta_token_ids=[1, 2, 3], + request=None, + ) + assert result1 is not None + assert result1.content == "I'll help you with that. " + + # Delta 2: "<|tool_calls_section_begin|>" + prev_ids = [1, 2, 3] + curr_ids = prev_ids + [section_begin_token_id] + result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="I'll help you with that. ", + current_text="I'll help you with that. <|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=prev_ids, + current_token_ids=curr_ids, + delta_token_ids=[section_begin_token_id], + request=None, + ) + # Section marker should be stripped and suppressed + assert result2 is None or (result2.content is None or result2.content == "") + + # Delta 3: " spurious text or tokens " (THE LEAK SCENARIO) + prev_ids = curr_ids + curr_ids = curr_ids + [4, 5] + result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="I'll help you with that. <|tool_calls_section_begin|>", + current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ", + delta_text=" spurious text ", + previous_token_ids=prev_ids, + current_token_ids=curr_ids, + delta_token_ids=[4, 5], + request=None, + ) + # CRITICAL: This text should be suppressed, NOT returned as reasoning_delta + assert result3 is None or (result3.content is None or result3.content == "") + + # Delta 4: "<|tool_call_begin|>..." + prev_ids = curr_ids + curr_ids = curr_ids + [tool_call_begin_token_id] + _result4 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ", + current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>", + delta_text="<|tool_call_begin|>", + previous_token_ids=prev_ids, + current_token_ids=curr_ids, + delta_token_ids=[tool_call_begin_token_id], + request=None, + ) + # Now we're in tool call mode, result depends on internal state + # The key is that the spurious text from Delta 3 was not leaked + + +def test_split_markers_across_deltas(kimi_k2_tool_parser): + """ + Test that markers split across delta chunks are correctly detected + via the rolling buffer mechanism. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_token_id = kimi_k2_tool_parser.vocab.get( + "<|tool_calls_section_begin|>" + ) + + # Delta 1: "...reasoning<|tool_calls_sec" + _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Some reasoning", + current_text="Some reasoning<|tool_calls_sec", + delta_text="<|tool_calls_sec", + previous_token_ids=[1, 2], + current_token_ids=[1, 2, 3], # Partial token + delta_token_ids=[3], + request=None, + ) + # Partial token not recognized yet, might be buffered + # Should return as content or None (depends on implementation) + + # Delta 2: "tion_begin|> " (completes the marker) + _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Some reasoning<|tool_calls_sec", + current_text="Some reasoning<|tool_calls_section_begin|> ", + delta_text="tion_begin|> ", + previous_token_ids=[1, 2, 3], + current_token_ids=[1, 2, section_begin_token_id, 4], + delta_token_ids=[section_begin_token_id, 4], + request=None, + ) + # Now the complete marker should be detected via buffer + # The parser should enter tool section mode + assert kimi_k2_tool_parser.in_tool_section is True + + +def test_marker_variants(kimi_k2_tool_parser): + """Test that both singular and plural marker variants are recognized.""" + kimi_k2_tool_parser.reset_streaming_state() + + # Test singular variant: <|tool_call_section_begin|> (note: singular "call") + singular_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_section_begin|>") + + if singular_token_id is not None: # Only test if tokenizer supports it + _result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning ", + current_text="Reasoning <|tool_call_section_begin|>", + delta_text="<|tool_call_section_begin|>", + previous_token_ids=[1, 2], + current_token_ids=[1, 2, singular_token_id], + delta_token_ids=[singular_token_id], + request=None, + ) + # Should enter tool section mode with singular variant too + assert kimi_k2_tool_parser.in_tool_section is True + + +def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser): + """ + Test that after exiting a tool section with <|tool_calls_section_end|>, + subsequent text is correctly returned as reasoning content. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") + + # Enter tool section + _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="<|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=[], + current_token_ids=[section_begin_id], + delta_token_ids=[section_begin_id], + request=None, + ) + assert kimi_k2_tool_parser.in_tool_section is True + + # Exit tool section + _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="<|tool_calls_section_begin|>", + current_text="<|tool_calls_section_begin|><|tool_calls_section_end|>", + delta_text="<|tool_calls_section_end|>", + previous_token_ids=[section_begin_id], + current_token_ids=[section_begin_id, section_end_id], + delta_token_ids=[section_end_id], + request=None, + ) + assert kimi_k2_tool_parser.in_tool_section is False + + # Subsequent reasoning text should be returned normally + result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="<|tool_calls_section_begin|><|tool_calls_section_end|>", + current_text="<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning", + delta_text=" More reasoning", + previous_token_ids=[section_begin_id, section_end_id], + current_token_ids=[section_begin_id, section_end_id, 10, 11], + delta_token_ids=[10, 11], + request=None, + ) + assert result3 is not None + assert result3.content == " More reasoning" + + +def test_empty_tool_section(kimi_k2_tool_parser): + """Test an empty tool section (begin immediately followed by end).""" + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") + + # Section begin + _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning ", + current_text="Reasoning <|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=[1], + current_token_ids=[1, section_begin_id], + delta_token_ids=[section_begin_id], + request=None, + ) + + # Immediate section end + _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning <|tool_calls_section_begin|>", + current_text="Reasoning <|tool_calls_section_begin|><|tool_calls_section_end|>", + delta_text="<|tool_calls_section_end|>", + previous_token_ids=[1, section_begin_id], + current_token_ids=[1, section_begin_id, section_end_id], + delta_token_ids=[section_end_id], + request=None, + ) + # Should exit cleanly without errors + assert kimi_k2_tool_parser.in_tool_section is False + + +def test_malformed_tool_section_recovery(kimi_k2_tool_parser): + """ + Test that the parser recovers from a malformed tool section + that never closes properly. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + + # Enter tool section + _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="<|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=[], + current_token_ids=[section_begin_id], + delta_token_ids=[section_begin_id], + request=None, + ) + assert kimi_k2_tool_parser.in_tool_section is True + + # Simulate a lot of text without proper tool calls or section end + # This should trigger the error recovery mechanism + large_text = "x" * 10000 # Exceeds max_section_chars + + result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="<|tool_calls_section_begin|>", + current_text="<|tool_calls_section_begin|>" + large_text, + delta_text=large_text, + previous_token_ids=[section_begin_id], + current_token_ids=[section_begin_id] + list(range(100, 100 + len(large_text))), + delta_token_ids=list(range(100, 100 + len(large_text))), + request=None, + ) + + # Parser should have force-exited the tool section + assert kimi_k2_tool_parser.in_tool_section is False + # And returned the content as reasoning + assert result2 is not None + assert result2.content == large_text + + +def test_state_reset(kimi_k2_tool_parser): + """Test that reset_streaming_state() properly clears all state.""" + # Put parser in a complex state + kimi_k2_tool_parser.in_tool_section = True + kimi_k2_tool_parser.token_buffer = "some buffer" + kimi_k2_tool_parser.current_tool_id = 5 + kimi_k2_tool_parser.prev_tool_call_arr = [{"id": "test"}] + kimi_k2_tool_parser.section_char_count = 1000 + + # Reset + kimi_k2_tool_parser.reset_streaming_state() + + # Verify all state is cleared + assert kimi_k2_tool_parser.in_tool_section is False + assert kimi_k2_tool_parser.token_buffer == "" + assert kimi_k2_tool_parser.current_tool_id == -1 + assert kimi_k2_tool_parser.prev_tool_call_arr == [] + assert kimi_k2_tool_parser.section_char_count == 0 + assert kimi_k2_tool_parser.current_tool_name_sent is False + assert kimi_k2_tool_parser.streamed_args_for_tool == [] + + +def test_section_begin_noise_tool_begin_same_chunk(kimi_k2_tool_parser): + """ + Test that begin→noise→tool_begin within the SAME chunk suppresses + the noise text correctly (not just across chunks). + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + tool_call_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>") + + # Single delta containing: section_begin + spurious text + tool_call_begin + combined_text = "<|tool_calls_section_begin|> noise text <|tool_call_begin|>" + + result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning ", + current_text="Reasoning " + combined_text, + delta_text=combined_text, + previous_token_ids=[1, 2], + current_token_ids=[1, 2, section_begin_id, 3, 4, tool_call_begin_id], + delta_token_ids=[section_begin_id, 3, 4, tool_call_begin_id], + request=None, + ) + + # The noise text should NOT leak into content + # Result should either be None/empty or start tool call parsing + if result is not None and result.content is not None: + # If content is returned, it should not contain the noise + assert "noise text" not in result.content + assert result.content == "" or result.content.strip() == "" + + +def test_stream_ends_without_section_end_marker(kimi_k2_tool_parser): + """ + Test that if the stream ends (EOF) without a proper section end marker, + the parser doesn't leak text, doesn't crash, and resets state cleanly. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + + # Enter tool section + _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="<|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=[], + current_token_ids=[section_begin_id], + delta_token_ids=[section_begin_id], + request=None, + ) + assert kimi_k2_tool_parser.in_tool_section is True + + # Some content in tool section + result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="<|tool_calls_section_begin|>", + current_text="<|tool_calls_section_begin|> partial content", + delta_text=" partial content", + previous_token_ids=[section_begin_id], + current_token_ids=[section_begin_id, 10, 11], + delta_token_ids=[10, 11], + request=None, + ) + # Content should be suppressed + assert result2.content == "" or result2.content is None + + # Stream ends (EOF) - no more deltas, no section_end marker + # Simulate this by manually checking state and resetting + # (In real usage, the request handler would call reset_streaming_state) + assert kimi_k2_tool_parser.in_tool_section is True # Still in section + + # Reset state (as would happen between requests) + kimi_k2_tool_parser.reset_streaming_state() + + # Verify clean slate + assert kimi_k2_tool_parser.in_tool_section is False + assert kimi_k2_tool_parser.token_buffer == "" + + # Next request should work normally + result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="New reasoning", + delta_text="New reasoning", + previous_token_ids=[], + current_token_ids=[20, 21], + delta_token_ids=[20, 21], + request=None, + ) + assert result3 is not None + assert result3.content == "New reasoning" + + +def test_same_chunk_begin_and_end_markers(kimi_k2_tool_parser): + """ + CRITICAL TEST: Verify that when both section_begin and section_end + markers appear in the SAME chunk, the parser correctly: + 1. Enters the tool section + 2. Immediately exits the tool section + 3. Does NOT get stuck in in_tool_section=True state + + This tests the bug fix where elif was changed to if to handle + both state transitions in a single delta. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") + + # Single chunk with both markers (e.g., empty tool section) + combined_delta = "<|tool_calls_section_begin|><|tool_calls_section_end|>" + + result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Some reasoning ", + current_text="Some reasoning " + combined_delta, + delta_text=combined_delta, + previous_token_ids=[1, 2], + current_token_ids=[1, 2, section_begin_id, section_end_id], + delta_token_ids=[section_begin_id, section_end_id], + request=None, + ) + + # CRITICAL: Parser should NOT be stuck in tool section + assert kimi_k2_tool_parser.in_tool_section is False, ( + "Parser stuck in tool section after processing both begin/end in same chunk. " + "This indicates the elif bug was not fixed." + ) + + # Result should be empty or contain only stripped content + assert result is not None + assert result.content == "" or result.content is None + + # Verify subsequent content streams correctly (not suppressed) + result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Some reasoning " + combined_delta, + current_text="Some reasoning " + combined_delta + " More reasoning", + delta_text=" More reasoning", + previous_token_ids=[1, 2, section_begin_id, section_end_id], + current_token_ids=[1, 2, section_begin_id, section_end_id, 10, 11], + delta_token_ids=[10, 11], + request=None, + ) + + # This content should NOT be suppressed (we're out of tool section) + assert result2 is not None + assert result2.content == " More reasoning" + + +def test_same_chunk_begin_content_end_markers(kimi_k2_tool_parser): + """ + Test the same-chunk scenario with actual content between markers. + Example: <|tool_calls_section_begin|> text <|tool_calls_section_end|> + all arriving in one delta. The key is that the state machine correctly + transitions in and out within the same chunk. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") + + # Chunk with begin, some whitespace/noise, and end all together + # This simulates a tool section that opens and closes in the same chunk + combined_delta = "<|tool_calls_section_begin|> <|tool_calls_section_end|>" + + _result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning ", + current_text="Reasoning " + combined_delta, + delta_text=combined_delta, + previous_token_ids=[1], + current_token_ids=[1, section_begin_id, 100, section_end_id], + delta_token_ids=[section_begin_id, 100, section_end_id], + request=None, + ) + + # Parser should exit cleanly (not stuck in tool section) + assert kimi_k2_tool_parser.in_tool_section is False + + # Verify the fix: next content should stream normally, not be suppressed + result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Reasoning " + combined_delta, + current_text="Reasoning " + combined_delta + " Done", + delta_text=" Done", + previous_token_ids=[1, section_begin_id, 100, section_end_id], + current_token_ids=[1, section_begin_id, 100, section_end_id, 200], + delta_token_ids=[200], + request=None, + ) + + # Content after section should be returned (not suppressed) + assert result2 is not None + assert result2.content == " Done" + + +def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser): + """ + CRITICAL TEST (P1): Verify that when both <|tool_call_end|> and + <|tool_calls_section_end|> appear in the SAME chunk, the parser: + 1. Processes the tool_call_end first (emits final arguments) + 2. THEN exits the section + 3. Does NOT drop the final tool call update + 4. Does NOT leak special tokens into reasoning + + This tests the deferred section exit fix. + """ + kimi_k2_tool_parser.reset_streaming_state() + + section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") + section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") + tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>") + tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>") + + # Simulate a streaming sequence for a SHORT tool call (all in one chunk): + # 1. Reasoning text + result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="Let me help. ", + delta_text="Let me help. ", + previous_token_ids=[], + current_token_ids=[1, 2], + delta_token_ids=[1, 2], + request=None, + ) + assert result1 is not None + assert result1.content == "Let me help. " + + # 2. Section begin + _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="Let me help. ", + current_text="Let me help. <|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=[1, 2], + current_token_ids=[1, 2, section_begin_id], + delta_token_ids=[section_begin_id], + request=None, + ) + assert kimi_k2_tool_parser.in_tool_section is True + + # 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK + # This is the critical scenario for short tool calls + combined = ( + '<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} ' + "<|tool_call_end|><|tool_calls_section_end|>" + ) + + # Build up the previous text gradually to simulate realistic streaming + prev_text = "Let me help. <|tool_calls_section_begin|>" + curr_text = prev_text + combined + + result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text=prev_text, + current_text=curr_text, + delta_text=combined, + previous_token_ids=[1, 2, section_begin_id], + current_token_ids=[ + 1, + 2, + section_begin_id, + tool_begin_id, + 10, + 11, + 12, + tool_end_id, + section_end_id, + ], + delta_token_ids=[tool_begin_id, 10, 11, 12, tool_end_id, section_end_id], + request=None, + ) + + # CRITICAL: Parser should have exited section AFTER processing tool + assert kimi_k2_tool_parser.in_tool_section is False + + # Tool call should have been emitted (not dropped) + # The result might be the tool name or None depending on state, but + # importantly, it shouldn't be returning the literal tokens as content + + if result3 is not None and result3.content is not None: + # Verify no special tokens leaked into content + assert "<|tool_call_end|>" not in result3.content + assert "<|tool_calls_section_end|>" not in result3.content + + # 4. Verify subsequent content streams normally + result4 = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text=curr_text, + current_text=curr_text + " Done", + delta_text=" Done", + previous_token_ids=[ + 1, + 2, + section_begin_id, + tool_begin_id, + 10, + 11, + 12, + tool_end_id, + section_end_id, + ], + current_token_ids=[ + 1, + 2, + section_begin_id, + tool_begin_id, + 10, + 11, + 12, + tool_end_id, + section_end_id, + 20, + ], + delta_token_ids=[20], + request=None, + ) + + # Content after tool section should stream normally + assert result4 is not None + assert result4.content == " Done" diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py index 0453db58361a9..a84c9e4547168 100644 --- a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py @@ -34,8 +34,27 @@ class KimiK2ToolParser(ToolParser): str ] = [] # map what has been streamed for each tool so far to a list + # Section-level state management to prevent token leakage + self.in_tool_section: bool = False + self.token_buffer: str = "" + # Buffer size: empirical worst-case for longest marker (~30 chars) * 2 + # + safety margin for unicode + partial overlap. Prevents unbounded growth. + self.buffer_max_size: int = 1024 + self.section_char_count: int = 0 # Track characters processed in tool section + self.max_section_chars: int = 8192 # Force exit if section exceeds this + self._buffer_overflow_logged: bool = False # Log overflow once per session + + # Support both singular and plural variants self.tool_calls_start_token: str = "<|tool_calls_section_begin|>" self.tool_calls_end_token: str = "<|tool_calls_section_end|>" + self.tool_calls_start_token_variants: list[str] = [ + "<|tool_calls_section_begin|>", + "<|tool_call_section_begin|>", # singular variant + ] + self.tool_calls_end_token_variants: list[str] = [ + "<|tool_calls_section_end|>", + "<|tool_call_section_end|>", # singular variant + ] self.tool_call_start_token: str = "<|tool_call_begin|>" self.tool_call_end_token: str = "<|tool_call_end|>" @@ -58,6 +77,18 @@ class KimiK2ToolParser(ToolParser): self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) + # Get token IDs for all variants + self.tool_calls_start_token_ids: list[int] = [ + tid + for variant in self.tool_calls_start_token_variants + if (tid := self.vocab.get(variant)) is not None + ] + self.tool_calls_end_token_ids: list[int] = [ + tid + for variant in self.tool_calls_end_token_variants + if (tid := self.vocab.get(variant)) is not None + ] + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) @@ -70,6 +101,51 @@ class KimiK2ToolParser(ToolParser): "tokens in the tokenizer!" ) + def _check_and_strip_markers(self, text: str) -> tuple[str, bool, bool]: + """ + Check for section begin/end markers in text and strip them. + Returns: (cleaned_text, found_section_begin, found_section_end) + """ + found_begin = False + found_end = False + cleaned = text + + # Check for section begin markers (any variant) + for variant in self.tool_calls_start_token_variants: + if variant in cleaned: + cleaned = cleaned.replace(variant, "") + found_begin = True + + # Check for section end markers (any variant) + for variant in self.tool_calls_end_token_variants: + if variant in cleaned: + cleaned = cleaned.replace(variant, "") + found_end = True + + return cleaned, found_begin, found_end + + def _reset_section_state(self) -> None: + """Reset state when exiting tool section.""" + self.in_tool_section = False + self.token_buffer = "" + self.section_char_count = 0 + + def reset_streaming_state(self) -> None: + """ + Reset all streaming state. Call this between requests to prevent + state leakage when parser instance is reused. + """ + # Reset section state + self._reset_section_state() + + # Reset parent class state + self.current_tool_name_sent = False + self.prev_tool_call_arr = [] + self.current_tool_id = -1 + self.streamed_args_for_tool = [] + + logger.debug("Streaming state reset") + def extract_tool_calls( self, model_output: str, @@ -131,13 +207,94 @@ class KimiK2ToolParser(ToolParser): ) -> DeltaMessage | None: logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) - # check to see if we should be streaming a tool call - is there a - if self.tool_calls_start_token_id not in current_token_ids: - logger.debug("No tool call tokens found!") - return DeltaMessage(content=delta_text) - delta_text = delta_text.replace(self.tool_calls_start_token, "").replace( - self.tool_calls_end_token, "" + + # Flag to defer section exit until after tool parsing completes + deferred_section_exit = False + + # Add delta to buffer for split marker detection + self.token_buffer += delta_text + + # Enforce buffer size limit to prevent memory issues + if len(self.token_buffer) > self.buffer_max_size: + if not self._buffer_overflow_logged: + logger.warning( + "Token buffer exceeded max size (%d bytes), flushing excess. " + "This may indicate very long markers or unusual tokenization.", + self.buffer_max_size, + ) + self._buffer_overflow_logged = True + # Keep only the most recent content that might contain partial markers + self.token_buffer = self.token_buffer[-self.buffer_max_size // 2 :] + + # Check buffer for section markers (handles split tokens) + buffered_text, found_section_begin, found_section_end = ( + self._check_and_strip_markers(self.token_buffer) ) + + # Track section state transitions + if found_section_begin and not self.in_tool_section: + logger.debug("Entering tool section") + self.in_tool_section = True + self.token_buffer = buffered_text # Use cleaned buffer + self.section_char_count = 0 # Reset counter for new section + if found_section_end and self.in_tool_section: + logger.debug("Detected section end marker") + # CRITICAL: Don't exit early if tool_call_end is in this chunk. + # Tool parser must emit final arguments/close first to avoid dropping + # the final tool update and leaking tokens into reasoning channel. + has_tool_end = self.tool_call_end_token_id in delta_token_ids + if has_tool_end: + # Defer exit until after tool parsing completes + deferred_section_exit = True + logger.debug("Deferring section exit: tool_call_end in same chunk") + self.token_buffer = buffered_text + else: + # No tool call ending, safe to exit immediately + logger.debug("Exiting tool section") + remaining = buffered_text + self._reset_section_state() + # Return remaining text as reasoning content if non-empty + if remaining.strip(): + return DeltaMessage(content=remaining) + # Return empty delta to maintain function contract + # (always returns DeltaMessage) + return DeltaMessage(content="") + else: + self.token_buffer = buffered_text + + # Check if any variant of section start token is in current_token_ids + has_section_token = any( + tid in current_token_ids for tid in self.tool_calls_start_token_ids + ) + + # Early return: if no section token detected yet, return as reasoning content + if not has_section_token and not self.in_tool_section: + logger.debug("No tool call tokens found!") + # Don't clear buffer - it needs to accumulate partial markers across deltas + # Buffer overflow is already protected by lines 215-224 + return DeltaMessage(content=delta_text) + + # Strip section markers from delta_text for subsequent processing + # NOTE: This preprocessing happens BEFORE the regex-based tool call + # parsing (from PR #24847) to ensure markers are removed cleanly + # before pattern matching. No double-stripping occurs because + # section markers and tool call markers are distinct. + delta_text, _, _ = self._check_and_strip_markers(delta_text) + + # Error recovery: If in tool section for too long, force exit + if self.in_tool_section: + self.section_char_count += len(delta_text) + if self.section_char_count > self.max_section_chars: + logger.warning( + "Tool section exceeded max length (%d chars), forcing exit. " + "This may indicate malformed model output.", + self.max_section_chars, + ) + self._reset_section_state() + # Deferred exit already handled by forced exit above + # Return remaining content as reasoning (or empty delta if no content) + return DeltaMessage(content=delta_text if delta_text.strip() else "") + try: # figure out where we are in the parsing by counting tool call # start & end tags @@ -158,6 +315,16 @@ class KimiK2ToolParser(ToolParser): and prev_tool_end_count == cur_tool_end_count and self.tool_call_end_token not in delta_text ): + # CRITICAL FIX: Suppress content if in tool section but + # no tool calls started + if self.in_tool_section and cur_tool_start_count == 0: + logger.debug( + "In tool section but no tool calls started yet. " + "Suppressing: %s", + delta_text, + ) + # Return empty delta to maintain iterator contract + return DeltaMessage(content="") logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) @@ -209,6 +376,9 @@ class KimiK2ToolParser(ToolParser): ): if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: logger.debug("attempting to close tool call, but no tool call") + # Handle deferred section exit before returning + if deferred_section_exit and self.in_tool_section: + self._reset_section_state() return None diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: @@ -218,6 +388,9 @@ class KimiK2ToolParser(ToolParser): else diff ) if '"}' not in delta_text: + # Handle deferred section exit before returning + if deferred_section_exit and self.in_tool_section: + self._reset_section_state() return None end_loc = delta_text.rindex('"}') diff = delta_text[:end_loc] + '"}' @@ -227,6 +400,10 @@ class KimiK2ToolParser(ToolParser): diff, ) self.streamed_args_for_tool[self.current_tool_id] += diff + # Handle deferred section exit before returning + if deferred_section_exit and self.in_tool_section: + logger.debug("Completing deferred section exit") + self._reset_section_state() return DeltaMessage( tool_calls=[ DeltaToolCall( @@ -240,9 +417,19 @@ class KimiK2ToolParser(ToolParser): # case -- otherwise we're just generating text else: + # Check if we're in tool section - if so, suppress + if self.in_tool_section: + logger.debug("In tool section, suppressing text generation") + # Handle deferred section exit before returning + if deferred_section_exit: + self._reset_section_state() + return DeltaMessage(content="") text = delta_text.replace(self.tool_call_start_token, "") text = text.replace(self.tool_call_end_token, "") delta = DeltaMessage(tool_calls=[], content=text) + # Handle deferred section exit before returning + if deferred_section_exit and self.in_tool_section: + self._reset_section_state() return delta current_tool_call = dict() @@ -390,6 +577,11 @@ class KimiK2ToolParser(ToolParser): else: self.prev_tool_call_arr.append(current_tool_call) + # Handle deferred section exit after tool parsing completes + if deferred_section_exit and self.in_tool_section: + logger.debug("Completing deferred section exit") + self._reset_section_state() + return delta except Exception: