mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
[Misc] refactor disaggregated-prefill-v1 example (#18474)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
parent
907f935de9
commit
107f5fc4cb
@ -5,5 +5,6 @@ This example contains scripts that demonstrate disaggregated prefill in the offl
|
||||
## Files
|
||||
|
||||
- `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially.
|
||||
- Make sure you are in the `examples/offline_inference/disaggregated-prefill-v1` directory before running `run.sh`.
|
||||
- `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`.
|
||||
- `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`.
|
||||
|
||||
@ -3,35 +3,47 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
# Read prompts from output.txt
|
||||
prompts = []
|
||||
try:
|
||||
with open("output.txt") as f:
|
||||
for line in f:
|
||||
prompts.append(line.strip())
|
||||
print(f"Loaded {len(prompts)} prompts from output.txt")
|
||||
except FileNotFoundError:
|
||||
print("Error: output.txt file not found")
|
||||
exit(-1)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
def read_prompts():
|
||||
"""Read prompts from output.txt"""
|
||||
prompts = []
|
||||
try:
|
||||
with open("output.txt") as f:
|
||||
for line in f:
|
||||
prompts.append(line.strip())
|
||||
print(f"Loaded {len(prompts)} prompts from output.txt")
|
||||
return prompts
|
||||
except FileNotFoundError:
|
||||
print("Error: output.txt file not found")
|
||||
exit(-1)
|
||||
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=16,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"shared_storage_path": "local_storage"
|
||||
})) #, max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
def main():
|
||||
prompts = read_prompts()
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=16,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"shared_storage_path": "local_storage"
|
||||
})) #, max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
print("-" * 30)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -3,42 +3,54 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
context = "Hi " * 1000
|
||||
context2 = "Hey " * 500
|
||||
prompts = [
|
||||
context + "Hello, my name is",
|
||||
context + "The capital of France is",
|
||||
context2 + "Your name is",
|
||||
context2 + "The capital of China is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
def read_prompts():
|
||||
context = "Hi " * 1000
|
||||
context2 = "Hey " * 500
|
||||
return [
|
||||
context + "Hello, my name is",
|
||||
context + "The capital of France is",
|
||||
context2 + "Your name is",
|
||||
context2 + "The capital of China is",
|
||||
]
|
||||
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"shared_storage_path": "local_storage"
|
||||
})) #, max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
)
|
||||
def main():
|
||||
prompts = read_prompts()
|
||||
|
||||
new_prompts = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
new_prompts.append(prompt + generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
# Write new_prompts to output.txt
|
||||
with open("output.txt", "w") as f:
|
||||
for prompt in new_prompts:
|
||||
f.write(prompt + "\n")
|
||||
print(f"Saved {len(new_prompts)} prompts to output.txt")
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"shared_storage_path": "local_storage"
|
||||
})) #, max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
new_prompts = []
|
||||
print("-" * 30)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
new_prompts.append(prompt + generated_text)
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 30)
|
||||
|
||||
# Write new_prompts to output.txt
|
||||
with open("output.txt", "w") as f:
|
||||
for prompt in new_prompts:
|
||||
f.write(prompt + "\n")
|
||||
print(f"Saved {len(new_prompts)} prompts to output.txt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user