mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 09:26:08 +08:00
[Frontend] split append tool output (#28333)
Signed-off-by: Andrew Xia <axia@fb.com> Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
parent
a1d3866dda
commit
7c38ed0f1c
@ -34,6 +34,9 @@ class MockConversationContext(ConversationContext):
|
||||
def append_output(self, output) -> None:
|
||||
pass
|
||||
|
||||
def append_tool_output(self, output) -> None:
|
||||
pass
|
||||
|
||||
async def call_tool(self):
|
||||
return []
|
||||
|
||||
|
||||
@ -80,7 +80,11 @@ class TurnMetrics:
|
||||
|
||||
class ConversationContext(ABC):
|
||||
@abstractmethod
|
||||
def append_output(self, output) -> None:
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_tool_output(self, output) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -151,6 +155,9 @@ class SimpleContext(ConversationContext):
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
|
||||
def append_tool_output(self, output) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
return False
|
||||
|
||||
@ -205,8 +212,7 @@ class HarmonyContext(ConversationContext):
|
||||
if self.parser.current_channel in {"analysis", "commentary"}:
|
||||
self.num_reasoning_tokens += 1
|
||||
|
||||
def append_output(self, output: RequestOutput | list[Message]) -> None:
|
||||
if isinstance(output, RequestOutput):
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
output_token_ids = output.outputs[0].token_ids
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
for token_id in output_token_ids:
|
||||
@ -224,8 +230,9 @@ class HarmonyContext(ConversationContext):
|
||||
output_msgs = self.parser.messages
|
||||
# The responses finish reason is set in the last message
|
||||
self.finish_reason = output.outputs[0].finish_reason
|
||||
else:
|
||||
# Tool output.
|
||||
self._messages.extend(output_msgs)
|
||||
|
||||
def append_tool_output(self, output: list[Message]) -> None:
|
||||
output_msgs = output
|
||||
self._messages.extend(output_msgs)
|
||||
|
||||
@ -502,8 +509,7 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
|
||||
def append_output(self, output: RequestOutput | list[Message]) -> None:
|
||||
if isinstance(output, RequestOutput):
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
# append_output is called for each output token in streaming case,
|
||||
# so we only want to add the prompt tokens once for each message.
|
||||
if self.first_tok_of_message:
|
||||
@ -528,7 +534,8 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
self._messages.extend(
|
||||
self.parser.messages[len(self._messages) - self.num_init_messages :]
|
||||
)
|
||||
else:
|
||||
|
||||
def append_tool_output(self, output: list[Message]) -> None:
|
||||
# Handle the case of tool output in direct message format
|
||||
assert len(output) == 1, "Tool output should be a single message"
|
||||
msg = output[0]
|
||||
|
||||
@ -1227,7 +1227,7 @@ class OpenAIServing:
|
||||
|
||||
# Call the tool and update the context with the result.
|
||||
tool_output = await context.call_tool()
|
||||
context.append_output(tool_output)
|
||||
context.append_tool_output(tool_output)
|
||||
|
||||
# TODO: uncomment this and enable tool output streaming
|
||||
# yield context
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user