diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 51c618e9d51d7..94c24ce9b307a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -6,7 +6,7 @@ import json import time from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import Sequence as GenericSequence -from typing import Callable, Final, Optional, Union +from typing import Final, Optional, Union import jinja2 import partial_json_parser @@ -56,14 +56,13 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.entrypoints.utils import get_max_tokens from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput -from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizers import ( @@ -112,42 +111,15 @@ class OpenAIServingChat(OpenAIServing): self.trust_request_chat_template = trust_request_chat_template self.enable_log_outputs = enable_log_outputs + # set up reasoning parser + self.reasoning_parser = self._get_reasoning_parser( + reasoning_parser_name=reasoning_parser + ) # set up tool use self.enable_auto_tools: bool = enable_auto_tools - if self.enable_auto_tools: - logger.info( - '"auto" tool choice has been enabled please note that while' - " the parallel_tool_calls client option is preset for " - "compatibility reasons, it will be ignored." - ) - - self.reasoning_parser: Optional[Callable[[AnyTokenizer], ReasoningParser]] = ( - None + self.tool_parser = self._get_tool_parser( + tool_parser_name=tool_parser, enable_auto_tools=enable_auto_tools ) - if reasoning_parser: - try: - self.reasoning_parser = ReasoningParserManager.get_reasoning_parser( - reasoning_parser - ) - assert self.reasoning_parser is not None - except Exception as e: - raise TypeError(f"{reasoning_parser=} has not been registered") from e - self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None - if self.enable_auto_tools: - try: - if tool_parser == "pythonic" and self.model_config.model.startswith( - "meta-llama/Llama-3.2" - ): - logger.warning( - "Llama3.2 models may struggle to emit valid pythonic tool calls" - ) - self.tool_parser = ToolParserManager.get_tool_parser(tool_parser) - except Exception as e: - raise TypeError( - "Error: --enable-auto-tool-choice requires " - f"tool_parser:'{tool_parser}' which has not " - "been registered" - ) from e self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none self.enable_prompt_tokens_details = enable_prompt_tokens_details diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index edb8ecc94382a..0d1a525c6d3da 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -63,7 +63,7 @@ from vllm.entrypoints.openai.protocol import ( TranslationRequest, ) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import PromptType @@ -82,6 +82,7 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error ) from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tracing import ( contains_trace_headers, @@ -274,6 +275,50 @@ class OpenAIServing: self.model_config = self.models.model_config self.max_model_len = self.model_config.max_model_len + def _get_tool_parser( + self, tool_parser_name: Optional[str] = None, enable_auto_tools: bool = False + ) -> Optional[Callable[[AnyTokenizer], ToolParser]]: + """Get the tool parser based on the name.""" + parser = None + if not enable_auto_tools or tool_parser_name is None: + return parser + logger.info( + '"auto" tool choice has been enabled please note that while' + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored." + ) + + try: + if tool_parser_name == "pythonic" and self.model_config.model.startswith( + "meta-llama/Llama-3.2" + ): + logger.warning( + "Llama3.2 models may struggle to emit valid pythonic tool calls" + ) + parser = ToolParserManager.get_tool_parser(tool_parser_name) + except Exception as e: + raise TypeError( + "Error: --enable-auto-tool-choice requires " + f"tool_parser:'{tool_parser_name}' which has not " + "been registered" + ) from e + return parser + + def _get_reasoning_parser( + self, + reasoning_parser_name: str, + ) -> Optional[Callable[[AnyTokenizer], ReasoningParser]]: + """Get the reasoning parser based on the name.""" + parser = None + if not reasoning_parser_name: + return None + try: + parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name) + assert parser is not None + except Exception as e: + raise TypeError(f"{reasoning_parser_name=} has not been registered") from e + return parser + async def reset_mm_cache(self) -> None: self.processor.clear_mm_cache() await self.engine_client.reset_mm_cache() diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 48c5222bccc95..60f8b78ed1757 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -96,7 +96,6 @@ from vllm.logger import init_logger from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import SampleLogprobs from vllm.outputs import CompletionOutput -from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -136,18 +135,9 @@ class OpenAIServingResponses(OpenAIServing): self.chat_template_content_format: Final = chat_template_content_format self.enable_log_outputs = enable_log_outputs - self.reasoning_parser: Optional[Callable[[AnyTokenizer], ReasoningParser]] = ( - None + self.reasoning_parser = self._get_reasoning_parser( + reasoning_parser_name=reasoning_parser ) - if reasoning_parser: - try: - self.reasoning_parser = ReasoningParserManager.get_reasoning_parser( - reasoning_parser - ) - assert self.reasoning_parser is not None - except Exception as e: - raise TypeError(f"{reasoning_parser=} has not been registered") from e - self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_force_include_usage = enable_force_include_usage self.default_sampling_params = self.model_config.get_diff_sampling_param()