diff --git a/examples/offline_inference_with_prefix.py b/examples/offline_inference_with_prefix.py index 166e98549b53..04c2843792a1 100644 --- a/examples/offline_inference_with_prefix.py +++ b/examples/offline_inference_with_prefix.py @@ -51,8 +51,10 @@ for output in outputs: print("-" * 80) -# The llm.generate call will batch all prompts and send the batch at once -# if resources allow. +# Warmup so that the shared prompt's KV cache is computed. +prefix_cached_llm.generate(generating_prompts[0], sampling_params) + +# Generate with prefix caching. start_time_cached = time() outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) duration_cached = time() - start_time_cached