[Misc] format and refactor some examples (#16252)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
Reid 2025-04-08 18:42:32 +08:00 committed by GitHub
parent 995e3d1f41
commit 7f00899ff7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 178 additions and 115 deletions

View File

@ -90,8 +90,9 @@ def run_simple_demo(args: argparse.Namespace):
}, },
] ]
outputs = llm.chat(messages, sampling_params=sampling_params) outputs = llm.chat(messages, sampling_params=sampling_params)
print("-" * 50)
print(outputs[0].outputs[0].text) print(outputs[0].outputs[0].text)
print("-" * 50)
def run_advanced_demo(args: argparse.Namespace): def run_advanced_demo(args: argparse.Namespace):
@ -162,7 +163,9 @@ def run_advanced_demo(args: argparse.Namespace):
] ]
outputs = llm.chat(messages=messages, sampling_params=sampling_params) outputs = llm.chat(messages=messages, sampling_params=sampling_params)
print("-" * 50)
print(outputs[0].outputs[0].text) print(outputs[0].outputs[0].text)
print("-" * 50)
def main(): def main():

View File

@ -61,6 +61,7 @@ def process_requests(engine: LLMEngine,
"""Continuously process a list of prompts and handle the outputs.""" """Continuously process a list of prompts and handle the outputs."""
request_id = 0 request_id = 0
print("-" * 50)
while test_prompts or engine.has_unfinished_requests(): while test_prompts or engine.has_unfinished_requests():
if test_prompts: if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0) prompt, sampling_params, lora_request = test_prompts.pop(0)
@ -75,6 +76,7 @@ def process_requests(engine: LLMEngine,
for request_output in request_outputs: for request_output in request_outputs:
if request_output.finished: if request_output.finished:
print(request_output) print(request_output)
print("-" * 50)
def initialize_engine() -> LLMEngine: def initialize_engine() -> LLMEngine:

View File

@ -12,27 +12,36 @@ prompts = [
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM( def main():
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Create an LLM.
max_num_seqs=8, llm = LLM(
# The max_model_len and block_size arguments are required to be same as model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
# max sequence length when targeting neuron device. max_num_seqs=8,
# Currently, this is a known limitation in continuous batching support # The max_model_len and block_size arguments are required to be same as
# in transformers-neuronx. # max sequence length when targeting neuron device.
# TODO(liangfu): Support paged-attention in transformers-neuronx. # Currently, this is a known limitation in continuous batching support
max_model_len=1024, # in transformers-neuronx.
block_size=1024, # TODO(liangfu): Support paged-attention in transformers-neuronx.
# The device can be automatically detected when AWS Neuron SDK is installed. max_model_len=1024,
# The device argument can be either unspecified for automated detection, block_size=1024,
# or explicitly assigned. # ruff: noqa: E501
device="neuron", # The device can be automatically detected when AWS Neuron SDK is installed.
tensor_parallel_size=2) # The device argument can be either unspecified for automated detection,
# Generate texts from the prompts. The output is a list of RequestOutput objects # or explicitly assigned.
# that contain the prompt, generated text, and other information. device="neuron",
outputs = llm.generate(prompts, sampling_params) tensor_parallel_size=2)
# Print the outputs. # Generate texts from the prompts. The output is a list of RequestOutput objects
for output in outputs: # that contain the prompt, generated text, and other information.
prompt = output.prompt outputs = llm.generate(prompts, sampling_params)
generated_text = output.outputs[0].text # Print the outputs.
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
main()

View File

@ -22,31 +22,40 @@ prompts = [
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM( def main():
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Create an LLM.
max_num_seqs=8, llm = LLM(
# The max_model_len and block_size arguments are required to be same as model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
# max sequence length when targeting neuron device. max_num_seqs=8,
# Currently, this is a known limitation in continuous batching support # The max_model_len and block_size arguments are required to be same as
# in transformers-neuronx. # max sequence length when targeting neuron device.
# TODO(liangfu): Support paged-attention in transformers-neuronx. # Currently, this is a known limitation in continuous batching support
max_model_len=2048, # in transformers-neuronx.
block_size=2048, # TODO(liangfu): Support paged-attention in transformers-neuronx.
# The device can be automatically detected when AWS Neuron SDK is installed. max_model_len=2048,
# The device argument can be either unspecified for automated detection, block_size=2048,
# or explicitly assigned. # ruff: noqa: E501
device="neuron", # The device can be automatically detected when AWS Neuron SDK is installed.
quantization="neuron_quant", # The device argument can be either unspecified for automated detection,
override_neuron_config={ # or explicitly assigned.
"cast_logits_dtype": "bfloat16", device="neuron",
}, quantization="neuron_quant",
tensor_parallel_size=2) override_neuron_config={
# Generate texts from the prompts. The output is a list of RequestOutput objects "cast_logits_dtype": "bfloat16",
# that contain the prompt, generated text, and other information. },
outputs = llm.generate(prompts, sampling_params) tensor_parallel_size=2)
# Print the outputs. # Generate texts from the prompts. The output is a list of RequestOutput objects
for output in outputs: # that contain the prompt, generated text, and other information.
prompt = output.prompt outputs = llm.generate(prompts, sampling_params)
generated_text = output.outputs[0].text # Print the outputs.
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
main()

View File

@ -31,55 +31,62 @@ generating_prompts = [prefix + prompt for prompt in prompts]
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0) sampling_params = SamplingParams(temperature=0.0)
# Create an LLM without prefix caching as a baseline.
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)
print("Results without `enable_prefix_caching`") def main():
# Create an LLM without prefix caching as a baseline.
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)
# Generate texts from the prompts. The output is a list of RequestOutput objects print("Results without `enable_prefix_caching`")
# that contain the prompt, generated text, and other information.
outputs = regular_llm.generate(generating_prompts, sampling_params)
regular_generated_texts = [] # ruff: noqa: E501
# Print the outputs. # Generate texts from the prompts. The output is a list of RequestOutput objects
for output in outputs: # that contain the prompt, generated text, and other information.
prompt = output.prompt outputs = regular_llm.generate(generating_prompts, sampling_params)
generated_text = output.outputs[0].text
regular_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print("-" * 80) regular_generated_texts = []
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
regular_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Destroy the LLM object and free up the GPU memory. # Destroy the LLM object and free up the GPU memory.
del regular_llm del regular_llm
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
# Create an LLM with prefix caching enabled. # Create an LLM with prefix caching enabled.
prefix_cached_llm = LLM(model="facebook/opt-125m", prefix_cached_llm = LLM(model="facebook/opt-125m",
enable_prefix_caching=True, enable_prefix_caching=True,
gpu_memory_utilization=0.4) gpu_memory_utilization=0.4)
# Warmup so that the shared prompt's KV cache is computed. # Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params) prefix_cached_llm.generate(generating_prompts[0], sampling_params)
# Generate with prefix caching. # Generate with prefix caching.
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
print("Results with `enable_prefix_caching`") print("Results with `enable_prefix_caching`")
cached_generated_texts = [] cached_generated_texts = []
# Print the outputs. You should see the same outputs as before. # Print the outputs. You should see the same outputs as before.
for output in outputs: print("-" * 50)
prompt = output.prompt for output in outputs:
generated_text = output.outputs[0].text prompt = output.prompt
cached_generated_texts.append(generated_text) generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") cached_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
print("-" * 80) # Compare the results and display the speedup
generated_same = all([
regular_generated_texts[i] == cached_generated_texts[i]
for i in range(len(prompts))
])
print(f"Generated answers are the same: {generated_same}")
# Compare the results and display the speedup
generated_same = all([ if __name__ == "__main__":
regular_generated_texts[i] == cached_generated_texts[i] main()
for i in range(len(prompts))
])
print(f"Generated answers are the same: {generated_same}")

View File

@ -19,8 +19,6 @@ SEED = 42
# because it is almost impossible to make the scheduling deterministic in the # because it is almost impossible to make the scheduling deterministic in the
# online serving setting. # online serving setting.
llm = LLM(model="facebook/opt-125m", seed=SEED)
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
@ -29,8 +27,17 @@ prompts = [
] ]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(prompts, sampling_params)
for output in outputs: def main():
prompt = output.prompt llm = LLM(model="facebook/opt-125m", seed=SEED)
generated_text = output.outputs[0].text outputs = llm.generate(prompts, sampling_params)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
main()

