[Bugfix][Model] Prevent special token leakage in KimiK2ToolParser streaming mode (#28543)

Signed-off-by: Jscaldwell55 <jay.s.caldwell@gmail.com>
This commit is contained in:
Jay Caldwell 2025-11-16 23:54:46 -06:00 committed by GitHub
parent 60e089f0b9
commit 6f37419244
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 791 additions and 6 deletions

View File

@ -209,3 +209,596 @@ def test_streaming_no_tool_calls(kimi_k2_tool_parser):
assert result is not None
assert hasattr(result, "content")
assert result.content == " without any tool calls."
def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
"""
Test that text between <|tool_calls_section_begin|> and <|tool_call_begin|>
is suppressed and does not leak into reasoning_delta.
This is the main vulnerability being fixed.
"""
kimi_k2_tool_parser.reset_streaming_state()
# Get token IDs for the markers
section_begin_token_id = kimi_k2_tool_parser.vocab.get(
"<|tool_calls_section_begin|>"
)
tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
# Simulate streaming sequence:
# Delta 1: "I'll help you with that. "
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="I'll help you with that. ",
delta_text="I'll help you with that. ",
previous_token_ids=[],
current_token_ids=[1, 2, 3], # Regular tokens
delta_token_ids=[1, 2, 3],
request=None,
)
assert result1 is not None
assert result1.content == "I'll help you with that. "
# Delta 2: "<|tool_calls_section_begin|>"
prev_ids = [1, 2, 3]
curr_ids = prev_ids + [section_begin_token_id]
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. ",
current_text="I'll help you with that. <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[section_begin_token_id],
request=None,
)
# Section marker should be stripped and suppressed
assert result2 is None or (result2.content is None or result2.content == "")
# Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
prev_ids = curr_ids
curr_ids = curr_ids + [4, 5]
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. <|tool_calls_section_begin|>",
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
delta_text=" spurious text ",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[4, 5],
request=None,
)
# CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
assert result3 is None or (result3.content is None or result3.content == "")
# Delta 4: "<|tool_call_begin|>..."
prev_ids = curr_ids
curr_ids = curr_ids + [tool_call_begin_token_id]
_result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>",
delta_text="<|tool_call_begin|>",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[tool_call_begin_token_id],
request=None,
)
# Now we're in tool call mode, result depends on internal state
# The key is that the spurious text from Delta 3 was not leaked
def test_split_markers_across_deltas(kimi_k2_tool_parser):
"""
Test that markers split across delta chunks are correctly detected
via the rolling buffer mechanism.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_token_id = kimi_k2_tool_parser.vocab.get(
"<|tool_calls_section_begin|>"
)
# Delta 1: "...reasoning<|tool_calls_sec"
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning",
current_text="Some reasoning<|tool_calls_sec",
delta_text="<|tool_calls_sec",
previous_token_ids=[1, 2],
current_token_ids=[1, 2, 3], # Partial token
delta_token_ids=[3],
request=None,
)
# Partial token not recognized yet, might be buffered
# Should return as content or None (depends on implementation)
# Delta 2: "tion_begin|> " (completes the marker)
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning<|tool_calls_sec",
current_text="Some reasoning<|tool_calls_section_begin|> ",
delta_text="tion_begin|> ",
previous_token_ids=[1, 2, 3],
current_token_ids=[1, 2, section_begin_token_id, 4],
delta_token_ids=[section_begin_token_id, 4],
request=None,
)
# Now the complete marker should be detected via buffer
# The parser should enter tool section mode
assert kimi_k2_tool_parser.in_tool_section is True
def test_marker_variants(kimi_k2_tool_parser):
"""Test that both singular and plural marker variants are recognized."""
kimi_k2_tool_parser.reset_streaming_state()
# Test singular variant: <|tool_call_section_begin|> (note: singular "call")
singular_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_section_begin|>")
if singular_token_id is not None: # Only test if tokenizer supports it
_result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning <|tool_call_section_begin|>",
delta_text="<|tool_call_section_begin|>",
previous_token_ids=[1, 2],
current_token_ids=[1, 2, singular_token_id],
delta_token_ids=[singular_token_id],
request=None,
)
# Should enter tool section mode with singular variant too
assert kimi_k2_tool_parser.in_tool_section is True
def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
"""
Test that after exiting a tool section with <|tool_calls_section_end|>,
subsequent text is correctly returned as reasoning content.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Enter tool section
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="<|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# Exit tool section
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id, section_end_id],
delta_token_ids=[section_end_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is False
# Subsequent reasoning text should be returned normally
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning",
delta_text=" More reasoning",
previous_token_ids=[section_begin_id, section_end_id],
current_token_ids=[section_begin_id, section_end_id, 10, 11],
delta_token_ids=[10, 11],
request=None,
)
assert result3 is not None
assert result3.content == " More reasoning"
def test_empty_tool_section(kimi_k2_tool_parser):
"""Test an empty tool section (begin immediately followed by end)."""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Section begin
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[1],
current_token_ids=[1, section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
# Immediate section end
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning <|tool_calls_section_begin|>",
current_text="Reasoning <|tool_calls_section_begin|><|tool_calls_section_end|>",
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[1, section_begin_id],
current_token_ids=[1, section_begin_id, section_end_id],
delta_token_ids=[section_end_id],
request=None,
)
# Should exit cleanly without errors
assert kimi_k2_tool_parser.in_tool_section is False
def test_malformed_tool_section_recovery(kimi_k2_tool_parser):
"""
Test that the parser recovers from a malformed tool section
that never closes properly.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
# Enter tool section
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="<|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# Simulate a lot of text without proper tool calls or section end
# This should trigger the error recovery mechanism
large_text = "x" * 10000 # Exceeds max_section_chars
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|>" + large_text,
delta_text=large_text,
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id] + list(range(100, 100 + len(large_text))),
delta_token_ids=list(range(100, 100 + len(large_text))),
request=None,
)
# Parser should have force-exited the tool section
assert kimi_k2_tool_parser.in_tool_section is False
# And returned the content as reasoning
assert result2 is not None
assert result2.content == large_text
def test_state_reset(kimi_k2_tool_parser):
"""Test that reset_streaming_state() properly clears all state."""
# Put parser in a complex state
kimi_k2_tool_parser.in_tool_section = True
kimi_k2_tool_parser.token_buffer = "some buffer"
kimi_k2_tool_parser.current_tool_id = 5
kimi_k2_tool_parser.prev_tool_call_arr = [{"id": "test"}]
kimi_k2_tool_parser.section_char_count = 1000
# Reset
kimi_k2_tool_parser.reset_streaming_state()
# Verify all state is cleared
assert kimi_k2_tool_parser.in_tool_section is False
assert kimi_k2_tool_parser.token_buffer == ""
assert kimi_k2_tool_parser.current_tool_id == -1
assert kimi_k2_tool_parser.prev_tool_call_arr == []
assert kimi_k2_tool_parser.section_char_count == 0
assert kimi_k2_tool_parser.current_tool_name_sent is False
assert kimi_k2_tool_parser.streamed_args_for_tool == []
def test_section_begin_noise_tool_begin_same_chunk(kimi_k2_tool_parser):
"""
Test that beginnoisetool_begin within the SAME chunk suppresses
the noise text correctly (not just across chunks).
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
tool_call_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
# Single delta containing: section_begin + spurious text + tool_call_begin
combined_text = "<|tool_calls_section_begin|> noise text <|tool_call_begin|>"
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning " + combined_text,
delta_text=combined_text,
previous_token_ids=[1, 2],
current_token_ids=[1, 2, section_begin_id, 3, 4, tool_call_begin_id],
delta_token_ids=[section_begin_id, 3, 4, tool_call_begin_id],
request=None,
)
# The noise text should NOT leak into content
# Result should either be None/empty or start tool call parsing
if result is not None and result.content is not None:
# If content is returned, it should not contain the noise
assert "noise text" not in result.content
assert result.content == "" or result.content.strip() == ""
def test_stream_ends_without_section_end_marker(kimi_k2_tool_parser):
"""
Test that if the stream ends (EOF) without a proper section end marker,
the parser doesn't leak text, doesn't crash, and resets state cleanly.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
# Enter tool section
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="<|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# Some content in tool section
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|> partial content",
delta_text=" partial content",
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id, 10, 11],
delta_token_ids=[10, 11],
request=None,
)
# Content should be suppressed
assert result2.content == "" or result2.content is None
# Stream ends (EOF) - no more deltas, no section_end marker
# Simulate this by manually checking state and resetting
# (In real usage, the request handler would call reset_streaming_state)
assert kimi_k2_tool_parser.in_tool_section is True # Still in section
# Reset state (as would happen between requests)
kimi_k2_tool_parser.reset_streaming_state()
# Verify clean slate
assert kimi_k2_tool_parser.in_tool_section is False
assert kimi_k2_tool_parser.token_buffer == ""
# Next request should work normally
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="New reasoning",
delta_text="New reasoning",
previous_token_ids=[],
current_token_ids=[20, 21],
delta_token_ids=[20, 21],
request=None,
)
assert result3 is not None
assert result3.content == "New reasoning"
def test_same_chunk_begin_and_end_markers(kimi_k2_tool_parser):
"""
CRITICAL TEST: Verify that when both section_begin and section_end
markers appear in the SAME chunk, the parser correctly:
1. Enters the tool section
2. Immediately exits the tool section
3. Does NOT get stuck in in_tool_section=True state
This tests the bug fix where elif was changed to if to handle
both state transitions in a single delta.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Single chunk with both markers (e.g., empty tool section)
combined_delta = "<|tool_calls_section_begin|><|tool_calls_section_end|>"
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning ",
current_text="Some reasoning " + combined_delta,
delta_text=combined_delta,
previous_token_ids=[1, 2],
current_token_ids=[1, 2, section_begin_id, section_end_id],
delta_token_ids=[section_begin_id, section_end_id],
request=None,
)
# CRITICAL: Parser should NOT be stuck in tool section
assert kimi_k2_tool_parser.in_tool_section is False, (
"Parser stuck in tool section after processing both begin/end in same chunk. "
"This indicates the elif bug was not fixed."
)
# Result should be empty or contain only stripped content
assert result is not None
assert result.content == "" or result.content is None
# Verify subsequent content streams correctly (not suppressed)
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning " + combined_delta,
current_text="Some reasoning " + combined_delta + " More reasoning",
delta_text=" More reasoning",
previous_token_ids=[1, 2, section_begin_id, section_end_id],
current_token_ids=[1, 2, section_begin_id, section_end_id, 10, 11],
delta_token_ids=[10, 11],
request=None,
)
# This content should NOT be suppressed (we're out of tool section)
assert result2 is not None
assert result2.content == " More reasoning"
def test_same_chunk_begin_content_end_markers(kimi_k2_tool_parser):
"""
Test the same-chunk scenario with actual content between markers.
Example: <|tool_calls_section_begin|> text <|tool_calls_section_end|>
all arriving in one delta. The key is that the state machine correctly
transitions in and out within the same chunk.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Chunk with begin, some whitespace/noise, and end all together
# This simulates a tool section that opens and closes in the same chunk
combined_delta = "<|tool_calls_section_begin|> <|tool_calls_section_end|>"
_result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning " + combined_delta,
delta_text=combined_delta,
previous_token_ids=[1],
current_token_ids=[1, section_begin_id, 100, section_end_id],
delta_token_ids=[section_begin_id, 100, section_end_id],
request=None,
)
# Parser should exit cleanly (not stuck in tool section)
assert kimi_k2_tool_parser.in_tool_section is False
# Verify the fix: next content should stream normally, not be suppressed
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning " + combined_delta,
current_text="Reasoning " + combined_delta + " Done",
delta_text=" Done",
previous_token_ids=[1, section_begin_id, 100, section_end_id],
current_token_ids=[1, section_begin_id, 100, section_end_id, 200],
delta_token_ids=[200],
request=None,
)
# Content after section should be returned (not suppressed)
assert result2 is not None
assert result2.content == " Done"
def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
"""
CRITICAL TEST (P1): Verify that when both <|tool_call_end|> and
<|tool_calls_section_end|> appear in the SAME chunk, the parser:
1. Processes the tool_call_end first (emits final arguments)
2. THEN exits the section
3. Does NOT drop the final tool call update
4. Does NOT leak special tokens into reasoning
This tests the deferred section exit fix.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
# Simulate a streaming sequence for a SHORT tool call (all in one chunk):
# 1. Reasoning text
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="Let me help. ",
delta_text="Let me help. ",
previous_token_ids=[],
current_token_ids=[1, 2],
delta_token_ids=[1, 2],
request=None,
)
assert result1 is not None
assert result1.content == "Let me help. "
# 2. Section begin
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Let me help. ",
current_text="Let me help. <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[1, 2],
current_token_ids=[1, 2, section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK
# This is the critical scenario for short tool calls
combined = (
'<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
"<|tool_call_end|><|tool_calls_section_end|>"
)
# Build up the previous text gradually to simulate realistic streaming
prev_text = "Let me help. <|tool_calls_section_begin|>"
curr_text = prev_text + combined
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text=prev_text,
current_text=curr_text,
delta_text=combined,
previous_token_ids=[1, 2, section_begin_id],
current_token_ids=[
1,
2,
section_begin_id,
tool_begin_id,
10,
11,
12,
tool_end_id,
section_end_id,
],
delta_token_ids=[tool_begin_id, 10, 11, 12, tool_end_id, section_end_id],
request=None,
)
# CRITICAL: Parser should have exited section AFTER processing tool
assert kimi_k2_tool_parser.in_tool_section is False
# Tool call should have been emitted (not dropped)
# The result might be the tool name or None depending on state, but
# importantly, it shouldn't be returning the literal tokens as content
if result3 is not None and result3.content is not None:
# Verify no special tokens leaked into content
assert "<|tool_call_end|>" not in result3.content
assert "<|tool_calls_section_end|>" not in result3.content
# 4. Verify subsequent content streams normally
result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text=curr_text,
current_text=curr_text + " Done",
delta_text=" Done",
previous_token_ids=[
1,
2,
section_begin_id,
tool_begin_id,
10,
11,
12,
tool_end_id,
section_end_id,
],
current_token_ids=[
1,
2,
section_begin_id,
tool_begin_id,
10,
11,
12,
tool_end_id,
section_end_id,
20,
],
delta_token_ids=[20],
request=None,
)
# Content after tool section should stream normally
assert result4 is not None
assert result4.content == " Done"

