[ci][amd] fix EPLB execution test (#28742)

Signed-off-by: Bradley Davis <bradleyhd@meta.com>
This commit is contained in:
Bradley D 2025-11-19 23:53:38 -08:00 committed by GitHub
parent 7218f83992
commit 1e1c06789e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import os
import random
import pytest
import torch
import torch.distributed
import torch.multiprocessing as mp
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
@ -17,10 +17,12 @@ from vllm.distributed.parallel_state import (
)
from vllm.utils.system_utils import update_environment_variables
mp.set_start_method("spawn", force=True)
def distributed_run(fn, world_size):
def distributed_run(fn, world_size, *args):
number_of_processes = world_size
processes: list[multiprocessing.Process] = []
processes: list[mp.Process] = []
for i in range(number_of_processes):
env: dict[str, str] = {}
env["RANK"] = str(i)
@ -29,7 +31,7 @@ def distributed_run(fn, world_size):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env,))
p = mp.Process(target=fn, args=(env, world_size, *args))
processes.append(p)
p.start()
@ -40,24 +42,16 @@ def distributed_run(fn, world_size):
assert p.exitcode == 0
def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapped_fn(env):
update_environment_variables(env)
local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
init_distributed_environment()
def set_env_vars_and_device(env: dict[str, str]) -> None:
update_environment_variables(env)
local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
init_distributed_environment()
# Ensure each worker process has the same random seed
random.seed(42)
torch.manual_seed(42)
fn()
return wrapped_fn
# Ensure each worker process has the same random seed
random.seed(42)
torch.manual_seed(42)
def create_expert_indices_with_redundancy(
@ -275,6 +269,79 @@ def verify_redundant_experts_have_same_weights(
)
def _test_rearrange_expert_weights_with_redundancy(
env, world_size, num_layers, num_local_experts, num_logical_experts
) -> None:
# Initialize model parallel (using tensor parallel as an entrypoint
# to expert parallel)
set_env_vars_and_device(env)
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
# Test parameters
total_physical_experts = world_size * num_local_experts
hidden_sizes = [32, 64] # Two different weight matrices
# Create old expert indices (with redundancy)
redundancy_config = create_redundancy_config(
num_logical_experts, total_physical_experts
)
old_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
redundancy_config,
)
# Create new expert indices (with redundancy)
new_redundancy_config = create_redundancy_config(
num_logical_experts, total_physical_experts
)
new_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
new_redundancy_config,
)
# Create expert weights
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
# Execute weight rearrangement
rearrange_expert_weights_inplace(
old_indices,
new_indices,
expert_weights,
ep_group,
is_profile=False,
)
# Verify the rearrangement result
verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
verify_redundant_experts_have_same_weights(
expert_weights,
new_indices,
hidden_sizes,
world_size,
num_local_experts,
)
@pytest.mark.parametrize(
"world_size,num_layers,num_local_experts,num_logical_experts",
[
@ -305,78 +372,69 @@ def test_rearrange_expert_weights_with_redundancy(
if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
distributed_run(
_test_rearrange_expert_weights_with_redundancy,
world_size,
num_layers,
num_local_experts,
num_logical_experts,
)
@worker_fn_wrapper
def worker_fn():
# Initialize model parallel (using tensor parallel as an entrypoint
# to expert parallel)
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
set_env_vars_and_device(env)
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
# Test parameters
total_physical_experts = world_size * num_local_experts
hidden_sizes = [32, 64] # Two different weight matrices
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
# Create old expert indices (with redundancy)
redundancy_config = create_redundancy_config(
num_logical_experts, total_physical_experts
)
num_layers = 2
num_local_experts = 2
total_physical_experts = world_size * num_local_experts
num_logical_experts = total_physical_experts // 2 # Some redundancy
hidden_sizes = [32, 64]
old_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
redundancy_config,
)
# Create redundancy configuration
redundancy_config = [2] * num_logical_experts
# Create new expert indices (with redundancy)
new_redundancy_config = create_redundancy_config(
num_logical_experts, total_physical_experts
)
new_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
new_redundancy_config,
)
# Same indices - no change
indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, redundancy_config
)
# Create expert weights
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
)
# Execute weight rearrangement
rearrange_expert_weights_inplace(
old_indices,
new_indices,
expert_weights,
ep_group,
is_profile=False,
)
# Save original weights
original_weights = []
for layer_weights in expert_weights:
layer_copy = []
for weight in layer_weights:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
# Verify the rearrangement result
verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
# Execute rearrangement (should be no change)
rearrange_expert_weights_inplace(
indices,
indices, # Same indices
expert_weights,
ep_group,
is_profile=False,
)
verify_redundant_experts_have_same_weights(
expert_weights,
new_indices,
hidden_sizes,
world_size,
num_local_experts,
)
distributed_run(worker_fn, world_size)
# Verify that the weights have not changed
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg=f"""Layer {layer}, weight {weight_idx}
should remain unchanged""",
)
@pytest.mark.parametrize("world_size", [2, 4])
@ -388,62 +446,69 @@ def test_rearrange_expert_weights_no_change(world_size):
if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
distributed_run(_test_rearrange_expert_weights_no_change, world_size)
@worker_fn_wrapper
def worker_fn():
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
set_env_vars_and_device(env)
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
num_layers = 2
num_local_experts = 2
total_physical_experts = world_size * num_local_experts
num_logical_experts = total_physical_experts // 2 # Some redundancy
hidden_sizes = [32, 64]
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
# Create redundancy configuration
redundancy_config = [2] * num_logical_experts
num_layers = 1
num_local_experts = 2
total_physical_experts = world_size * num_local_experts
num_logical_experts = total_physical_experts // 2
hidden_sizes = [32]
# Same indices - no change
indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, redundancy_config
)
# Create different index distributions
old_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
new_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
)
old_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, old_redundancy
)
new_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, new_redundancy
)
# Save original weights
original_weights = []
for layer_weights in expert_weights:
layer_copy = []
for weight in layer_weights:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
# Execute rearrangement (should be no change)
rearrange_expert_weights_inplace(
indices,
indices, # Same indices
expert_weights,
ep_group,
is_profile=False,
)
# Save original weights
original_weights = []
for layer_weights in expert_weights:
layer_copy = []
for weight in layer_weights:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
# Verify that the weights have not changed
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg=f"Layer {layer}, weight {weight_idx} should remain unchanged",
)
# Execute profile mode rearrangement
rearrange_expert_weights_inplace(
old_indices,
new_indices,
expert_weights,
ep_group,
is_profile=True, # Profile mode
)
distributed_run(worker_fn, world_size)
# In profile mode, the weights should remain unchanged
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg="In profile mode, the weights should remain unchanged",
)
@pytest.mark.parametrize("world_size", [2, 4])
@ -452,66 +517,4 @@ def test_rearrange_expert_weights_profile_mode(world_size):
if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
@worker_fn_wrapper
def worker_fn():
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
num_layers = 1
num_local_experts = 2
total_physical_experts = world_size * num_local_experts
num_logical_experts = total_physical_experts // 2
hidden_sizes = [32]
# Create different index distributions
old_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
new_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
old_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, old_redundancy
)
new_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, new_redundancy
)
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
# Save original weights
original_weights = []
for layer_weights in expert_weights:
layer_copy = []
for weight in layer_weights:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
# Execute profile mode rearrangement
rearrange_expert_weights_inplace(
old_indices,
new_indices,
expert_weights,
ep_group,
is_profile=True, # Profile mode
)
# In profile mode, the weights should remain unchanged
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg="In profile mode, the weights should remain unchanged",
)
distributed_run(worker_fn, world_size)
distributed_run(_test_rearrange_expert_weights_profile_mode, world_size)