[gpt-oss] Add loop for built-in tool call (#22374)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Woosuk Kwon 2025-08-06 10:32:21 -07:00 committed by GitHub
parent 2435ea7ed5
commit ec7cb19224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 16 deletions

View File

@ -35,6 +35,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_mistral_chat_template,
parse_chat_messages_futures,
resolve_chat_template_content_format)
from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
@ -948,6 +949,61 @@ class OpenAIServing:
return conversation, [request_prompt], [engine_prompt]
async def _generate_with_builtin_tools(
self,
request_id: str,
request_prompt: RequestPrompt,
engine_prompt: EngineTokensPrompt,
sampling_params: SamplingParams,
context: ConversationContext,
lora_request: Optional[LoRARequest] = None,
priority: int = 0,
**kwargs,
):
orig_priority = priority
while True:
self._log_inputs(
request_id,
request_prompt,
params=sampling_params,
lora_request=lora_request,
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
priority=priority,
**kwargs,
)
async for res in generator:
context.append_output(res)
# NOTE(woosuk): The stop condition is handled by the engine.
yield context
if not context.need_builtin_tool_call():
# The model did not ask for a tool call, so we're done.
break
# Call the tool and update the context with the result.
tool_output = await context.call_tool()
context.append_output(tool_output)
# TODO: uncomment this and enable tool output streaming
# yield context
# Create inputs for the next turn.
# Render the next prompt token ids.
prompt_token_ids = context.render_for_completion()
engine_prompt = EngineTokensPrompt(
prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
# Update the sampling params.
sampling_params.max_tokens = (self.max_model_len -
len(prompt_token_ids))
# OPTIMIZATION
priority = orig_priority - 1
def _load_prompt_embeds(
self,
prompt_embeds: Optional[Union[bytes, list[bytes]]],

View File

@ -16,6 +16,7 @@ from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ChatTemplateContentFormatOption)
from vllm.entrypoints.context import ConversationContext, SimpleContext
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
@ -29,7 +30,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -187,7 +187,7 @@ class OpenAIServingResponses(OpenAIServing):
raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
generators: list[AsyncGenerator[ConversationContext, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
default_max_tokens = self.max_model_len - len(
@ -195,21 +195,19 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params)
self._log_inputs(request.request_id,
request_prompts[i],
params=sampling_params,
lora_request=lora_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request.request_id,
context = SimpleContext()
generator = self._generate_with_builtin_tools(
request_id=request.request_id,
request_prompt=request_prompts[i],
engine_prompt=engine_prompt,
sampling_params=sampling_params,
context=context,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
trace_headers=trace_headers,
)
generators.append(generator)
except ValueError as e:
@ -277,7 +275,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
sampling_params: SamplingParams,
result_generator: AsyncIterator[RequestOutput],
result_generator: AsyncIterator[ConversationContext],
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
@ -285,17 +283,20 @@ class OpenAIServingResponses(OpenAIServing):
) -> Union[ErrorResponse, ResponsesResponse]:
if created_time is None:
created_time = int(time.time())
final_res: Optional[RequestOutput] = None
context: Optional[ConversationContext] = None
try:
async for res in result_generator:
final_res = res
async for context in result_generator:
pass
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
assert context is not None
assert isinstance(context, SimpleContext)
final_res = context.last_output
assert final_res is not None
assert len(final_res.outputs) == 1
final_output = final_res.outputs[0]