From c2d6d2f960176491e0499656409f30b947ee8027 Mon Sep 17 00:00:00 2001 From: Daniil Arapov <59310708+Delviet@users.noreply.github.com> Date: Sun, 2 Jun 2024 01:53:52 +0300 Subject: [PATCH] [Bugfix]: Fix issues related to prefix caching example (#5177) (#5180) --- examples/offline_inference_with_prefix.py | 47 ++++++++++++++++++----- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/examples/offline_inference_with_prefix.py b/examples/offline_inference_with_prefix.py index 7ed0563f14e0..166e98549b53 100644 --- a/examples/offline_inference_with_prefix.py +++ b/examples/offline_inference_with_prefix.py @@ -1,5 +1,8 @@ +from time import time + from vllm import LLM, SamplingParams +# Common prefix. prefix = ( "You are an expert school principal, skilled in effectively managing " "faculty and staff. Draft 10-15 questions for a potential first grade " @@ -18,36 +21,60 @@ prompts = [ "The capital of France is", "The future of AI is", ] + +generating_prompts = [prefix + prompt for prompt in prompts] + # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0) # Create an LLM. -llm = LLM(model="facebook/opt-125m", enable_prefix_caching=True) +regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4) -generating_prompts = [prefix + prompt for prompt in prompts] +prefix_cached_llm = LLM(model="facebook/opt-125m", + enable_prefix_caching=True, + gpu_memory_utilization=0.4) +print("Results without `enable_prefix_caching`") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. -outputs = llm.generate(generating_prompts, sampling_params) +start_time_regular = time() +outputs = regular_llm.generate(generating_prompts, sampling_params) +duration_regular = time() - start_time_regular + +regular_generated_texts = [] # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text + regular_generated_texts.append(generated_text) print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print("-" * 80) # The llm.generate call will batch all prompts and send the batch at once -# if resources allow. The prefix will only be cached after the first batch -# is processed, so we need to call generate once to calculate the prefix -# and cache it. -outputs = llm.generate(generating_prompts[0], sampling_params) +# if resources allow. +start_time_cached = time() +outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) +duration_cached = time() - start_time_cached -# Subsequent batches can leverage the cached prefix -outputs = llm.generate(generating_prompts, sampling_params) +print("Results with `enable_prefix_caching`") -# Print the outputs. You should see the same outputs as before +cached_generated_texts = [] +# Print the outputs. You should see the same outputs as before. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text + cached_generated_texts.append(generated_text) print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +print("-" * 80) + +# Compare the results and display the speedup +generated_same = all([ + regular_generated_texts[i] == cached_generated_texts[i] + for i in range(len(prompts)) +]) +print(f"Generated answers are the same: {generated_same}") + +speedup = round(duration_regular / duration_cached, 2) +print(f"Speed up of cached generation compared to the regular is: {speedup}")