mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 15:07:01 +08:00
[gpt-oss] Add loop for built-in tool call (#22374)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
2435ea7ed5
commit
ec7cb19224
@ -35,6 +35,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
|||||||
apply_mistral_chat_template,
|
apply_mistral_chat_template,
|
||||||
parse_chat_messages_futures,
|
parse_chat_messages_futures,
|
||||||
resolve_chat_template_content_format)
|
resolve_chat_template_content_format)
|
||||||
|
from vllm.entrypoints.context import ConversationContext
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
@ -948,6 +949,61 @@ class OpenAIServing:
|
|||||||
|
|
||||||
return conversation, [request_prompt], [engine_prompt]
|
return conversation, [request_prompt], [engine_prompt]
|
||||||
|
|
||||||
|
async def _generate_with_builtin_tools(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
request_prompt: RequestPrompt,
|
||||||
|
engine_prompt: EngineTokensPrompt,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
context: ConversationContext,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
orig_priority = priority
|
||||||
|
while True:
|
||||||
|
self._log_inputs(
|
||||||
|
request_id,
|
||||||
|
request_prompt,
|
||||||
|
params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
)
|
||||||
|
generator = self.engine_client.generate(
|
||||||
|
engine_prompt,
|
||||||
|
sampling_params,
|
||||||
|
request_id,
|
||||||
|
lora_request=lora_request,
|
||||||
|
priority=priority,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
async for res in generator:
|
||||||
|
context.append_output(res)
|
||||||
|
# NOTE(woosuk): The stop condition is handled by the engine.
|
||||||
|
yield context
|
||||||
|
|
||||||
|
if not context.need_builtin_tool_call():
|
||||||
|
# The model did not ask for a tool call, so we're done.
|
||||||
|
break
|
||||||
|
|
||||||
|
# Call the tool and update the context with the result.
|
||||||
|
tool_output = await context.call_tool()
|
||||||
|
context.append_output(tool_output)
|
||||||
|
|
||||||
|
# TODO: uncomment this and enable tool output streaming
|
||||||
|
# yield context
|
||||||
|
|
||||||
|
# Create inputs for the next turn.
|
||||||
|
# Render the next prompt token ids.
|
||||||
|
prompt_token_ids = context.render_for_completion()
|
||||||
|
engine_prompt = EngineTokensPrompt(
|
||||||
|
prompt_token_ids=prompt_token_ids)
|
||||||
|
request_prompt = prompt_token_ids
|
||||||
|
# Update the sampling params.
|
||||||
|
sampling_params.max_tokens = (self.max_model_len -
|
||||||
|
len(prompt_token_ids))
|
||||||
|
# OPTIMIZATION
|
||||||
|
priority = orig_priority - 1
|
||||||
|
|
||||||
def _load_prompt_embeds(
|
def _load_prompt_embeds(
|
||||||
self,
|
self,
|
||||||
prompt_embeds: Optional[Union[bytes, list[bytes]]],
|
prompt_embeds: Optional[Union[bytes, list[bytes]]],
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from vllm.config import ModelConfig
|
|||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||||
ChatTemplateContentFormatOption)
|
ChatTemplateContentFormatOption)
|
||||||
|
from vllm.entrypoints.context import ConversationContext, SimpleContext
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -29,7 +30,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
|||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
|
||||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
@ -187,7 +187,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
raw_request.state.request_metadata = request_metadata
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
||||||
try:
|
try:
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
default_max_tokens = self.max_model_len - len(
|
default_max_tokens = self.max_model_len - len(
|
||||||
@ -195,21 +195,19 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens, self.default_sampling_params)
|
default_max_tokens, self.default_sampling_params)
|
||||||
|
|
||||||
self._log_inputs(request.request_id,
|
|
||||||
request_prompts[i],
|
|
||||||
params=sampling_params,
|
|
||||||
lora_request=lora_request)
|
|
||||||
|
|
||||||
trace_headers = (None if raw_request is None else await
|
trace_headers = (None if raw_request is None else await
|
||||||
self._get_trace_headers(raw_request.headers))
|
self._get_trace_headers(raw_request.headers))
|
||||||
|
|
||||||
generator = self.engine_client.generate(
|
context = SimpleContext()
|
||||||
engine_prompt,
|
generator = self._generate_with_builtin_tools(
|
||||||
sampling_params,
|
request_id=request.request_id,
|
||||||
request.request_id,
|
request_prompt=request_prompts[i],
|
||||||
|
engine_prompt=engine_prompt,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
context=context,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
|
||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
|
trace_headers=trace_headers,
|
||||||
)
|
)
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -277,7 +275,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
request: ResponsesRequest,
|
request: ResponsesRequest,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
result_generator: AsyncIterator[RequestOutput],
|
result_generator: AsyncIterator[ConversationContext],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
request_metadata: RequestResponseMetadata,
|
request_metadata: RequestResponseMetadata,
|
||||||
@ -285,17 +283,20 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
) -> Union[ErrorResponse, ResponsesResponse]:
|
) -> Union[ErrorResponse, ResponsesResponse]:
|
||||||
if created_time is None:
|
if created_time is None:
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
final_res: Optional[RequestOutput] = None
|
|
||||||
|
|
||||||
|
context: Optional[ConversationContext] = None
|
||||||
try:
|
try:
|
||||||
async for res in result_generator:
|
async for context in result_generator:
|
||||||
final_res = res
|
pass
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
assert context is not None
|
||||||
|
assert isinstance(context, SimpleContext)
|
||||||
|
final_res = context.last_output
|
||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
assert len(final_res.outputs) == 1
|
assert len(final_res.outputs) == 1
|
||||||
final_output = final_res.outputs[0]
|
final_output = final_res.outputs[0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user