diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index e785985bfed9d..5e00945708f32 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -32,8 +32,8 @@ Multi-node: import os from time import sleep -from vllm import LLM, EngineArgs, SamplingParams -from vllm.utils import FlexibleArgumentParser, get_open_port +from vllm import LLM, SamplingParams +from vllm.utils import get_open_port def parse_args(): @@ -76,6 +76,11 @@ def parse_args(): default=0.8, help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), ) + parser.add_argument( + "--enable-microbatching", + action="store_true", + help=("Enable microbatched execution"), + ) return parser.parse_args() @@ -135,24 +140,20 @@ def main( # sampling params. here we set different max_tokens for different # ranks for demonstration. sampling_params = SamplingParams( - temperature=0.8, top_p=0.95, max_tokens=[20, 16][global_dp_rank % 2] + temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2] ) - # Fixed params - args.pop("tensor_parallel_size") - args.pop("enable_expert_parallel") - # Create an LLM. llm = LLM( + model=model, tensor_parallel_size=GPUs_per_dp_rank, + enforce_eager=enforce_eager, enable_expert_parallel=True, trust_remote_code=trust_remote_code, max_num_seqs=max_num_seqs, gpu_memory_utilization=gpu_memory_utilization, ) - print("BEFORE GENERATE") outputs = llm.generate(prompts, sampling_params) - print("AFTER GENERATE") # Print the outputs. for i, output in enumerate(outputs): if i >= 5: @@ -170,22 +171,19 @@ def main( if __name__ == "__main__": + args = parse_args() - args = vars(parse_args()) - - dp_size = args.pop("dp_size") - tp_size = args.pop("tp_size") - node_size = args.pop("node_size") - node_rank = args.pop("node_rank") + dp_size = args.dp_size + tp_size = args.tp_size + node_size = args.node_size + node_rank = args.node_rank if node_size == 1: dp_master_ip = "127.0.0.1" dp_master_port = get_open_port() - args.pop("master_addr") - args.pop("master_port") else: - dp_master_ip = args.pop("master_addr") - dp_master_port = args.pop("master_port") + dp_master_ip = args.master_addr + dp_master_port = args.master_port assert dp_size % node_size == 0, "dp_size should be divisible by node_size" dp_per_node = dp_size // node_size @@ -216,7 +214,7 @@ if __name__ == "__main__": procs.append(proc) exit_code = 0 for proc in procs: - proc.join(timeout=1200) + proc.join(timeout=300) if proc.exitcode is None: print(f"Killing process {proc.pid} that didn't stop within 5 minutes.") proc.kill() @@ -224,4 +222,4 @@ if __name__ == "__main__": elif proc.exitcode: exit_code = proc.exitcode - exit(exit_code) + exit(exit_code) \ No newline at end of file