support more args in dp example

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-20 04:42:41 +00:00
parent df8f889f37
commit f93bdd3151

View File

@ -31,16 +31,13 @@ import os
from time import sleep from time import sleep
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import get_open_port from vllm.utils import get_open_port, FlexibleArgumentParser
from vllm import LLM, EngineArgs
def parse_args(): def parse_args():
import argparse parser = FlexibleArgumentParser()
parser = argparse.ArgumentParser(description="Data Parallel Inference") EngineArgs.add_cli_args(parser)
parser.add_argument("--model", parser.set_defaults(model="ibm-research/PowerMoE-3b")
type=str,
default="ibm-research/PowerMoE-3b",
help="Model name or path")
parser.add_argument("--dp-size", parser.add_argument("--dp-size",
type=int, type=int,
default=2, default=2,
@ -65,17 +62,11 @@ def parse_args():
type=int, type=int,
default=0, default=0,
help="Master node port") help="Master node port")
parser.add_argument("--enforce-eager",
action='store_true',
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action='store_true',
help="Trust remote code.")
return parser.parse_args() return parser.parse_args()
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code): dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size) os.environ["VLLM_DP_SIZE"] = str(dp_size)
@ -114,13 +105,15 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
top_p=0.95, top_p=0.95,
max_tokens=[16, 20][global_dp_rank % 2]) max_tokens=[16, 20][global_dp_rank % 2])
# Fixed params
args.pop("tensor_parallel_size")
args.pop("enable_expert_parallel")
# Create an LLM. # Create an LLM.
llm = LLM( llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank, tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=enforce_eager,
enable_expert_parallel=True, enable_expert_parallel=True,
trust_remote_code=trust_remote_code, **args,
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
@ -139,19 +132,21 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = vars(parse_args())
dp_size = args.dp_size dp_size = args.pop("dp_size")
tp_size = args.tp_size tp_size = args.pop("tp_size")
node_size = args.node_size node_size = args.pop("node_size")
node_rank = args.node_rank node_rank = args.pop("node_rank")
if node_size == 1: if node_size == 1:
dp_master_ip = "127.0.0.1" dp_master_ip = "127.0.0.1"
dp_master_port = get_open_port() dp_master_port = get_open_port()
args.pop("master_addr")
args.pop("master_port")
else: else:
dp_master_ip = args.master_addr dp_master_ip = args.pop("master_addr")
dp_master_port = args.master_port dp_master_port = args.pop("master_port")
assert dp_size % node_size == 0, "dp_size should be divisible by node_size" assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
dp_per_node = dp_size // node_size dp_per_node = dp_size // node_size
@ -162,10 +157,9 @@ if __name__ == "__main__":
for local_dp_rank, global_dp_rank in enumerate( for local_dp_rank, global_dp_rank in enumerate(
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)): range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
proc = Process(target=main, proc = Process(target=main,
args=(args.model, dp_size, local_dp_rank, args=(args, dp_size, local_dp_rank,
global_dp_rank, dp_master_ip, dp_master_port, global_dp_rank, dp_master_ip, dp_master_port,
tp_size, args.enforce_eager, tp_size,))
args.trust_remote_code))
proc.start() proc.start()
procs.append(proc) procs.append(proc)
exit_code = 0 exit_code = 0