support more args in dp example

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-20 04:42:41 +00:00
parent df8f889f37
commit f93bdd3151

View File

@ -31,16 +31,13 @@ import os
from time import sleep
from vllm import LLM, SamplingParams
from vllm.utils import get_open_port
from vllm.utils import get_open_port, FlexibleArgumentParser
from vllm import LLM, EngineArgs
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="Data Parallel Inference")
parser.add_argument("--model",
type=str,
default="ibm-research/PowerMoE-3b",
help="Model name or path")
parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser)
parser.set_defaults(model="ibm-research/PowerMoE-3b")
parser.add_argument("--dp-size",
type=int,
default=2,
@ -65,17 +62,11 @@ def parse_args():
type=int,
default=0,
help="Master node port")
parser.add_argument("--enforce-eager",
action='store_true',
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action='store_true',
help="Trust remote code.")
return parser.parse_args()
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
@ -114,13 +105,15 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
top_p=0.95,
max_tokens=[16, 20][global_dp_rank % 2])
# Fixed params
args.pop("tensor_parallel_size")
args.pop("enable_expert_parallel")
# Create an LLM.
llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=enforce_eager,
enable_expert_parallel=True,
trust_remote_code=trust_remote_code,
**args,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
@ -139,19 +132,21 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
if __name__ == "__main__":
args = parse_args()
args = vars(parse_args())
dp_size = args.dp_size
tp_size = args.tp_size
node_size = args.node_size
node_rank = args.node_rank
dp_size = args.pop("dp_size")
tp_size = args.pop("tp_size")
node_size = args.pop("node_size")
node_rank = args.pop("node_rank")
if node_size == 1:
dp_master_ip = "127.0.0.1"
dp_master_port = get_open_port()
args.pop("master_addr")
args.pop("master_port")
else:
dp_master_ip = args.master_addr
dp_master_port = args.master_port
dp_master_ip = args.pop("master_addr")
dp_master_port = args.pop("master_port")
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
dp_per_node = dp_size // node_size
@ -162,10 +157,9 @@ if __name__ == "__main__":
for local_dp_rank, global_dp_rank in enumerate(
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
proc = Process(target=main,
args=(args.model, dp_size, local_dp_rank,
args=(args, dp_size, local_dp_rank,
global_dp_rank, dp_master_ip, dp_master_port,
tp_size, args.enforce_eager,
args.trust_remote_code))
tp_size,))
proc.start()
procs.append(proc)
exit_code = 0