[ci][amd] fix EPLB execution test (#28742)

Signed-off-by: Bradley Davis <bradleyhd@meta.com>
This commit is contained in:
Bradley D 2025-11-19 23:53:38 -08:00 committed by GitHub
parent 7218f83992
commit 1e1c06789e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import os import os
import random import random
import pytest import pytest
import torch import torch
import torch.distributed import torch.distributed
import torch.multiprocessing as mp
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
@ -17,10 +17,12 @@ from vllm.distributed.parallel_state import (
) )
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
mp.set_start_method("spawn", force=True)
def distributed_run(fn, world_size):
def distributed_run(fn, world_size, *args):
number_of_processes = world_size number_of_processes = world_size
processes: list[multiprocessing.Process] = [] processes: list[mp.Process] = []
for i in range(number_of_processes): for i in range(number_of_processes):
env: dict[str, str] = {} env: dict[str, str] = {}
env["RANK"] = str(i) env["RANK"] = str(i)
@ -29,7 +31,7 @@ def distributed_run(fn, world_size):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost" env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345" env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env,)) p = mp.Process(target=fn, args=(env, world_size, *args))
processes.append(p) processes.append(p)
p.start() p.start()
@ -40,24 +42,16 @@ def distributed_run(fn, world_size):
assert p.exitcode == 0 assert p.exitcode == 0
def worker_fn_wrapper(fn): def set_env_vars_and_device(env: dict[str, str]) -> None:
# `multiprocessing.Process` cannot accept environment variables directly update_environment_variables(env)
# so we need to pass the environment variables as arguments local_rank = os.environ["LOCAL_RANK"]
# and update the environment variables in the function device = torch.device(f"cuda:{local_rank}")
def wrapped_fn(env): torch.cuda.set_device(device)
update_environment_variables(env) init_distributed_environment()
local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
init_distributed_environment()
# Ensure each worker process has the same random seed # Ensure each worker process has the same random seed
random.seed(42) random.seed(42)
torch.manual_seed(42) torch.manual_seed(42)
fn()
return wrapped_fn
def create_expert_indices_with_redundancy( def create_expert_indices_with_redundancy(
@ -275,6 +269,79 @@ def verify_redundant_experts_have_same_weights(
) )
def _test_rearrange_expert_weights_with_redundancy(
env, world_size, num_layers, num_local_experts, num_logical_experts
) -> None:
# Initialize model parallel (using tensor parallel as an entrypoint
# to expert parallel)
set_env_vars_and_device(env)
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
# Test parameters
total_physical_experts = world_size * num_local_experts
hidden_sizes = [32, 64] # Two different weight matrices
# Create old expert indices (with redundancy)
redundancy_config = create_redundancy_config(
num_logical_experts, total_physical_experts
)
old_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
redundancy_config,
)
# Create new expert indices (with redundancy)
new_redundancy_config = create_redundancy_config(
num_logical_experts, total_physical_experts
)
new_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
new_redundancy_config,
)
# Create expert weights
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
# Execute weight rearrangement
rearrange_expert_weights_inplace(
old_indices,
new_indices,
expert_weights,
ep_group,
is_profile=False,
)
# Verify the rearrangement result
verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
verify_redundant_experts_have_same_weights(
expert_weights,
new_indices,
hidden_sizes,
world_size,
num_local_experts,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"world_size,num_layers,num_local_experts,num_logical_experts", "world_size,num_layers,num_local_experts,num_logical_experts",
[ [
@ -305,78 +372,69 @@ def test_rearrange_expert_weights_with_redundancy(
if torch.cuda.device_count() < world_size: if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test") pytest.skip(f"Need at least {world_size} GPUs to run the test")
distributed_run(
_test_rearrange_expert_weights_with_redundancy,
world_size,
num_layers,
num_local_experts,
num_logical_experts,
)
@worker_fn_wrapper
def worker_fn():
# Initialize model parallel (using tensor parallel as an entrypoint
# to expert parallel)
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
ep_rank = torch.distributed.get_rank() set_env_vars_and_device(env)
device = torch.device(f"cuda:{ep_rank}") ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
# Test parameters ep_group = get_tp_group().cpu_group
total_physical_experts = world_size * num_local_experts ep_rank = torch.distributed.get_rank()
hidden_sizes = [32, 64] # Two different weight matrices device = torch.device(f"cuda:{ep_rank}")
# Create old expert indices (with redundancy) num_layers = 2
redundancy_config = create_redundancy_config( num_local_experts = 2
num_logical_experts, total_physical_experts total_physical_experts = world_size * num_local_experts
) num_logical_experts = total_physical_experts // 2 # Some redundancy
hidden_sizes = [32, 64]
old_indices = create_expert_indices_with_redundancy( # Create redundancy configuration
num_layers, redundancy_config = [2] * num_logical_experts
num_logical_experts,
total_physical_experts,
redundancy_config,
)
# Create new expert indices (with redundancy) # Same indices - no change
new_redundancy_config = create_redundancy_config( indices = create_expert_indices_with_redundancy(
num_logical_experts, total_physical_experts num_layers, num_logical_experts, total_physical_experts, redundancy_config
) )
new_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
new_redundancy_config,
)
# Create expert weights expert_weights = create_expert_weights(
expert_weights = create_expert_weights( num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices )
)
# Execute weight rearrangement # Save original weights
rearrange_expert_weights_inplace( original_weights = []
old_indices, for layer_weights in expert_weights:
new_indices, layer_copy = []
expert_weights, for weight in layer_weights:
ep_group, layer_copy.append(weight.clone())
is_profile=False, original_weights.append(layer_copy)
)
# Verify the rearrangement result # Execute rearrangement (should be no change)
verify_expert_weights_after_shuffle( rearrange_expert_weights_inplace(
expert_weights, indices,
new_indices, indices, # Same indices
hidden_sizes, expert_weights,
ep_rank, ep_group,
num_local_experts, is_profile=False,
) )
verify_redundant_experts_have_same_weights( # Verify that the weights have not changed
expert_weights, for layer in range(num_layers):
new_indices, for weight_idx in range(len(hidden_sizes)):
hidden_sizes, torch.testing.assert_close(
world_size, expert_weights[layer][weight_idx],
num_local_experts, original_weights[layer][weight_idx],
) msg=f"""Layer {layer}, weight {weight_idx}
should remain unchanged""",
distributed_run(worker_fn, world_size) )
@pytest.mark.parametrize("world_size", [2, 4]) @pytest.mark.parametrize("world_size", [2, 4])
@ -388,62 +446,69 @@ def test_rearrange_expert_weights_no_change(world_size):
if torch.cuda.device_count() < world_size: if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test") pytest.skip(f"Need at least {world_size} GPUs to run the test")
distributed_run(_test_rearrange_expert_weights_no_change, world_size)
@worker_fn_wrapper
def worker_fn():
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
ep_rank = torch.distributed.get_rank() set_env_vars_and_device(env)
device = torch.device(f"cuda:{ep_rank}") ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
num_layers = 2 ep_group = get_tp_group().cpu_group
num_local_experts = 2 ep_rank = torch.distributed.get_rank()
total_physical_experts = world_size * num_local_experts device = torch.device(f"cuda:{ep_rank}")
num_logical_experts = total_physical_experts // 2 # Some redundancy
hidden_sizes = [32, 64]
# Create redundancy configuration num_layers = 1
redundancy_config = [2] * num_logical_experts num_local_experts = 2
total_physical_experts = world_size * num_local_experts
num_logical_experts = total_physical_experts // 2
hidden_sizes = [32]
# Same indices - no change # Create different index distributions
indices = create_expert_indices_with_redundancy( old_redundancy = create_redundancy_config(
num_layers, num_logical_experts, total_physical_experts, redundancy_config num_logical_experts, total_physical_experts
) )
new_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
expert_weights = create_expert_weights( old_indices = create_expert_indices_with_redundancy(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices num_layers, num_logical_experts, total_physical_experts, old_redundancy
) )
new_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, new_redundancy
)
# Save original weights expert_weights = create_expert_weights(
original_weights = [] num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
for layer_weights in expert_weights: )
layer_copy = []
for weight in layer_weights:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
# Execute rearrangement (should be no change) # Save original weights
rearrange_expert_weights_inplace( original_weights = []
indices, for layer_weights in expert_weights:
indices, # Same indices layer_copy = []
expert_weights, for weight in layer_weights:
ep_group, layer_copy.append(weight.clone())
is_profile=False, original_weights.append(layer_copy)
)
# Verify that the weights have not changed # Execute profile mode rearrangement
for layer in range(num_layers): rearrange_expert_weights_inplace(
for weight_idx in range(len(hidden_sizes)): old_indices,
torch.testing.assert_close( new_indices,
expert_weights[layer][weight_idx], expert_weights,
original_weights[layer][weight_idx], ep_group,
msg=f"Layer {layer}, weight {weight_idx} should remain unchanged", is_profile=True, # Profile mode
) )
distributed_run(worker_fn, world_size) # In profile mode, the weights should remain unchanged
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg="In profile mode, the weights should remain unchanged",
)
@pytest.mark.parametrize("world_size", [2, 4]) @pytest.mark.parametrize("world_size", [2, 4])
@ -452,66 +517,4 @@ def test_rearrange_expert_weights_profile_mode(world_size):
if torch.cuda.device_count() < world_size: if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test") pytest.skip(f"Need at least {world_size} GPUs to run the test")
distributed_run(_test_rearrange_expert_weights_profile_mode, world_size)
@worker_fn_wrapper
def worker_fn():
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
num_layers = 1
num_local_experts = 2
total_physical_experts = world_size * num_local_experts
num_logical_experts = total_physical_experts // 2
hidden_sizes = [32]
# Create different index distributions
old_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
new_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
old_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, old_redundancy
)
new_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, new_redundancy
)
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
# Save original weights
original_weights = []
for layer_weights in expert_weights:
layer_copy = []
for weight in layer_weights:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
# Execute profile mode rearrangement
rearrange_expert_weights_inplace(
old_indices,
new_indices,
expert_weights,
ep_group,
is_profile=True, # Profile mode
)
# In profile mode, the weights should remain unchanged
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg="In profile mode, the weights should remain unchanged",
)
distributed_run(worker_fn, world_size)