From 93269bb43e4815fc665b4a6628fa9444a5062a4d Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Mon, 28 Jul 2025 10:46:38 +0800 Subject: [PATCH] Fix GLM tool parser (#21668) Co-authored-by: Chenhui Zhang --- .../tool_parsers/glm4_moe_tool_parser.py | 443 +++++------------- 1 file changed, 113 insertions(+), 330 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py index 40cdf7275a8f6..8fd14f171d0af 100644 --- a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# code modified from deepseekv3_tool_parser.py +import ast +import json from collections.abc import Sequence -from typing import Union +from typing import Any, Optional, Union import regex as re from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, @@ -34,36 +36,13 @@ class Glm4MoeModelToolParser(ToolParser): self.tool_calls_start_token = self.tool_call_start_token - # Updated regex for the XML-based format - self.tool_call_regex = re.compile( - r"\s*" - r"(?P[^\n<]+)\s*" # 函数名(到换行或 <) - r"(?P(?:\s*[^<]+\s*" - r"[^<]*\s*)*)\s*" - r"", - re.DOTALL, - ) - - # Regex for parsing individual arguments - self.arg_regex = re.compile( - r"(?P[^<]+)\s*(?P[^<]*)", - re.DOTALL, - ) - - # Streaming regex - self.stream_tool_call_portion_regex = re.compile( - r"(?P[^\n<]+)\s*" - r"(?P(?:\s*[^<]+\s*" - r"[^<]*\s*)*)", - re.DOTALL, - ) - - # For streaming, we also need a regex to match just the function name - self.stream_tool_call_name_regex = re.compile( - r"(?P[^\n<]+)", - re.DOTALL, - ) - + self.func_call_regex = re.compile(r".*?", + re.DOTALL) + self.func_detail_regex = re.compile( + r"([^\n]*)\n(.*)", re.DOTALL) + self.func_arg_regex = re.compile( + r"(.*?)\s*(.*?)", + re.DOTALL) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " @@ -72,20 +51,7 @@ class Glm4MoeModelToolParser(ToolParser): self.tool_call_start_token_id = self.vocab.get( self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - - def _parse_arguments(self, args_text: str) -> str: - """Parse XML-based arguments into JSON format.""" - if not args_text or not args_text.strip(): - return "{}" - - args_dict = {} - matches = self.arg_regex.findall(args_text) - - for key, value in matches: - args_dict[key.strip()] = value.strip() - - import json - return json.dumps(args_dict, ensure_ascii=False) + self._buffer = "" def extract_tool_calls( self, @@ -93,52 +59,67 @@ class Glm4MoeModelToolParser(ToolParser): request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - # sanity check; avoid unnecessary processing - if self.tool_calls_start_token not in model_output: + def _is_string_type( + tool_name: str, arg_name: str, + tools: Optional[list[ChatCompletionToolsParam]]) -> bool: + if tools is None: + return False + for tool in tools: + if tool.function.name == tool_name: + if tool.function.parameters is None: + return False + arg_type = tool.function.parameters.get( + "properties", {}).get(arg_name, {}).get("type", None) + return arg_type == "string" + logger.warning("No tool named '%s'.", tool_name) + return False + + def _deserialize(value: str) -> Any: + try: + return json.loads(value) + except Exception: + pass + + try: + return ast.literal_eval(value) + except Exception: + pass + return value + + matched_tool_calls = self.func_call_regex.findall(model_output) + logger.debug("model_output: %s", model_output) + try: + tool_calls = [] + for match in matched_tool_calls: + tc_detail = self.func_detail_regex.search(match) + tc_name = tc_detail.group(1) + tc_args = tc_detail.group(2) + pairs = self.func_arg_regex.findall(tc_args) + arg_dct = {} + for key, value in pairs: + arg_key = key.strip() + arg_val = value.strip() + if not _is_string_type(tc_name, arg_key, request.tools): + arg_val = _deserialize(arg_val) + logger.debug("arg_key = %s, arg_val = %s", arg_key, + arg_val) + arg_dct[arg_key] = arg_val + tool_calls.append( + ToolCall(type="function", + function=FunctionCall( + name=tc_name, arguments=json.dumps(arg_dct)))) + except Exception: + logger.exception("Failed to extract tool call spec") return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) - - try: - # Find all tool calls in the output - function_call_matches = self.tool_call_regex.findall(model_output) - - logger.debug("function_call_matches: %s", function_call_matches) - - if not function_call_matches: - return ExtractedToolCallInformation( - tools_called=False, - tool_calls=[], - content=model_output, - ) - - tool_calls = [] - for i, match in enumerate(function_call_matches): - function_name, function_args_xml = match - function_name = function_name.strip() - - # Parse XML arguments to JSON - function_args_json = self._parse_arguments(function_args_xml) - - tool_calls.append( - ToolCall( - id=f"call_{i}", - type='function', - function=FunctionCall(name=function_name, - arguments=function_args_json), - )) - - # Extract content before the first tool call - content = model_output[:model_output.find(self. - tool_calls_start_token)] - return ExtractedToolCallInformation( - tools_called=bool(tool_calls), - tool_calls=tool_calls, - content=content.strip() if content.strip() else None, - ) - - except Exception: - logger.exception("Error in extracting tool call from response.") + else: + if len(tool_calls) > 0: + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation(tools_called=True, + tool_calls=tool_calls, + content=content) return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) @@ -153,250 +134,52 @@ class Glm4MoeModelToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - - logger.debug("delta_text: %s", delta_text) - logger.debug("delta_token_ids: %s", delta_token_ids) - # check to see if we should be streaming a tool call - is there a - if self.tool_call_start_token_id not in current_token_ids: - logger.debug("No tool call tokens found!") - return DeltaMessage(content=delta_text) - delta_text = delta_text.replace(self.tool_calls_start_token, - "").replace(self.tool_call_end_token, - "") - try: - - # figure out where we are in the parsing by counting tool call - # start & end tags - prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) - cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) - tool_call_portion = None - text_portion = None - - # case: if we're generating text, OR rounding out a tool call - if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count - and self.tool_call_end_token not in delta_text): - logger.debug("Generating text content! skipping tool parsing.") - return DeltaMessage(content=delta_text) - - if self.tool_call_end_token in delta_text: - logger.debug("tool_call_end_token in delta_text") - full_text = current_text + delta_text - tool_call_portion = full_text.split( - self.tool_call_start_token)[-1].split( - self.tool_call_end_token)[0].rstrip() - delta_text = delta_text.split( - self.tool_call_end_token)[0].rstrip() - text_portion = delta_text.split( - self.tool_call_end_token)[-1].lstrip() - - # case -- we're starting a new tool call - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): - if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] - else: - tool_call_portion = None - delta = None - - text_portion = None - - # set cursors and state appropriately - self.current_tool_id += 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - logger.debug("Starting on a new tool %s", self.current_tool_id) - - # case -- we're updating an existing tool call - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): - - # get the portion of the text that's the tool call - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] - text_portion = None - - # case -- the current tool call is being closed. - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count >= prev_tool_end_count): - if self.prev_tool_call_arr is None or len( - self.prev_tool_call_arr) == 0: - logger.debug( - "attempting to close tool call, but no tool call") - return None - diff = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") - if diff: - diff = (diff.encode("utf-8").decode("unicode_escape") - if diff is str else diff) - if '"}' not in delta_text: - return None - end_loc = delta_text.rindex('"}') - diff = delta_text[:end_loc] + '"}' - logger.debug( - "Finishing tool and found diff that had not " - "been streamed yet: %s", - diff, - ) - self.streamed_args_for_tool[self.current_tool_id] += diff - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump(exclude_none=True), - ) - ]) - - # case -- otherwise we're just generating text - else: - text = delta_text.replace(self.tool_call_start_token, "") - text = text.replace(self.tool_call_end_token, "") - delta = DeltaMessage(tool_calls=[], content=text) - return delta - - current_tool_call = dict() - if tool_call_portion: - current_tool_call_matches = ( - self.stream_tool_call_portion_regex.match( - tool_call_portion)) - if current_tool_call_matches: - tool_id, tool_args = (current_tool_call_matches.groups()) - tool_name = tool_id.split('.')[1].split(':')[0] - current_tool_call['id'] = tool_id - current_tool_call["name"] = tool_name - current_tool_call["arguments"] = tool_args - else: - current_tool_call_name_matches = ( - self.stream_tool_call_name_regex.match( - tool_call_portion)) - if current_tool_call_name_matches: - tool_id_str, = current_tool_call_name_matches.groups() - tool_name = tool_id_str.split('.')[1].split(':')[0] - current_tool_call['id'] = tool_id_str - current_tool_call["name"] = tool_name - current_tool_call["arguments"] = "" - else: - logger.debug("Not enough token") - return None - - # case - we haven't sent the tool name yet. If it's available, send - # it. otherwise, wait until it's available. - if not self.current_tool_name_sent: - if current_tool_call is None: - return None - function_name: Union[str, None] = current_tool_call.get("name") - tool_id = current_tool_call.get("id") - if function_name: - self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), - ) - ]) - else: - return None - - # case -- otherwise, send the tool call delta - - # if the tool call portion is None, send the delta as text - if tool_call_portion is None: - # if there's text but not tool calls, send that - - # otherwise None to skip chunk - delta = (DeltaMessage( - content=delta_text) if text_portion is not None else None) - return delta - - # now, the nitty-gritty of tool calls - # now we have the portion to parse as tool call. - - logger.debug("Trying to parse current tool call with ID %s", - self.current_tool_id) - - # if we're starting a new tool call, push an empty object in as - # a placeholder for the arguments - if len(self.prev_tool_call_arr) <= self.current_tool_id: + self._buffer += delta_text + cur_text = self._buffer + start_idx = cur_text.find(self.tool_call_start_token) + if start_idx == -1: + self._buffer = "" + if self.current_tool_id > 0: + cur_text = "" + return DeltaMessage(content=cur_text) + logger.debug("cur_text = %s", cur_text) + end_idx = cur_text.find(self.tool_call_end_token) + if end_idx != -1: + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] + while len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") - # main logic for tool parsing here - compare prev. partially-parsed - # JSON to the current partially-parsed JSON - prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") - cur_arguments = current_tool_call.get("arguments") + extracted_tool_calls = self.extract_tool_calls( + cur_text[:end_idx + len(self.tool_call_end_token)], request) - logger.debug("diffing old arguments: %s", prev_arguments) - logger.debug("against new ones: %s", cur_arguments) - - # case -- no arguments have been created yet. skip sending a delta. - if not cur_arguments and not prev_arguments: - logger.debug("Skipping text %s - no arguments", delta_text) - delta = None - - # case -- prev arguments are defined, but non are now. - # probably impossible, but not a fatal error - just keep going - elif not cur_arguments and prev_arguments: - logger.error("should be impossible to have arguments reset " - "mid-call. skipping streaming anything.") - delta = None - - # case -- we now have the first info about arguments available from - # autocompleting the JSON - elif cur_arguments and not prev_arguments: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=cur_arguments).model_dump( - exclude_none=True), - ) + if len(extracted_tool_calls.tool_calls) == 0: + logger.warning("Failed to extract any tool calls.") + return None + tool_call = extracted_tool_calls.tool_calls[0] + self.prev_tool_call_arr[self.current_tool_id] = { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments) + } + self.streamed_args_for_tool[ + self.current_tool_id] = tool_call.function.arguments + delta = DeltaMessage( + content=extracted_tool_calls.content, + tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + id=tool_call.id, + type=tool_call.type, + function=DeltaFunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments)) ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments - - # last case -- we have an update to existing arguments. - elif cur_arguments and prev_arguments: - if (isinstance(delta_text, str) - and cur_arguments != prev_arguments - and len(cur_arguments) > len(prev_arguments) - and cur_arguments.startswith(prev_arguments)): - delta_arguments = cur_arguments[len(prev_arguments):] - logger.debug("got diff %s", delta_text) - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=delta_arguments).model_dump( - exclude_none=True), - ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments - else: - delta = None - - # handle saving the state for the current tool into - # the "prev" list for use in diffing for the next iteration - if self.current_tool_id == len(self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call - else: - self.prev_tool_call_arr.append(current_tool_call) - + self.current_tool_id += 1 + self._buffer = cur_text[end_idx + len(self.tool_call_end_token):] return delta - except Exception: - logger.exception("Error trying to handle streaming tool call.") - return None # do not stream a delta. skip this token ID. + self._buffer = cur_text[start_idx:] + return DeltaMessage(content=cur_text[:start_idx])