[Misc] refactor disaggregated-prefill-v1 example (#18474)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
Reid 2025-05-21 19:10:14 +08:00 committed by GitHub
parent 907f935de9
commit 107f5fc4cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 87 additions and 62 deletions

View File

@ -5,5 +5,6 @@ This example contains scripts that demonstrate disaggregated prefill in the offl
## Files
- `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially.
- Make sure you are in the `examples/offline_inference/disaggregated-prefill-v1` directory before running `run.sh`.
- `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`.
- `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`.

View File

@ -3,35 +3,47 @@
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
# Read prompts from output.txt
prompts = []
try:
with open("output.txt") as f:
for line in f:
prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from output.txt")
except FileNotFoundError:
print("Error: output.txt file not found")
exit(-1)
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
def read_prompts():
"""Read prompts from output.txt"""
prompts = []
try:
with open("output.txt") as f:
for line in f:
prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from output.txt")
return prompts
except FileNotFoundError:
print("Error: output.txt file not found")
exit(-1)
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage"
})) #, max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params)
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage"
})) #, max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params)
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
if __name__ == "__main__":
main()

View File

@ -3,42 +3,54 @@
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
context = "Hi " * 1000
context2 = "Hey " * 500
prompts = [
context + "Hello, my name is",
context + "The capital of France is",
context2 + "Your name is",
context2 + "The capital of China is",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
def read_prompts():
context = "Hi " * 1000
context2 = "Hey " * 500
return [
context + "Hello, my name is",
context + "The capital of France is",
context2 + "Your name is",
context2 + "The capital of China is",
]
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage"
})) #, max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)
def main():
prompts = read_prompts()
new_prompts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
# Write new_prompts to output.txt
with open("output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to output.txt")
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage"
})) #, max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)
new_prompts = []
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
# Write new_prompts to output.txt
with open("output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to output.txt")
if __name__ == "__main__":
main()