mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-20 05:47:54 +08:00
[gpt-oss] Add loop for built-in tool call (#22374)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
2435ea7ed5
commit
ec7cb19224
@ -35,6 +35,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.context import ConversationContext
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
@ -948,6 +949,61 @@ class OpenAIServing:
|
||||
|
||||
return conversation, [request_prompt], [engine_prompt]
|
||||
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
request_prompt: RequestPrompt,
|
||||
engine_prompt: EngineTokensPrompt,
|
||||
sampling_params: SamplingParams,
|
||||
context: ConversationContext,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
priority: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
orig_priority = priority
|
||||
while True:
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
request_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
priority=priority,
|
||||
**kwargs,
|
||||
)
|
||||
async for res in generator:
|
||||
context.append_output(res)
|
||||
# NOTE(woosuk): The stop condition is handled by the engine.
|
||||
yield context
|
||||
|
||||
if not context.need_builtin_tool_call():
|
||||
# The model did not ask for a tool call, so we're done.
|
||||
break
|
||||
|
||||
# Call the tool and update the context with the result.
|
||||
tool_output = await context.call_tool()
|
||||
context.append_output(tool_output)
|
||||
|
||||
# TODO: uncomment this and enable tool output streaming
|
||||
# yield context
|
||||
|
||||
# Create inputs for the next turn.
|
||||
# Render the next prompt token ids.
|
||||
prompt_token_ids = context.render_for_completion()
|
||||
engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=prompt_token_ids)
|
||||
request_prompt = prompt_token_ids
|
||||
# Update the sampling params.
|
||||
sampling_params.max_tokens = (self.max_model_len -
|
||||
len(prompt_token_ids))
|
||||
# OPTIMIZATION
|
||||
priority = orig_priority - 1
|
||||
|
||||
def _load_prompt_embeds(
|
||||
self,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]],
|
||||
|
||||
@ -16,6 +16,7 @@ from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption)
|
||||
from vllm.entrypoints.context import ConversationContext, SimpleContext
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -29,7 +30,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -187,7 +187,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
@ -195,21 +195,19 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
|
||||
self._log_inputs(request.request_id,
|
||||
request_prompts[i],
|
||||
params=sampling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request.request_id,
|
||||
context = SimpleContext()
|
||||
generator = self._generate_with_builtin_tools(
|
||||
request_id=request.request_id,
|
||||
request_prompt=request_prompts[i],
|
||||
engine_prompt=engine_prompt,
|
||||
sampling_params=sampling_params,
|
||||
context=context,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
@ -277,7 +275,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
sampling_params: SamplingParams,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
result_generator: AsyncIterator[ConversationContext],
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
@ -285,17 +283,20 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
) -> Union[ErrorResponse, ResponsesResponse]:
|
||||
if created_time is None:
|
||||
created_time = int(time.time())
|
||||
final_res: Optional[RequestOutput] = None
|
||||
|
||||
context: Optional[ConversationContext] = None
|
||||
try:
|
||||
async for res in result_generator:
|
||||
final_res = res
|
||||
async for context in result_generator:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
assert context is not None
|
||||
assert isinstance(context, SimpleContext)
|
||||
final_res = context.last_output
|
||||
assert final_res is not None
|
||||
assert len(final_res.outputs) == 1
|
||||
final_output = final_res.outputs[0]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user