fix data_parallel.py

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-02 19:25:59 +00:00
parent c0efbbb5de
commit 0767d9863f

View File

@ -32,8 +32,8 @@ Multi-node:
import os
from time import sleep
from vllm import LLM, EngineArgs, SamplingParams
from vllm.utils import FlexibleArgumentParser, get_open_port
from vllm import LLM, SamplingParams
from vllm.utils import get_open_port
def parse_args():
@ -76,6 +76,11 @@ def parse_args():
default=0.8,
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
)
parser.add_argument(
"--enable-microbatching",
action="store_true",
help=("Enable microbatched execution"),
)
return parser.parse_args()
@ -135,24 +140,20 @@ def main(
# sampling params. here we set different max_tokens for different
# ranks for demonstration.
sampling_params = SamplingParams(
temperature=0.8, top_p=0.95, max_tokens=[20, 16][global_dp_rank % 2]
temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
)
# Fixed params
args.pop("tensor_parallel_size")
args.pop("enable_expert_parallel")
# Create an LLM.
llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=enforce_eager,
enable_expert_parallel=True,
trust_remote_code=trust_remote_code,
max_num_seqs=max_num_seqs,
gpu_memory_utilization=gpu_memory_utilization,
)
print("BEFORE GENERATE")
outputs = llm.generate(prompts, sampling_params)
print("AFTER GENERATE")
# Print the outputs.
for i, output in enumerate(outputs):
if i >= 5:
@ -170,22 +171,19 @@ def main(
if __name__ == "__main__":
args = parse_args()
args = vars(parse_args())
dp_size = args.pop("dp_size")
tp_size = args.pop("tp_size")
node_size = args.pop("node_size")
node_rank = args.pop("node_rank")
dp_size = args.dp_size
tp_size = args.tp_size
node_size = args.node_size
node_rank = args.node_rank
if node_size == 1:
dp_master_ip = "127.0.0.1"
dp_master_port = get_open_port()
args.pop("master_addr")
args.pop("master_port")
else:
dp_master_ip = args.pop("master_addr")
dp_master_port = args.pop("master_port")
dp_master_ip = args.master_addr
dp_master_port = args.master_port
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
dp_per_node = dp_size // node_size
@ -216,7 +214,7 @@ if __name__ == "__main__":
procs.append(proc)
exit_code = 0
for proc in procs:
proc.join(timeout=1200)
proc.join(timeout=300)
if proc.exitcode is None:
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
proc.kill()
@ -224,4 +222,4 @@ if __name__ == "__main__":
elif proc.exitcode:
exit_code = proc.exitcode
exit(exit_code)
exit(exit_code)