View File

@ -85,11 +85,13 @@ sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params)) outputs = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, " print(f"Prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}") f"Generated text: {generated_text!r}")
print("-" * 50)
# set up the communication between the training process # set up the communication between the training process
# and the inference engine. # and the inference engine.
@ -120,8 +122,10 @@ assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
# use the updated model to generate texts, they will be nonsense # use the updated model to generate texts, they will be nonsense
# because the weights are all zeros. # because the weights are all zeros.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated: for output in outputs_updated:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, " print(f"Prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}") f"Generated text: {generated_text!r}")
print("-" * 50)

View File

@ -32,10 +32,12 @@ if __name__ == "__main__":
llm.stop_profile() llm.stop_profile()
# Print the outputs. # Print the outputs.
print("-" * 50)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Add a buffer to wait for profiler in the background process # Add a buffer to wait for profiler in the background process
# (in case MP is on) to finish writing profiling output. # (in case MP is on) to finish writing profiling output.

View File

@ -36,11 +36,13 @@ llm = LLM(
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# all ranks will have the same outputs # all ranks will have the same outputs
print("-" * 50)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, " print(f"Prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}") f"Generated text: {generated_text!r}")
print("-" * 50)
""" """
Further tips: Further tips:

View File

@ -16,14 +16,22 @@ N = 1
# Currently, top-p sampling is disabled. `top_p` should be 1.0. # Currently, top-p sampling is disabled. `top_p` should be 1.0.
sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16) sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16)
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`. def main():
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", # Set `enforce_eager=True` to avoid ahead-of-time compilation.
max_num_batched_tokens=64, # In real workloads, `enforace_eager` should be `False`.
max_num_seqs=4) llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
outputs = llm.generate(prompts, sampling_params) max_num_batched_tokens=64,
for output, answer in zip(outputs, answers): max_num_seqs=4)
prompt = output.prompt outputs = llm.generate(prompts, sampling_params)
generated_text = output.outputs[0].text print("-" * 50)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") for output, answer in zip(outputs, answers):
assert generated_text.startswith(answer) prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
assert generated_text.startswith(answer)
print("-" * 50)
if __name__ == "__main__":
main()

View File

@ -1089,14 +1089,18 @@ def main(args):
start_time = time.time() start_time = time.time()
outputs = llm.generate(inputs, sampling_params=sampling_params) outputs = llm.generate(inputs, sampling_params=sampling_params)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
print("-" * 50)
print("-- generate time = {}".format(elapsed_time)) print("-- generate time = {}".format(elapsed_time))
print("-" * 50)
else: else:
outputs = llm.generate(inputs, sampling_params=sampling_params) outputs = llm.generate(inputs, sampling_params=sampling_params)
print("-" * 50)
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)
print("-" * 50)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -143,8 +143,10 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
"multi_modal_data": mm_data, "multi_modal_data": mm_data,
}) })
print("-" * 50)
for output in outputs: for output in outputs:
print(output.outputs.embedding) print(output.outputs.embedding)
print("-" * 50)
def main(args: Namespace): def main(args: Namespace):

View File

@ -644,9 +644,11 @@ def run_generate(model, question: str, image_urls: list[str],
}, },
sampling_params=sampling_params) sampling_params=sampling_params)
print("-" * 50)
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)
print("-" * 50)
def run_chat(model: str, question: str, image_urls: list[str], def run_chat(model: str, question: str, image_urls: list[str],
@ -687,9 +689,11 @@ def run_chat(model: str, question: str, image_urls: list[str],
chat_template=req_data.chat_template, chat_template=req_data.chat_template,
) )
print("-" * 50)
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)
print("-" * 50)
def main(args: Namespace): def main(args: Namespace):