mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 02:54:27 +08:00
[Frontend] split append tool output (#28333)
Signed-off-by: Andrew Xia <axia@fb.com> Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
parent
a1d3866dda
commit
7c38ed0f1c
@ -34,6 +34,9 @@ class MockConversationContext(ConversationContext):
|
|||||||
def append_output(self, output) -> None:
|
def append_output(self, output) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def append_tool_output(self, output) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def call_tool(self):
|
async def call_tool(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|||||||
@ -80,7 +80,11 @@ class TurnMetrics:
|
|||||||
|
|
||||||
class ConversationContext(ABC):
|
class ConversationContext(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def append_output(self, output) -> None:
|
def append_output(self, output: RequestOutput) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def append_tool_output(self, output) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -151,6 +155,9 @@ class SimpleContext(ConversationContext):
|
|||||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||||
|
|
||||||
|
def append_tool_output(self, output) -> None:
|
||||||
|
raise NotImplementedError("Should not be called.")
|
||||||
|
|
||||||
def need_builtin_tool_call(self) -> bool:
|
def need_builtin_tool_call(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -205,28 +212,28 @@ class HarmonyContext(ConversationContext):
|
|||||||
if self.parser.current_channel in {"analysis", "commentary"}:
|
if self.parser.current_channel in {"analysis", "commentary"}:
|
||||||
self.num_reasoning_tokens += 1
|
self.num_reasoning_tokens += 1
|
||||||
|
|
||||||
def append_output(self, output: RequestOutput | list[Message]) -> None:
|
def append_output(self, output: RequestOutput) -> None:
|
||||||
if isinstance(output, RequestOutput):
|
output_token_ids = output.outputs[0].token_ids
|
||||||
output_token_ids = output.outputs[0].token_ids
|
self.parser = get_streamable_parser_for_assistant()
|
||||||
self.parser = get_streamable_parser_for_assistant()
|
for token_id in output_token_ids:
|
||||||
for token_id in output_token_ids:
|
self.parser.process(token_id)
|
||||||
self.parser.process(token_id)
|
# Check if the current token is part of reasoning content
|
||||||
# Check if the current token is part of reasoning content
|
self._update_num_reasoning_tokens()
|
||||||
self._update_num_reasoning_tokens()
|
self._update_prefill_token_usage(output)
|
||||||
self._update_prefill_token_usage(output)
|
self._update_decode_token_usage(output)
|
||||||
self._update_decode_token_usage(output)
|
# Append current turn to all turn list for next turn's calculations
|
||||||
# Append current turn to all turn list for next turn's calculations
|
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
||||||
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
self.current_turn_metrics.reset()
|
||||||
self.current_turn_metrics.reset()
|
# append_output is called only once before tool calling
|
||||||
# append_output is called only once before tool calling
|
# in non-streaming case
|
||||||
# in non-streaming case
|
# so we can append all the parser messages to _messages
|
||||||
# so we can append all the parser messages to _messages
|
output_msgs = self.parser.messages
|
||||||
output_msgs = self.parser.messages
|
# The responses finish reason is set in the last message
|
||||||
# The responses finish reason is set in the last message
|
self.finish_reason = output.outputs[0].finish_reason
|
||||||
self.finish_reason = output.outputs[0].finish_reason
|
self._messages.extend(output_msgs)
|
||||||
else:
|
|
||||||
# Tool output.
|
def append_tool_output(self, output: list[Message]) -> None:
|
||||||
output_msgs = output
|
output_msgs = output
|
||||||
self._messages.extend(output_msgs)
|
self._messages.extend(output_msgs)
|
||||||
|
|
||||||
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
|
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
|
||||||
@ -502,45 +509,45 @@ class StreamingHarmonyContext(HarmonyContext):
|
|||||||
def messages(self) -> list:
|
def messages(self) -> list:
|
||||||
return self._messages
|
return self._messages
|
||||||
|
|
||||||
def append_output(self, output: RequestOutput | list[Message]) -> None:
|
def append_output(self, output: RequestOutput) -> None:
|
||||||
if isinstance(output, RequestOutput):
|
# append_output is called for each output token in streaming case,
|
||||||
# append_output is called for each output token in streaming case,
|
# so we only want to add the prompt tokens once for each message.
|
||||||
# so we only want to add the prompt tokens once for each message.
|
if self.first_tok_of_message:
|
||||||
if self.first_tok_of_message:
|
self._update_prefill_token_usage(output)
|
||||||
self._update_prefill_token_usage(output)
|
# Reset self.first_tok_of_message if needed:
|
||||||
# Reset self.first_tok_of_message if needed:
|
# if the current token is the last one of the current message
|
||||||
# if the current token is the last one of the current message
|
# (finished=True), then the next token processed will mark the
|
||||||
# (finished=True), then the next token processed will mark the
|
# beginning of a new message
|
||||||
# beginning of a new message
|
self.first_tok_of_message = output.finished
|
||||||
self.first_tok_of_message = output.finished
|
for tok in output.outputs[0].token_ids:
|
||||||
for tok in output.outputs[0].token_ids:
|
self.parser.process(tok)
|
||||||
self.parser.process(tok)
|
self._update_decode_token_usage(output)
|
||||||
self._update_decode_token_usage(output)
|
|
||||||
|
|
||||||
# For streaming, update previous turn when message is complete
|
# For streaming, update previous turn when message is complete
|
||||||
if output.finished:
|
if output.finished:
|
||||||
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
||||||
self.current_turn_metrics.reset()
|
self.current_turn_metrics.reset()
|
||||||
# Check if the current token is part of reasoning content
|
# Check if the current token is part of reasoning content
|
||||||
self._update_num_reasoning_tokens()
|
self._update_num_reasoning_tokens()
|
||||||
self.last_tok = tok
|
self.last_tok = tok
|
||||||
if len(self._messages) - self.num_init_messages < len(self.parser.messages):
|
if len(self._messages) - self.num_init_messages < len(self.parser.messages):
|
||||||
self._messages.extend(
|
self._messages.extend(
|
||||||
self.parser.messages[len(self._messages) - self.num_init_messages :]
|
self.parser.messages[len(self._messages) - self.num_init_messages :]
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Handle the case of tool output in direct message format
|
def append_tool_output(self, output: list[Message]) -> None:
|
||||||
assert len(output) == 1, "Tool output should be a single message"
|
# Handle the case of tool output in direct message format
|
||||||
msg = output[0]
|
assert len(output) == 1, "Tool output should be a single message"
|
||||||
# Sometimes the recipient is not set for tool messages,
|
msg = output[0]
|
||||||
# so we set it to "assistant"
|
# Sometimes the recipient is not set for tool messages,
|
||||||
if msg.author.role == Role.TOOL and msg.recipient is None:
|
# so we set it to "assistant"
|
||||||
msg.recipient = "assistant"
|
if msg.author.role == Role.TOOL and msg.recipient is None:
|
||||||
toks = self.encoding.render(msg)
|
msg.recipient = "assistant"
|
||||||
for tok in toks:
|
toks = self.encoding.render(msg)
|
||||||
self.parser.process(tok)
|
for tok in toks:
|
||||||
self.last_tok = toks[-1]
|
self.parser.process(tok)
|
||||||
# TODO: add tool_output messages to self._messages
|
self.last_tok = toks[-1]
|
||||||
|
# TODO: add tool_output messages to self._messages
|
||||||
|
|
||||||
def is_expecting_start(self) -> bool:
|
def is_expecting_start(self) -> bool:
|
||||||
return self.parser.state == StreamState.EXPECT_START
|
return self.parser.state == StreamState.EXPECT_START
|
||||||
|
|||||||
@ -1227,7 +1227,7 @@ class OpenAIServing:
|
|||||||
|
|
||||||
# Call the tool and update the context with the result.
|
# Call the tool and update the context with the result.
|
||||||
tool_output = await context.call_tool()
|
tool_output = await context.call_tool()
|
||||||
context.append_output(tool_output)
|
context.append_tool_output(tool_output)
|
||||||
|
|
||||||
# TODO: uncomment this and enable tool output streaming
|
# TODO: uncomment this and enable tool output streaming
|
||||||
# yield context
|
# yield context
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user