From ec7cb1922478015b4e7eae73c6acde8b598a05a8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 6 Aug 2025 10:32:21 -0700 Subject: [PATCH] [gpt-oss] Add loop for built-in tool call (#22374) Signed-off-by: Woosuk Kwon Co-authored-by: LiuXiaoxuanPKU Co-authored-by: simon-mo Co-authored-by: Chen Zhang Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com> Co-authored-by: Yongye Zhu --- vllm/entrypoints/openai/serving_engine.py | 56 ++++++++++++++++++++ vllm/entrypoints/openai/serving_responses.py | 33 ++++++------ 2 files changed, 73 insertions(+), 16 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 71976fea1ee77..822f1868406c7 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -35,6 +35,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_mistral_chat_template, parse_chat_messages_futures, resolve_chat_template_content_format) +from vllm.entrypoints.context import ConversationContext from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, @@ -948,6 +949,61 @@ class OpenAIServing: return conversation, [request_prompt], [engine_prompt] + async def _generate_with_builtin_tools( + self, + request_id: str, + request_prompt: RequestPrompt, + engine_prompt: EngineTokensPrompt, + sampling_params: SamplingParams, + context: ConversationContext, + lora_request: Optional[LoRARequest] = None, + priority: int = 0, + **kwargs, + ): + orig_priority = priority + while True: + self._log_inputs( + request_id, + request_prompt, + params=sampling_params, + lora_request=lora_request, + ) + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id, + lora_request=lora_request, + priority=priority, + **kwargs, + ) + async for res in generator: + context.append_output(res) + # NOTE(woosuk): The stop condition is handled by the engine. + yield context + + if not context.need_builtin_tool_call(): + # The model did not ask for a tool call, so we're done. + break + + # Call the tool and update the context with the result. + tool_output = await context.call_tool() + context.append_output(tool_output) + + # TODO: uncomment this and enable tool output streaming + # yield context + + # Create inputs for the next turn. + # Render the next prompt token ids. + prompt_token_ids = context.render_for_completion() + engine_prompt = EngineTokensPrompt( + prompt_token_ids=prompt_token_ids) + request_prompt = prompt_token_ids + # Update the sampling params. + sampling_params.max_tokens = (self.max_model_len - + len(prompt_token_ids)) + # OPTIMIZATION + priority = orig_priority - 1 + def _load_prompt_embeds( self, prompt_embeds: Optional[Union[bytes, list[bytes]]], diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index e009529fbd2ad..f340854386f82 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -16,6 +16,7 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, ChatTemplateContentFormatOption) +from vllm.entrypoints.context import ConversationContext, SimpleContext from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -29,7 +30,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger -from vllm.outputs import RequestOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -187,7 +187,7 @@ class OpenAIServingResponses(OpenAIServing): raw_request.state.request_metadata = request_metadata # Schedule the request and get the result generator. - generators: list[AsyncGenerator[RequestOutput, None]] = [] + generators: list[AsyncGenerator[ConversationContext, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): default_max_tokens = self.max_model_len - len( @@ -195,21 +195,19 @@ class OpenAIServingResponses(OpenAIServing): sampling_params = request.to_sampling_params( default_max_tokens, self.default_sampling_params) - self._log_inputs(request.request_id, - request_prompts[i], - params=sampling_params, - lora_request=lora_request) - trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) - generator = self.engine_client.generate( - engine_prompt, - sampling_params, - request.request_id, + context = SimpleContext() + generator = self._generate_with_builtin_tools( + request_id=request.request_id, + request_prompt=request_prompts[i], + engine_prompt=engine_prompt, + sampling_params=sampling_params, + context=context, lora_request=lora_request, - trace_headers=trace_headers, priority=request.priority, + trace_headers=trace_headers, ) generators.append(generator) except ValueError as e: @@ -277,7 +275,7 @@ class OpenAIServingResponses(OpenAIServing): self, request: ResponsesRequest, sampling_params: SamplingParams, - result_generator: AsyncIterator[RequestOutput], + result_generator: AsyncIterator[ConversationContext], model_name: str, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, @@ -285,17 +283,20 @@ class OpenAIServingResponses(OpenAIServing): ) -> Union[ErrorResponse, ResponsesResponse]: if created_time is None: created_time = int(time.time()) - final_res: Optional[RequestOutput] = None + context: Optional[ConversationContext] = None try: - async for res in result_generator: - final_res = res + async for context in result_generator: + pass except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) + assert context is not None + assert isinstance(context, SimpleContext) + final_res = context.last_output assert final_res is not None assert len(final_res.outputs) == 1 final_output = final_res.outputs[0]