View File

@ -34,8 +34,27 @@ class KimiK2ToolParser(ToolParser):
str
] = [] # map what has been streamed for each tool so far to a list
# Section-level state management to prevent token leakage
self.in_tool_section: bool = False
self.token_buffer: str = ""
# Buffer size: empirical worst-case for longest marker (~30 chars) * 2
# + safety margin for unicode + partial overlap. Prevents unbounded growth.
self.buffer_max_size: int = 1024
self.section_char_count: int = 0 # Track characters processed in tool section
self.max_section_chars: int = 8192 # Force exit if section exceeds this
self._buffer_overflow_logged: bool = False # Log overflow once per session
# Support both singular and plural variants
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
self.tool_calls_start_token_variants: list[str] = [
"<|tool_calls_section_begin|>",
"<|tool_call_section_begin|>", # singular variant
]
self.tool_calls_end_token_variants: list[str] = [
"<|tool_calls_section_end|>",
"<|tool_call_section_end|>", # singular variant
]
self.tool_call_start_token: str = "<|tool_call_begin|>"
self.tool_call_end_token: str = "<|tool_call_end|>"
@ -58,6 +77,18 @@ class KimiK2ToolParser(ToolParser):
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
# Get token IDs for all variants
self.tool_calls_start_token_ids: list[int] = [
tid
for variant in self.tool_calls_start_token_variants
if (tid := self.vocab.get(variant)) is not None
]
self.tool_calls_end_token_ids: list[int] = [
tid
for variant in self.tool_calls_end_token_variants
if (tid := self.vocab.get(variant)) is not None
]
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
@ -70,6 +101,51 @@ class KimiK2ToolParser(ToolParser):
"tokens in the tokenizer!"
)
def _check_and_strip_markers(self, text: str) -> tuple[str, bool, bool]:
"""
Check for section begin/end markers in text and strip them.
Returns: (cleaned_text, found_section_begin, found_section_end)
"""
found_begin = False
found_end = False
cleaned = text
# Check for section begin markers (any variant)
for variant in self.tool_calls_start_token_variants:
if variant in cleaned:
cleaned = cleaned.replace(variant, "")
found_begin = True
# Check for section end markers (any variant)
for variant in self.tool_calls_end_token_variants:
if variant in cleaned:
cleaned = cleaned.replace(variant, "")
found_end = True
return cleaned, found_begin, found_end
def _reset_section_state(self) -> None:
"""Reset state when exiting tool section."""
self.in_tool_section = False
self.token_buffer = ""
self.section_char_count = 0
def reset_streaming_state(self) -> None:
"""
Reset all streaming state. Call this between requests to prevent
state leakage when parser instance is reused.
"""
# Reset section state
self._reset_section_state()
# Reset parent class state
self.current_tool_name_sent = False
self.prev_tool_call_arr = []
self.current_tool_id = -1
self.streamed_args_for_tool = []
logger.debug("Streaming state reset")
def extract_tool_calls(
self,
model_output: str,
@ -131,13 +207,94 @@ class KimiK2ToolParser(ToolParser):
) -> DeltaMessage | None:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# check to see if we should be streaming a tool call - is there a
if self.tool_calls_start_token_id not in current_token_ids:
logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)
delta_text = delta_text.replace(self.tool_calls_start_token, "").replace(
self.tool_calls_end_token, ""
# Flag to defer section exit until after tool parsing completes
deferred_section_exit = False
# Add delta to buffer for split marker detection
self.token_buffer += delta_text
# Enforce buffer size limit to prevent memory issues
if len(self.token_buffer) > self.buffer_max_size:
if not self._buffer_overflow_logged:
logger.warning(
"Token buffer exceeded max size (%d bytes), flushing excess. "
"This may indicate very long markers or unusual tokenization.",
self.buffer_max_size,
)
self._buffer_overflow_logged = True
# Keep only the most recent content that might contain partial markers
self.token_buffer = self.token_buffer[-self.buffer_max_size // 2 :]
# Check buffer for section markers (handles split tokens)
buffered_text, found_section_begin, found_section_end = (
self._check_and_strip_markers(self.token_buffer)
)
# Track section state transitions
if found_section_begin and not self.in_tool_section:
logger.debug("Entering tool section")
self.in_tool_section = True
self.token_buffer = buffered_text # Use cleaned buffer
self.section_char_count = 0 # Reset counter for new section
if found_section_end and self.in_tool_section:
logger.debug("Detected section end marker")
# CRITICAL: Don't exit early if tool_call_end is in this chunk.
# Tool parser must emit final arguments/close first to avoid dropping
# the final tool update and leaking tokens into reasoning channel.
has_tool_end = self.tool_call_end_token_id in delta_token_ids
if has_tool_end:
# Defer exit until after tool parsing completes
deferred_section_exit = True
logger.debug("Deferring section exit: tool_call_end in same chunk")
self.token_buffer = buffered_text
else:
# No tool call ending, safe to exit immediately
logger.debug("Exiting tool section")
remaining = buffered_text
self._reset_section_state()
# Return remaining text as reasoning content if non-empty
if remaining.strip():
return DeltaMessage(content=remaining)
# Return empty delta to maintain function contract
# (always returns DeltaMessage)
return DeltaMessage(content="")
else:
self.token_buffer = buffered_text
# Check if any variant of section start token is in current_token_ids
has_section_token = any(
tid in current_token_ids for tid in self.tool_calls_start_token_ids
)
# Early return: if no section token detected yet, return as reasoning content
if not has_section_token and not self.in_tool_section:
logger.debug("No tool call tokens found!")
# Don't clear buffer - it needs to accumulate partial markers across deltas
# Buffer overflow is already protected by lines 215-224
return DeltaMessage(content=delta_text)
# Strip section markers from delta_text for subsequent processing
# NOTE: This preprocessing happens BEFORE the regex-based tool call
# parsing (from PR #24847) to ensure markers are removed cleanly
# before pattern matching. No double-stripping occurs because
# section markers and tool call markers are distinct.
delta_text, _, _ = self._check_and_strip_markers(delta_text)
# Error recovery: If in tool section for too long, force exit
if self.in_tool_section:
self.section_char_count += len(delta_text)
if self.section_char_count > self.max_section_chars:
logger.warning(
"Tool section exceeded max length (%d chars), forcing exit. "
"This may indicate malformed model output.",
self.max_section_chars,
)
self._reset_section_state()
# Deferred exit already handled by forced exit above
# Return remaining content as reasoning (or empty delta if no content)
return DeltaMessage(content=delta_text if delta_text.strip() else "")
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
@ -158,6 +315,16 @@ class KimiK2ToolParser(ToolParser):
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text
):
# CRITICAL FIX: Suppress content if in tool section but
# no tool calls started
if self.in_tool_section and cur_tool_start_count == 0:
logger.debug(
"In tool section but no tool calls started yet. "
"Suppressing: %s",
delta_text,
)
# Return empty delta to maintain iterator contract
return DeltaMessage(content="")
logger.debug("Generating text content! skipping tool parsing.")
return DeltaMessage(content=delta_text)
@ -209,6 +376,9 @@ class KimiK2ToolParser(ToolParser):
):
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
logger.debug("attempting to close tool call, but no tool call")
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
self._reset_section_state()
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
if diff:
@ -218,6 +388,9 @@ class KimiK2ToolParser(ToolParser):
else diff
)
if '"}' not in delta_text:
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
self._reset_section_state()
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
@ -227,6 +400,10 @@ class KimiK2ToolParser(ToolParser):
diff,
)
self.streamed_args_for_tool[self.current_tool_id] += diff
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
logger.debug("Completing deferred section exit")
self._reset_section_state()
return DeltaMessage(
tool_calls=[
DeltaToolCall(
@ -240,9 +417,19 @@ class KimiK2ToolParser(ToolParser):
# case -- otherwise we're just generating text
else:
# Check if we're in tool section - if so, suppress
if self.in_tool_section:
logger.debug("In tool section, suppressing text generation")
# Handle deferred section exit before returning
if deferred_section_exit:
self._reset_section_state()
return DeltaMessage(content="")
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
self._reset_section_state()
return delta
current_tool_call = dict()
@ -390,6 +577,11 @@ class KimiK2ToolParser(ToolParser):
else:
self.prev_tool_call_arr.append(current_tool_call)
# Handle deferred section exit after tool parsing completes
if deferred_section_exit and self.in_tool_section:
logger.debug("Completing deferred section exit")
self._reset_section_state()
return delta
except Exception: