mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-13 21:07:03 +08:00
[Bugfix][Model] Prevent special token leakage in KimiK2ToolParser streaming mode (#28543)
Signed-off-by: Jscaldwell55 <jay.s.caldwell@gmail.com>
This commit is contained in:
parent
60e089f0b9
commit
6f37419244
@ -209,3 +209,596 @@ def test_streaming_no_tool_calls(kimi_k2_tool_parser):
|
||||
assert result is not None
|
||||
assert hasattr(result, "content")
|
||||
assert result.content == " without any tool calls."
|
||||
|
||||
|
||||
def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that text between <|tool_calls_section_begin|> and <|tool_call_begin|>
|
||||
is suppressed and does not leak into reasoning_delta.
|
||||
This is the main vulnerability being fixed.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
# Get token IDs for the markers
|
||||
section_begin_token_id = kimi_k2_tool_parser.vocab.get(
|
||||
"<|tool_calls_section_begin|>"
|
||||
)
|
||||
tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
|
||||
|
||||
# Simulate streaming sequence:
|
||||
# Delta 1: "I'll help you with that. "
|
||||
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="I'll help you with that. ",
|
||||
delta_text="I'll help you with that. ",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[1, 2, 3], # Regular tokens
|
||||
delta_token_ids=[1, 2, 3],
|
||||
request=None,
|
||||
)
|
||||
assert result1 is not None
|
||||
assert result1.content == "I'll help you with that. "
|
||||
|
||||
# Delta 2: "<|tool_calls_section_begin|>"
|
||||
prev_ids = [1, 2, 3]
|
||||
curr_ids = prev_ids + [section_begin_token_id]
|
||||
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I'll help you with that. ",
|
||||
current_text="I'll help you with that. <|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=prev_ids,
|
||||
current_token_ids=curr_ids,
|
||||
delta_token_ids=[section_begin_token_id],
|
||||
request=None,
|
||||
)
|
||||
# Section marker should be stripped and suppressed
|
||||
assert result2 is None or (result2.content is None or result2.content == "")
|
||||
|
||||
# Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
|
||||
prev_ids = curr_ids
|
||||
curr_ids = curr_ids + [4, 5]
|
||||
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I'll help you with that. <|tool_calls_section_begin|>",
|
||||
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
|
||||
delta_text=" spurious text ",
|
||||
previous_token_ids=prev_ids,
|
||||
current_token_ids=curr_ids,
|
||||
delta_token_ids=[4, 5],
|
||||
request=None,
|
||||
)
|
||||
# CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
|
||||
assert result3 is None or (result3.content is None or result3.content == "")
|
||||
|
||||
# Delta 4: "<|tool_call_begin|>..."
|
||||
prev_ids = curr_ids
|
||||
curr_ids = curr_ids + [tool_call_begin_token_id]
|
||||
_result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
|
||||
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>",
|
||||
delta_text="<|tool_call_begin|>",
|
||||
previous_token_ids=prev_ids,
|
||||
current_token_ids=curr_ids,
|
||||
delta_token_ids=[tool_call_begin_token_id],
|
||||
request=None,
|
||||
)
|
||||
# Now we're in tool call mode, result depends on internal state
|
||||
# The key is that the spurious text from Delta 3 was not leaked
|
||||
|
||||
|
||||
def test_split_markers_across_deltas(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that markers split across delta chunks are correctly detected
|
||||
via the rolling buffer mechanism.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_token_id = kimi_k2_tool_parser.vocab.get(
|
||||
"<|tool_calls_section_begin|>"
|
||||
)
|
||||
|
||||
# Delta 1: "...reasoning<|tool_calls_sec"
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Some reasoning",
|
||||
current_text="Some reasoning<|tool_calls_sec",
|
||||
delta_text="<|tool_calls_sec",
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, 3], # Partial token
|
||||
delta_token_ids=[3],
|
||||
request=None,
|
||||
)
|
||||
# Partial token not recognized yet, might be buffered
|
||||
# Should return as content or None (depends on implementation)
|
||||
|
||||
# Delta 2: "tion_begin|> " (completes the marker)
|
||||
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Some reasoning<|tool_calls_sec",
|
||||
current_text="Some reasoning<|tool_calls_section_begin|> ",
|
||||
delta_text="tion_begin|> ",
|
||||
previous_token_ids=[1, 2, 3],
|
||||
current_token_ids=[1, 2, section_begin_token_id, 4],
|
||||
delta_token_ids=[section_begin_token_id, 4],
|
||||
request=None,
|
||||
)
|
||||
# Now the complete marker should be detected via buffer
|
||||
# The parser should enter tool section mode
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
|
||||
def test_marker_variants(kimi_k2_tool_parser):
|
||||
"""Test that both singular and plural marker variants are recognized."""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
# Test singular variant: <|tool_call_section_begin|> (note: singular "call")
|
||||
singular_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_section_begin|>")
|
||||
|
||||
if singular_token_id is not None: # Only test if tokenizer supports it
|
||||
_result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning ",
|
||||
current_text="Reasoning <|tool_call_section_begin|>",
|
||||
delta_text="<|tool_call_section_begin|>",
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, singular_token_id],
|
||||
delta_token_ids=[singular_token_id],
|
||||
request=None,
|
||||
)
|
||||
# Should enter tool section mode with singular variant too
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
|
||||
def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that after exiting a tool section with <|tool_calls_section_end|>,
|
||||
subsequent text is correctly returned as reasoning content.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
|
||||
# Enter tool section
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="<|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
# Exit tool section
|
||||
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="<|tool_calls_section_begin|>",
|
||||
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
|
||||
delta_text="<|tool_calls_section_end|>",
|
||||
previous_token_ids=[section_begin_id],
|
||||
current_token_ids=[section_begin_id, section_end_id],
|
||||
delta_token_ids=[section_end_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
|
||||
# Subsequent reasoning text should be returned normally
|
||||
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
|
||||
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning",
|
||||
delta_text=" More reasoning",
|
||||
previous_token_ids=[section_begin_id, section_end_id],
|
||||
current_token_ids=[section_begin_id, section_end_id, 10, 11],
|
||||
delta_token_ids=[10, 11],
|
||||
request=None,
|
||||
)
|
||||
assert result3 is not None
|
||||
assert result3.content == " More reasoning"
|
||||
|
||||
|
||||
def test_empty_tool_section(kimi_k2_tool_parser):
|
||||
"""Test an empty tool section (begin immediately followed by end)."""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
|
||||
# Section begin
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning ",
|
||||
current_text="Reasoning <|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[1],
|
||||
current_token_ids=[1, section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Immediate section end
|
||||
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning <|tool_calls_section_begin|>",
|
||||
current_text="Reasoning <|tool_calls_section_begin|><|tool_calls_section_end|>",
|
||||
delta_text="<|tool_calls_section_end|>",
|
||||
previous_token_ids=[1, section_begin_id],
|
||||
current_token_ids=[1, section_begin_id, section_end_id],
|
||||
delta_token_ids=[section_end_id],
|
||||
request=None,
|
||||
)
|
||||
# Should exit cleanly without errors
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
|
||||
|
||||
def test_malformed_tool_section_recovery(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that the parser recovers from a malformed tool section
|
||||
that never closes properly.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
|
||||
# Enter tool section
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="<|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
# Simulate a lot of text without proper tool calls or section end
|
||||
# This should trigger the error recovery mechanism
|
||||
large_text = "x" * 10000 # Exceeds max_section_chars
|
||||
|
||||
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="<|tool_calls_section_begin|>",
|
||||
current_text="<|tool_calls_section_begin|>" + large_text,
|
||||
delta_text=large_text,
|
||||
previous_token_ids=[section_begin_id],
|
||||
current_token_ids=[section_begin_id] + list(range(100, 100 + len(large_text))),
|
||||
delta_token_ids=list(range(100, 100 + len(large_text))),
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Parser should have force-exited the tool section
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
# And returned the content as reasoning
|
||||
assert result2 is not None
|
||||
assert result2.content == large_text
|
||||
|
||||
|
||||
def test_state_reset(kimi_k2_tool_parser):
|
||||
"""Test that reset_streaming_state() properly clears all state."""
|
||||
# Put parser in a complex state
|
||||
kimi_k2_tool_parser.in_tool_section = True
|
||||
kimi_k2_tool_parser.token_buffer = "some buffer"
|
||||
kimi_k2_tool_parser.current_tool_id = 5
|
||||
kimi_k2_tool_parser.prev_tool_call_arr = [{"id": "test"}]
|
||||
kimi_k2_tool_parser.section_char_count = 1000
|
||||
|
||||
# Reset
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
# Verify all state is cleared
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
assert kimi_k2_tool_parser.token_buffer == ""
|
||||
assert kimi_k2_tool_parser.current_tool_id == -1
|
||||
assert kimi_k2_tool_parser.prev_tool_call_arr == []
|
||||
assert kimi_k2_tool_parser.section_char_count == 0
|
||||
assert kimi_k2_tool_parser.current_tool_name_sent is False
|
||||
assert kimi_k2_tool_parser.streamed_args_for_tool == []
|
||||
|
||||
|
||||
def test_section_begin_noise_tool_begin_same_chunk(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that begin→noise→tool_begin within the SAME chunk suppresses
|
||||
the noise text correctly (not just across chunks).
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
tool_call_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
|
||||
|
||||
# Single delta containing: section_begin + spurious text + tool_call_begin
|
||||
combined_text = "<|tool_calls_section_begin|> noise text <|tool_call_begin|>"
|
||||
|
||||
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning ",
|
||||
current_text="Reasoning " + combined_text,
|
||||
delta_text=combined_text,
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, section_begin_id, 3, 4, tool_call_begin_id],
|
||||
delta_token_ids=[section_begin_id, 3, 4, tool_call_begin_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# The noise text should NOT leak into content
|
||||
# Result should either be None/empty or start tool call parsing
|
||||
if result is not None and result.content is not None:
|
||||
# If content is returned, it should not contain the noise
|
||||
assert "noise text" not in result.content
|
||||
assert result.content == "" or result.content.strip() == ""
|
||||
|
||||
|
||||
def test_stream_ends_without_section_end_marker(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that if the stream ends (EOF) without a proper section end marker,
|
||||
the parser doesn't leak text, doesn't crash, and resets state cleanly.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
|
||||
# Enter tool section
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="<|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
# Some content in tool section
|
||||
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="<|tool_calls_section_begin|>",
|
||||
current_text="<|tool_calls_section_begin|> partial content",
|
||||
delta_text=" partial content",
|
||||
previous_token_ids=[section_begin_id],
|
||||
current_token_ids=[section_begin_id, 10, 11],
|
||||
delta_token_ids=[10, 11],
|
||||
request=None,
|
||||
)
|
||||
# Content should be suppressed
|
||||
assert result2.content == "" or result2.content is None
|
||||
|
||||
# Stream ends (EOF) - no more deltas, no section_end marker
|
||||
# Simulate this by manually checking state and resetting
|
||||
# (In real usage, the request handler would call reset_streaming_state)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True # Still in section
|
||||
|
||||
# Reset state (as would happen between requests)
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
# Verify clean slate
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
assert kimi_k2_tool_parser.token_buffer == ""
|
||||
|
||||
# Next request should work normally
|
||||
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="New reasoning",
|
||||
delta_text="New reasoning",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[20, 21],
|
||||
delta_token_ids=[20, 21],
|
||||
request=None,
|
||||
)
|
||||
assert result3 is not None
|
||||
assert result3.content == "New reasoning"
|
||||
|
||||
|
||||
def test_same_chunk_begin_and_end_markers(kimi_k2_tool_parser):
|
||||
"""
|
||||
CRITICAL TEST: Verify that when both section_begin and section_end
|
||||
markers appear in the SAME chunk, the parser correctly:
|
||||
1. Enters the tool section
|
||||
2. Immediately exits the tool section
|
||||
3. Does NOT get stuck in in_tool_section=True state
|
||||
|
||||
This tests the bug fix where elif was changed to if to handle
|
||||
both state transitions in a single delta.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
|
||||
# Single chunk with both markers (e.g., empty tool section)
|
||||
combined_delta = "<|tool_calls_section_begin|><|tool_calls_section_end|>"
|
||||
|
||||
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Some reasoning ",
|
||||
current_text="Some reasoning " + combined_delta,
|
||||
delta_text=combined_delta,
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, section_begin_id, section_end_id],
|
||||
delta_token_ids=[section_begin_id, section_end_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# CRITICAL: Parser should NOT be stuck in tool section
|
||||
assert kimi_k2_tool_parser.in_tool_section is False, (
|
||||
"Parser stuck in tool section after processing both begin/end in same chunk. "
|
||||
"This indicates the elif bug was not fixed."
|
||||
)
|
||||
|
||||
# Result should be empty or contain only stripped content
|
||||
assert result is not None
|
||||
assert result.content == "" or result.content is None
|
||||
|
||||
# Verify subsequent content streams correctly (not suppressed)
|
||||
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Some reasoning " + combined_delta,
|
||||
current_text="Some reasoning " + combined_delta + " More reasoning",
|
||||
delta_text=" More reasoning",
|
||||
previous_token_ids=[1, 2, section_begin_id, section_end_id],
|
||||
current_token_ids=[1, 2, section_begin_id, section_end_id, 10, 11],
|
||||
delta_token_ids=[10, 11],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# This content should NOT be suppressed (we're out of tool section)
|
||||
assert result2 is not None
|
||||
assert result2.content == " More reasoning"
|
||||
|
||||
|
||||
def test_same_chunk_begin_content_end_markers(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test the same-chunk scenario with actual content between markers.
|
||||
Example: <|tool_calls_section_begin|> text <|tool_calls_section_end|>
|
||||
all arriving in one delta. The key is that the state machine correctly
|
||||
transitions in and out within the same chunk.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
|
||||
# Chunk with begin, some whitespace/noise, and end all together
|
||||
# This simulates a tool section that opens and closes in the same chunk
|
||||
combined_delta = "<|tool_calls_section_begin|> <|tool_calls_section_end|>"
|
||||
|
||||
_result = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning ",
|
||||
current_text="Reasoning " + combined_delta,
|
||||
delta_text=combined_delta,
|
||||
previous_token_ids=[1],
|
||||
current_token_ids=[1, section_begin_id, 100, section_end_id],
|
||||
delta_token_ids=[section_begin_id, 100, section_end_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Parser should exit cleanly (not stuck in tool section)
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
|
||||
# Verify the fix: next content should stream normally, not be suppressed
|
||||
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Reasoning " + combined_delta,
|
||||
current_text="Reasoning " + combined_delta + " Done",
|
||||
delta_text=" Done",
|
||||
previous_token_ids=[1, section_begin_id, 100, section_end_id],
|
||||
current_token_ids=[1, section_begin_id, 100, section_end_id, 200],
|
||||
delta_token_ids=[200],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Content after section should be returned (not suppressed)
|
||||
assert result2 is not None
|
||||
assert result2.content == " Done"
|
||||
|
||||
|
||||
def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
|
||||
"""
|
||||
CRITICAL TEST (P1): Verify that when both <|tool_call_end|> and
|
||||
<|tool_calls_section_end|> appear in the SAME chunk, the parser:
|
||||
1. Processes the tool_call_end first (emits final arguments)
|
||||
2. THEN exits the section
|
||||
3. Does NOT drop the final tool call update
|
||||
4. Does NOT leak special tokens into reasoning
|
||||
|
||||
This tests the deferred section exit fix.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
|
||||
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
|
||||
|
||||
# Simulate a streaming sequence for a SHORT tool call (all in one chunk):
|
||||
# 1. Reasoning text
|
||||
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="Let me help. ",
|
||||
delta_text="Let me help. ",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[1, 2],
|
||||
delta_token_ids=[1, 2],
|
||||
request=None,
|
||||
)
|
||||
assert result1 is not None
|
||||
assert result1.content == "Let me help. "
|
||||
|
||||
# 2. Section begin
|
||||
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Let me help. ",
|
||||
current_text="Let me help. <|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
# 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK
|
||||
# This is the critical scenario for short tool calls
|
||||
combined = (
|
||||
'<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
|
||||
"<|tool_call_end|><|tool_calls_section_end|>"
|
||||
)
|
||||
|
||||
# Build up the previous text gradually to simulate realistic streaming
|
||||
prev_text = "Let me help. <|tool_calls_section_begin|>"
|
||||
curr_text = prev_text + combined
|
||||
|
||||
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=prev_text,
|
||||
current_text=curr_text,
|
||||
delta_text=combined,
|
||||
previous_token_ids=[1, 2, section_begin_id],
|
||||
current_token_ids=[
|
||||
1,
|
||||
2,
|
||||
section_begin_id,
|
||||
tool_begin_id,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
tool_end_id,
|
||||
section_end_id,
|
||||
],
|
||||
delta_token_ids=[tool_begin_id, 10, 11, 12, tool_end_id, section_end_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# CRITICAL: Parser should have exited section AFTER processing tool
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
|
||||
# Tool call should have been emitted (not dropped)
|
||||
# The result might be the tool name or None depending on state, but
|
||||
# importantly, it shouldn't be returning the literal tokens as content
|
||||
|
||||
if result3 is not None and result3.content is not None:
|
||||
# Verify no special tokens leaked into content
|
||||
assert "<|tool_call_end|>" not in result3.content
|
||||
assert "<|tool_calls_section_end|>" not in result3.content
|
||||
|
||||
# 4. Verify subsequent content streams normally
|
||||
result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=curr_text,
|
||||
current_text=curr_text + " Done",
|
||||
delta_text=" Done",
|
||||
previous_token_ids=[
|
||||
1,
|
||||
2,
|
||||
section_begin_id,
|
||||
tool_begin_id,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
tool_end_id,
|
||||
section_end_id,
|
||||
],
|
||||
current_token_ids=[
|
||||
1,
|
||||
2,
|
||||
section_begin_id,
|
||||
tool_begin_id,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
tool_end_id,
|
||||
section_end_id,
|
||||
20,
|
||||
],
|
||||
delta_token_ids=[20],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Content after tool section should stream normally
|
||||
assert result4 is not None
|
||||
assert result4.content == " Done"
|
||||
|
||||
@ -34,8 +34,27 @@ class KimiK2ToolParser(ToolParser):
|
||||
str
|
||||
] = [] # map what has been streamed for each tool so far to a list
|
||||
|
||||
# Section-level state management to prevent token leakage
|
||||
self.in_tool_section: bool = False
|
||||
self.token_buffer: str = ""
|
||||
# Buffer size: empirical worst-case for longest marker (~30 chars) * 2
|
||||
# + safety margin for unicode + partial overlap. Prevents unbounded growth.
|
||||
self.buffer_max_size: int = 1024
|
||||
self.section_char_count: int = 0 # Track characters processed in tool section
|
||||
self.max_section_chars: int = 8192 # Force exit if section exceeds this
|
||||
self._buffer_overflow_logged: bool = False # Log overflow once per session
|
||||
|
||||
# Support both singular and plural variants
|
||||
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
|
||||
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
|
||||
self.tool_calls_start_token_variants: list[str] = [
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_section_begin|>", # singular variant
|
||||
]
|
||||
self.tool_calls_end_token_variants: list[str] = [
|
||||
"<|tool_calls_section_end|>",
|
||||
"<|tool_call_section_end|>", # singular variant
|
||||
]
|
||||
|
||||
self.tool_call_start_token: str = "<|tool_call_begin|>"
|
||||
self.tool_call_end_token: str = "<|tool_call_end|>"
|
||||
@ -58,6 +77,18 @@ class KimiK2ToolParser(ToolParser):
|
||||
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
|
||||
|
||||
# Get token IDs for all variants
|
||||
self.tool_calls_start_token_ids: list[int] = [
|
||||
tid
|
||||
for variant in self.tool_calls_start_token_variants
|
||||
if (tid := self.vocab.get(variant)) is not None
|
||||
]
|
||||
self.tool_calls_end_token_ids: list[int] = [
|
||||
tid
|
||||
for variant in self.tool_calls_end_token_variants
|
||||
if (tid := self.vocab.get(variant)) is not None
|
||||
]
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
@ -70,6 +101,51 @@ class KimiK2ToolParser(ToolParser):
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
def _check_and_strip_markers(self, text: str) -> tuple[str, bool, bool]:
|
||||
"""
|
||||
Check for section begin/end markers in text and strip them.
|
||||
Returns: (cleaned_text, found_section_begin, found_section_end)
|
||||
"""
|
||||
found_begin = False
|
||||
found_end = False
|
||||
cleaned = text
|
||||
|
||||
# Check for section begin markers (any variant)
|
||||
for variant in self.tool_calls_start_token_variants:
|
||||
if variant in cleaned:
|
||||
cleaned = cleaned.replace(variant, "")
|
||||
found_begin = True
|
||||
|
||||
# Check for section end markers (any variant)
|
||||
for variant in self.tool_calls_end_token_variants:
|
||||
if variant in cleaned:
|
||||
cleaned = cleaned.replace(variant, "")
|
||||
found_end = True
|
||||
|
||||
return cleaned, found_begin, found_end
|
||||
|
||||
def _reset_section_state(self) -> None:
|
||||
"""Reset state when exiting tool section."""
|
||||
self.in_tool_section = False
|
||||
self.token_buffer = ""
|
||||
self.section_char_count = 0
|
||||
|
||||
def reset_streaming_state(self) -> None:
|
||||
"""
|
||||
Reset all streaming state. Call this between requests to prevent
|
||||
state leakage when parser instance is reused.
|
||||
"""
|
||||
# Reset section state
|
||||
self._reset_section_state()
|
||||
|
||||
# Reset parent class state
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr = []
|
||||
self.current_tool_id = -1
|
||||
self.streamed_args_for_tool = []
|
||||
|
||||
logger.debug("Streaming state reset")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
@ -131,13 +207,94 @@ class KimiK2ToolParser(ToolParser):
|
||||
) -> DeltaMessage | None:
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_calls_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
delta_text = delta_text.replace(self.tool_calls_start_token, "").replace(
|
||||
self.tool_calls_end_token, ""
|
||||
|
||||
# Flag to defer section exit until after tool parsing completes
|
||||
deferred_section_exit = False
|
||||
|
||||
# Add delta to buffer for split marker detection
|
||||
self.token_buffer += delta_text
|
||||
|
||||
# Enforce buffer size limit to prevent memory issues
|
||||
if len(self.token_buffer) > self.buffer_max_size:
|
||||
if not self._buffer_overflow_logged:
|
||||
logger.warning(
|
||||
"Token buffer exceeded max size (%d bytes), flushing excess. "
|
||||
"This may indicate very long markers or unusual tokenization.",
|
||||
self.buffer_max_size,
|
||||
)
|
||||
self._buffer_overflow_logged = True
|
||||
# Keep only the most recent content that might contain partial markers
|
||||
self.token_buffer = self.token_buffer[-self.buffer_max_size // 2 :]
|
||||
|
||||
# Check buffer for section markers (handles split tokens)
|
||||
buffered_text, found_section_begin, found_section_end = (
|
||||
self._check_and_strip_markers(self.token_buffer)
|
||||
)
|
||||
|
||||
# Track section state transitions
|
||||
if found_section_begin and not self.in_tool_section:
|
||||
logger.debug("Entering tool section")
|
||||
self.in_tool_section = True
|
||||
self.token_buffer = buffered_text # Use cleaned buffer
|
||||
self.section_char_count = 0 # Reset counter for new section
|
||||
if found_section_end and self.in_tool_section:
|
||||
logger.debug("Detected section end marker")
|
||||
# CRITICAL: Don't exit early if tool_call_end is in this chunk.
|
||||
# Tool parser must emit final arguments/close first to avoid dropping
|
||||
# the final tool update and leaking tokens into reasoning channel.
|
||||
has_tool_end = self.tool_call_end_token_id in delta_token_ids
|
||||
if has_tool_end:
|
||||
# Defer exit until after tool parsing completes
|
||||
deferred_section_exit = True
|
||||
logger.debug("Deferring section exit: tool_call_end in same chunk")
|
||||
self.token_buffer = buffered_text
|
||||
else:
|
||||
# No tool call ending, safe to exit immediately
|
||||
logger.debug("Exiting tool section")
|
||||
remaining = buffered_text
|
||||
self._reset_section_state()
|
||||
# Return remaining text as reasoning content if non-empty
|
||||
if remaining.strip():
|
||||
return DeltaMessage(content=remaining)
|
||||
# Return empty delta to maintain function contract
|
||||
# (always returns DeltaMessage)
|
||||
return DeltaMessage(content="")
|
||||
else:
|
||||
self.token_buffer = buffered_text
|
||||
|
||||
# Check if any variant of section start token is in current_token_ids
|
||||
has_section_token = any(
|
||||
tid in current_token_ids for tid in self.tool_calls_start_token_ids
|
||||
)
|
||||
|
||||
# Early return: if no section token detected yet, return as reasoning content
|
||||
if not has_section_token and not self.in_tool_section:
|
||||
logger.debug("No tool call tokens found!")
|
||||
# Don't clear buffer - it needs to accumulate partial markers across deltas
|
||||
# Buffer overflow is already protected by lines 215-224
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Strip section markers from delta_text for subsequent processing
|
||||
# NOTE: This preprocessing happens BEFORE the regex-based tool call
|
||||
# parsing (from PR #24847) to ensure markers are removed cleanly
|
||||
# before pattern matching. No double-stripping occurs because
|
||||
# section markers and tool call markers are distinct.
|
||||
delta_text, _, _ = self._check_and_strip_markers(delta_text)
|
||||
|
||||
# Error recovery: If in tool section for too long, force exit
|
||||
if self.in_tool_section:
|
||||
self.section_char_count += len(delta_text)
|
||||
if self.section_char_count > self.max_section_chars:
|
||||
logger.warning(
|
||||
"Tool section exceeded max length (%d chars), forcing exit. "
|
||||
"This may indicate malformed model output.",
|
||||
self.max_section_chars,
|
||||
)
|
||||
self._reset_section_state()
|
||||
# Deferred exit already handled by forced exit above
|
||||
# Return remaining content as reasoning (or empty delta if no content)
|
||||
return DeltaMessage(content=delta_text if delta_text.strip() else "")
|
||||
|
||||
try:
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
@ -158,6 +315,16 @@ class KimiK2ToolParser(ToolParser):
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text
|
||||
):
|
||||
# CRITICAL FIX: Suppress content if in tool section but
|
||||
# no tool calls started
|
||||
if self.in_tool_section and cur_tool_start_count == 0:
|
||||
logger.debug(
|
||||
"In tool section but no tool calls started yet. "
|
||||
"Suppressing: %s",
|
||||
delta_text,
|
||||
)
|
||||
# Return empty delta to maintain iterator contract
|
||||
return DeltaMessage(content="")
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
@ -209,6 +376,9 @@ class KimiK2ToolParser(ToolParser):
|
||||
):
|
||||
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
|
||||
logger.debug("attempting to close tool call, but no tool call")
|
||||
# Handle deferred section exit before returning
|
||||
if deferred_section_exit and self.in_tool_section:
|
||||
self._reset_section_state()
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
||||
if diff:
|
||||
@ -218,6 +388,9 @@ class KimiK2ToolParser(ToolParser):
|
||||
else diff
|
||||
)
|
||||
if '"}' not in delta_text:
|
||||
# Handle deferred section exit before returning
|
||||
if deferred_section_exit and self.in_tool_section:
|
||||
self._reset_section_state()
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
@ -227,6 +400,10 @@ class KimiK2ToolParser(ToolParser):
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
# Handle deferred section exit before returning
|
||||
if deferred_section_exit and self.in_tool_section:
|
||||
logger.debug("Completing deferred section exit")
|
||||
self._reset_section_state()
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
@ -240,9 +417,19 @@ class KimiK2ToolParser(ToolParser):
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
# Check if we're in tool section - if so, suppress
|
||||
if self.in_tool_section:
|
||||
logger.debug("In tool section, suppressing text generation")
|
||||
# Handle deferred section exit before returning
|
||||
if deferred_section_exit:
|
||||
self._reset_section_state()
|
||||
return DeltaMessage(content="")
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
# Handle deferred section exit before returning
|
||||
if deferred_section_exit and self.in_tool_section:
|
||||
self._reset_section_state()
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
@ -390,6 +577,11 @@ class KimiK2ToolParser(ToolParser):
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
# Handle deferred section exit after tool parsing completes
|
||||
if deferred_section_exit and self.in_tool_section:
|
||||
logger.debug("Completing deferred section exit")
|
||||
self._reset_section_state()
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user