diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index b17761e00d0f1..054fa33403d04 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -30,10 +30,9 @@ Multi-node: import os from time import sleep -from vllm import LLM, SamplingParams -from vllm.utils import get_open_port, FlexibleArgumentParser -from vllm import LLM, EngineArgs -import torch +from vllm import LLM, EngineArgs, SamplingParams +from vllm.utils import FlexibleArgumentParser, get_open_port + def parse_args(): parser = FlexibleArgumentParser() @@ -158,9 +157,15 @@ if __name__ == "__main__": for local_dp_rank, global_dp_rank in enumerate( range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)): proc = Process(target=main, - args=(args, dp_size, local_dp_rank, - global_dp_rank, dp_master_ip, dp_master_port, - tp_size,)) + args=( + args, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + tp_size, + )) proc.start() procs.append(proc) exit_code = 0