mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 06:37:04 +08:00
fix data_parallel.py
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
c0efbbb5de
commit
0767d9863f
@ -32,8 +32,8 @@ Multi-node:
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_port
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -76,6 +76,11 @@ def parse_args():
|
||||
default=0.8,
|
||||
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-microbatching",
|
||||
action="store_true",
|
||||
help=("Enable microbatched execution"),
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -135,24 +140,20 @@ def main(
|
||||
# sampling params. here we set different max_tokens for different
|
||||
# ranks for demonstration.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8, top_p=0.95, max_tokens=[20, 16][global_dp_rank % 2]
|
||||
temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
|
||||
)
|
||||
|
||||
# Fixed params
|
||||
args.pop("tensor_parallel_size")
|
||||
args.pop("enable_expert_parallel")
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tensor_parallel_size=GPUs_per_dp_rank,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_expert_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
max_num_seqs=max_num_seqs,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
)
|
||||
print("BEFORE GENERATE")
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
print("AFTER GENERATE")
|
||||
# Print the outputs.
|
||||
for i, output in enumerate(outputs):
|
||||
if i >= 5:
|
||||
@ -170,22 +171,19 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
args = vars(parse_args())
|
||||
|
||||
dp_size = args.pop("dp_size")
|
||||
tp_size = args.pop("tp_size")
|
||||
node_size = args.pop("node_size")
|
||||
node_rank = args.pop("node_rank")
|
||||
dp_size = args.dp_size
|
||||
tp_size = args.tp_size
|
||||
node_size = args.node_size
|
||||
node_rank = args.node_rank
|
||||
|
||||
if node_size == 1:
|
||||
dp_master_ip = "127.0.0.1"
|
||||
dp_master_port = get_open_port()
|
||||
args.pop("master_addr")
|
||||
args.pop("master_port")
|
||||
else:
|
||||
dp_master_ip = args.pop("master_addr")
|
||||
dp_master_port = args.pop("master_port")
|
||||
dp_master_ip = args.master_addr
|
||||
dp_master_port = args.master_port
|
||||
|
||||
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
|
||||
dp_per_node = dp_size // node_size
|
||||
@ -216,7 +214,7 @@ if __name__ == "__main__":
|
||||
procs.append(proc)
|
||||
exit_code = 0
|
||||
for proc in procs:
|
||||
proc.join(timeout=1200)
|
||||
proc.join(timeout=300)
|
||||
if proc.exitcode is None:
|
||||
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
|
||||
proc.kill()
|
||||
@ -224,4 +222,4 @@ if __name__ == "__main__":
|
||||
elif proc.exitcode:
|
||||
exit_code = proc.exitcode
|
||||
|
||||
exit(exit_code)
|
||||
exit(exit_code)
|
||||
Loading…
x
Reference in New Issue
Block a user