[Frontend] split append tool output (#28333)

Signed-off-by: Andrew Xia <axia@fb.com>
Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
Andrew Xia 2025-11-12 20:03:23 -08:00 committed by GitHub
parent a1d3866dda
commit 7c38ed0f1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 72 additions and 62 deletions

View File

@ -34,6 +34,9 @@ class MockConversationContext(ConversationContext):
def append_output(self, output) -> None: def append_output(self, output) -> None:
pass pass
def append_tool_output(self, output) -> None:
pass
async def call_tool(self): async def call_tool(self):
return [] return []

View File

@ -80,7 +80,11 @@ class TurnMetrics:
class ConversationContext(ABC): class ConversationContext(ABC):
@abstractmethod @abstractmethod
def append_output(self, output) -> None: def append_output(self, output: RequestOutput) -> None:
pass
@abstractmethod
def append_tool_output(self, output) -> None:
pass pass
@abstractmethod @abstractmethod
@ -151,6 +155,9 @@ class SimpleContext(ConversationContext):
self.num_cached_tokens = output.num_cached_tokens or 0 self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or []) self.num_output_tokens += len(output.outputs[0].token_ids or [])
def append_tool_output(self, output) -> None:
raise NotImplementedError("Should not be called.")
def need_builtin_tool_call(self) -> bool: def need_builtin_tool_call(self) -> bool:
return False return False
@ -205,8 +212,7 @@ class HarmonyContext(ConversationContext):
if self.parser.current_channel in {"analysis", "commentary"}: if self.parser.current_channel in {"analysis", "commentary"}:
self.num_reasoning_tokens += 1 self.num_reasoning_tokens += 1
def append_output(self, output: RequestOutput | list[Message]) -> None: def append_output(self, output: RequestOutput) -> None:
if isinstance(output, RequestOutput):
output_token_ids = output.outputs[0].token_ids output_token_ids = output.outputs[0].token_ids
self.parser = get_streamable_parser_for_assistant() self.parser = get_streamable_parser_for_assistant()
for token_id in output_token_ids: for token_id in output_token_ids:
@ -224,8 +230,9 @@ class HarmonyContext(ConversationContext):
output_msgs = self.parser.messages output_msgs = self.parser.messages
# The responses finish reason is set in the last message # The responses finish reason is set in the last message
self.finish_reason = output.outputs[0].finish_reason self.finish_reason = output.outputs[0].finish_reason
else: self._messages.extend(output_msgs)
# Tool output.
def append_tool_output(self, output: list[Message]) -> None:
output_msgs = output output_msgs = output
self._messages.extend(output_msgs) self._messages.extend(output_msgs)
@ -502,8 +509,7 @@ class StreamingHarmonyContext(HarmonyContext):
def messages(self) -> list: def messages(self) -> list:
return self._messages return self._messages
def append_output(self, output: RequestOutput | list[Message]) -> None: def append_output(self, output: RequestOutput) -> None:
if isinstance(output, RequestOutput):
# append_output is called for each output token in streaming case, # append_output is called for each output token in streaming case,
# so we only want to add the prompt tokens once for each message. # so we only want to add the prompt tokens once for each message.
if self.first_tok_of_message: if self.first_tok_of_message:
@ -528,7 +534,8 @@ class StreamingHarmonyContext(HarmonyContext):
self._messages.extend( self._messages.extend(
self.parser.messages[len(self._messages) - self.num_init_messages :] self.parser.messages[len(self._messages) - self.num_init_messages :]
) )
else:
def append_tool_output(self, output: list[Message]) -> None:
# Handle the case of tool output in direct message format # Handle the case of tool output in direct message format
assert len(output) == 1, "Tool output should be a single message" assert len(output) == 1, "Tool output should be a single message"
msg = output[0] msg = output[0]

View File

@ -1227,7 +1227,7 @@ class OpenAIServing:
# Call the tool and update the context with the result. # Call the tool and update the context with the result.
tool_output = await context.call_tool() tool_output = await context.call_tool()
context.append_output(tool_output) context.append_tool_output(tool_output)
# TODO: uncomment this and enable tool output streaming # TODO: uncomment this and enable tool output streaming
# yield context # yield context