[Misc] refactor disaggregated-prefill-v1 example (#18474)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
Reid 2025-05-21 19:10:14 +08:00 committed by GitHub
parent 907f935de9
commit 107f5fc4cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 87 additions and 62 deletions

View File

@ -5,5 +5,6 @@ This example contains scripts that demonstrate disaggregated prefill in the offl
## Files ## Files
- `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially. - `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially.
- Make sure you are in the `examples/offline_inference/disaggregated-prefill-v1` directory before running `run.sh`.
- `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`. - `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`.
- `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`. - `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`.

View File

@ -3,17 +3,23 @@
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
# Read prompts from output.txt
def read_prompts():
"""Read prompts from output.txt"""
prompts = [] prompts = []
try: try:
with open("output.txt") as f: with open("output.txt") as f:
for line in f: for line in f:
prompts.append(line.strip()) prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from output.txt") print(f"Loaded {len(prompts)} prompts from output.txt")
return prompts
except FileNotFoundError: except FileNotFoundError:
print("Error: output.txt file not found") print("Error: output.txt file not found")
exit(-1) exit(-1)
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
@ -31,7 +37,13 @@ llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
# 1ST generation (prefill instance) # 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
print("-" * 30)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
if __name__ == "__main__":
main()

View File

@ -3,15 +3,21 @@
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
def read_prompts():
context = "Hi " * 1000 context = "Hi " * 1000
context2 = "Hey " * 500 context2 = "Hey " * 500
prompts = [ return [
context + "Hello, my name is", context + "Hello, my name is",
context + "The capital of France is", context + "The capital of France is",
context2 + "Your name is", context2 + "Your name is",
context2 + "The capital of China is", context2 + "The capital of China is",
] ]
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
@ -31,14 +37,20 @@ outputs = llm.generate(
) )
new_prompts = [] new_prompts = []
print("-" * 30)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text) new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
# Write new_prompts to output.txt # Write new_prompts to output.txt
with open("output.txt", "w") as f: with open("output.txt", "w") as f:
for prompt in new_prompts: for prompt in new_prompts:
f.write(prompt + "\n") f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to output.txt") print(f"Saved {len(new_prompts)} prompts to output.txt")
if __name__ == "__main__":
main()