mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 18:35:00 +08:00
[ci][amd] fix EPLB execution test (#28742)
Signed-off-by: Bradley Davis <bradleyhd@meta.com>
This commit is contained in:
parent
7218f83992
commit
1e1c06789e
@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
@ -17,10 +17,12 @@ from vllm.distributed.parallel_state import (
|
||||
)
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def distributed_run(fn, world_size):
|
||||
|
||||
def distributed_run(fn, world_size, *args):
|
||||
number_of_processes = world_size
|
||||
processes: list[multiprocessing.Process] = []
|
||||
processes: list[mp.Process] = []
|
||||
for i in range(number_of_processes):
|
||||
env: dict[str, str] = {}
|
||||
env["RANK"] = str(i)
|
||||
@ -29,7 +31,7 @@ def distributed_run(fn, world_size):
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12345"
|
||||
p = multiprocessing.Process(target=fn, args=(env,))
|
||||
p = mp.Process(target=fn, args=(env, world_size, *args))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
@ -40,11 +42,7 @@ def distributed_run(fn, world_size):
|
||||
assert p.exitcode == 0
|
||||
|
||||
|
||||
def worker_fn_wrapper(fn):
|
||||
# `multiprocessing.Process` cannot accept environment variables directly
|
||||
# so we need to pass the environment variables as arguments
|
||||
# and update the environment variables in the function
|
||||
def wrapped_fn(env):
|
||||
def set_env_vars_and_device(env: dict[str, str]) -> None:
|
||||
update_environment_variables(env)
|
||||
local_rank = os.environ["LOCAL_RANK"]
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
@ -55,10 +53,6 @@ def worker_fn_wrapper(fn):
|
||||
random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
|
||||
fn()
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
def create_expert_indices_with_redundancy(
|
||||
num_layers: int,
|
||||
@ -275,41 +269,12 @@ def verify_redundant_experts_have_same_weights(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"world_size,num_layers,num_local_experts,num_logical_experts",
|
||||
[
|
||||
# 2 GPU, 2 experts per GPU
|
||||
# 3 logical experts, 4 physical experts, 1 redundant experts
|
||||
(2, 1, 2, 3),
|
||||
# 2 GPU, 3 experts per GPU
|
||||
# 4 logical experts, 6 physical experts, 2 redundant experts
|
||||
(2, 2, 3, 4),
|
||||
# 2 GPU, 8 experts per GPU
|
||||
# 16 logical experts, 16 physical experts, 0 redundant experts
|
||||
(2, 4, 8, 16),
|
||||
# 4 GPU, 2 experts per GPU
|
||||
# 6 logical experts, 8 physical experts, 2 redundant experts
|
||||
(4, 1, 2, 6),
|
||||
# 4 GPU, 2 experts per GPU
|
||||
# 5 logical experts, 8 physical experts, 3 redundant experts
|
||||
(4, 2, 2, 5),
|
||||
# 4 GPU, 8 experts per GPU
|
||||
# 16 logical experts, 32 physical experts, 16 redundant experts
|
||||
(4, 8, 8, 16),
|
||||
],
|
||||
)
|
||||
def test_rearrange_expert_weights_with_redundancy(
|
||||
world_size, num_layers, num_local_experts, num_logical_experts
|
||||
):
|
||||
"""Test the functionality of rearranging expert weights with redundancy."""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
def _test_rearrange_expert_weights_with_redundancy(
|
||||
env, world_size, num_layers, num_local_experts, num_logical_experts
|
||||
) -> None:
|
||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||
# to expert parallel)
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
@ -376,21 +341,48 @@ def test_rearrange_expert_weights_with_redundancy(
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
distributed_run(worker_fn, world_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_rearrange_expert_weights_no_change(world_size):
|
||||
"""
|
||||
Test that when the indices do not change, the weights should remain
|
||||
unchanged.
|
||||
"""
|
||||
@pytest.mark.parametrize(
|
||||
"world_size,num_layers,num_local_experts,num_logical_experts",
|
||||
[
|
||||
# 2 GPU, 2 experts per GPU
|
||||
# 3 logical experts, 4 physical experts, 1 redundant experts
|
||||
(2, 1, 2, 3),
|
||||
# 2 GPU, 3 experts per GPU
|
||||
# 4 logical experts, 6 physical experts, 2 redundant experts
|
||||
(2, 2, 3, 4),
|
||||
# 2 GPU, 8 experts per GPU
|
||||
# 16 logical experts, 16 physical experts, 0 redundant experts
|
||||
(2, 4, 8, 16),
|
||||
# 4 GPU, 2 experts per GPU
|
||||
# 6 logical experts, 8 physical experts, 2 redundant experts
|
||||
(4, 1, 2, 6),
|
||||
# 4 GPU, 2 experts per GPU
|
||||
# 5 logical experts, 8 physical experts, 3 redundant experts
|
||||
(4, 2, 2, 5),
|
||||
# 4 GPU, 8 experts per GPU
|
||||
# 16 logical experts, 32 physical experts, 16 redundant experts
|
||||
(4, 8, 8, 16),
|
||||
],
|
||||
)
|
||||
def test_rearrange_expert_weights_with_redundancy(
|
||||
world_size, num_layers, num_local_experts, num_logical_experts
|
||||
):
|
||||
"""Test the functionality of rearranging expert weights with redundancy."""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
distributed_run(
|
||||
_test_rearrange_expert_weights_with_redundancy,
|
||||
world_size,
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
)
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
|
||||
def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
@ -440,21 +432,25 @@ def test_rearrange_expert_weights_no_change(world_size):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg=f"Layer {layer}, weight {weight_idx} should remain unchanged",
|
||||
msg=f"""Layer {layer}, weight {weight_idx}
|
||||
should remain unchanged""",
|
||||
)
|
||||
|
||||
distributed_run(worker_fn, world_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_rearrange_expert_weights_profile_mode(world_size):
|
||||
"""Test profile mode (should not copy actual weights)"""
|
||||
def test_rearrange_expert_weights_no_change(world_size):
|
||||
"""
|
||||
Test that when the indices do not change, the weights should remain
|
||||
unchanged.
|
||||
"""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
distributed_run(_test_rearrange_expert_weights_no_change, world_size)
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
|
||||
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
@ -514,4 +510,11 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
||||
msg="In profile mode, the weights should remain unchanged",
|
||||
)
|
||||
|
||||
distributed_run(worker_fn, world_size)
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_rearrange_expert_weights_profile_mode(world_size):
|
||||
"""Test profile mode (should not copy actual weights)"""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
distributed_run(_test_rearrange_expert_weights_profile_mode, world_size)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user