diff --git a/examples/online_serving/openai_chat_completion_client.py b/examples/online_serving/openai_chat_completion_client.py index 74e0c045d6214..bf99777d56974 100644 --- a/examples/online_serving/openai_chat_completion_client.py +++ b/examples/online_serving/openai_chat_completion_client.py @@ -3,6 +3,9 @@ NOTE: start a supported chat completion model server with `vllm serve`, e.g. vllm serve meta-llama/Llama-2-7b-chat-hf """ + +import argparse + from openai import OpenAI # Modify OpenAI's API key and API base to use vLLM's API server. @@ -24,7 +27,15 @@ messages = [{ }] -def main(): +def parse_args(): + parser = argparse.ArgumentParser(description="Client for vLLM API server") + parser.add_argument("--stream", + action="store_true", + help="Enable streaming response") + return parser.parse_args() + + +def main(args): client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") api_key=openai_api_key, @@ -34,16 +45,23 @@ def main(): models = client.models.list() model = models.data[0].id + # Chat Completion API chat_completion = client.chat.completions.create( messages=messages, model=model, + stream=args.stream, ) print("-" * 50) print("Chat completion results:") - print(chat_completion) + if args.stream: + for c in chat_completion: + print(c) + else: + print(chat_completion) print("-" * 50) if __name__ == "__main__": - main() + args = parse_args() + main(args)