[Bugfix] Consistent ascii handling in tool parsers (#17704)

Signed-off-by: Sebastian Schönnenbeck <sebastian.schoennenbeck@comma-soft.com>
This commit is contained in:
Sebastian Schoennenbeck 2025-05-21 22:41:23 +02:00 committed by GitHub
parent 94d8ec8d2b
commit 1f079540db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 53 additions and 29 deletions

View File

@ -80,7 +80,8 @@ class Granite20bFCToolParser(ToolParser):
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"]),
arguments=json.dumps(function_call["arguments"],
ensure_ascii=False),
),
) for function_call in raw_function_calls
]
@ -166,7 +167,8 @@ class Granite20bFCToolParser(ToolParser):
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
@ -218,7 +220,8 @@ class Granite20bFCToolParser(ToolParser):
if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
@ -226,7 +229,8 @@ class Granite20bFCToolParser(ToolParser):
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(

View File

@ -67,7 +67,8 @@ class GraniteToolParser(ToolParser):
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"]),
arguments=json.dumps(function_call["arguments"],
ensure_ascii=False),
),
) for function_call in raw_function_calls
]
@ -151,7 +152,8 @@ class GraniteToolParser(ToolParser):
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
@ -197,7 +199,8 @@ class GraniteToolParser(ToolParser):
if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
@ -205,7 +208,8 @@ class GraniteToolParser(ToolParser):
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(
prev_args_json, cur_args_json)

View File

@ -133,7 +133,8 @@ class Internlm2ToolParser(ToolParser):
delta = None
# first time to get parameters
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
arguments_delta = cur_arguments_json[:cur_arguments_json.
index(delta_text) +
@ -148,8 +149,10 @@ class Internlm2ToolParser(ToolParser):
self.current_tool_id] += arguments_delta
# both prev and cur parameters, send the increase parameters
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
@ -190,7 +193,8 @@ class Internlm2ToolParser(ToolParser):
action_dict = json.loads(action)
name, parameters = action_dict['name'], json.dumps(
action_dict.get('parameters', action_dict.get('arguments',
{})))
{})),
ensure_ascii=False)
if not tools or name not in [t.function.name for t in tools]:
ExtractedToolCallInformation(tools_called=False,

View File

@ -96,8 +96,9 @@ class JambaToolParser(ToolParser):
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"])))
for function_call in raw_function_calls
arguments=json.dumps(function_call["arguments"],
ensure_ascii=False),
)) for function_call in raw_function_calls
]
content = model_output[:model_output.
@ -187,7 +188,7 @@ class JambaToolParser(ToolParser):
diff: Union[str, None] = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff).replace(
diff = json.dumps(diff, ensure_ascii=False).replace(
self.streamed_args_for_tool[self.current_tool_id],
"")
delta = DeltaMessage(tool_calls=[
@ -248,7 +249,8 @@ class JambaToolParser(ToolParser):
"mid-arguments")
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
logger.debug("finding %s in %s", new_text,
cur_arguments_json)
@ -267,8 +269,10 @@ class JambaToolParser(ToolParser):
self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
logger.debug("Searching for diff between \n%s\n%s",
cur_args_json, prev_args_json)

View File

@ -88,7 +88,8 @@ class Llama3JsonToolParser(ToolParser):
# function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"] \
if "arguments" in raw_function_call \
else raw_function_call["parameters"])))
else raw_function_call["parameters"],
ensure_ascii=False)))
for raw_function_call in function_call_arr
]
@ -174,7 +175,8 @@ class Llama3JsonToolParser(ToolParser):
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
@ -226,7 +228,8 @@ class Llama3JsonToolParser(ToolParser):
if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
@ -234,7 +237,8 @@ class Llama3JsonToolParser(ToolParser):
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(

View File

@ -79,10 +79,11 @@ class Phi4MiniJsonToolParser(ToolParser):
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
raw_function_call["arguments"] if "arguments" in
raw_function_call else
raw_function_call["parameters"])))
for raw_function_call in function_call_arr
raw_function_call["arguments"]
if "arguments" in raw_function_call else
raw_function_call["parameters"],
ensure_ascii=False),
)) for raw_function_call in function_call_arr
]
# get any content before the tool call

View File

@ -200,9 +200,12 @@ def _handle_single_tool(call: ast.Call) -> ToolCall:
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(arguments)))
return ToolCall(
type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(arguments,
ensure_ascii=False)),
)
def _make_valid_python(text: str) -> Union[tuple[str, str], None]: