mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:15:42 +08:00
[ci][amd] fix EPLB execution test (#28742)
Signed-off-by: Bradley Davis <bradleyhd@meta.com>
This commit is contained in:
parent
7218f83992
commit
1e1c06789e
@ -1,13 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import multiprocessing
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
@ -17,10 +17,12 @@ from vllm.distributed.parallel_state import (
|
|||||||
)
|
)
|
||||||
from vllm.utils.system_utils import update_environment_variables
|
from vllm.utils.system_utils import update_environment_variables
|
||||||
|
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
def distributed_run(fn, world_size):
|
|
||||||
|
def distributed_run(fn, world_size, *args):
|
||||||
number_of_processes = world_size
|
number_of_processes = world_size
|
||||||
processes: list[multiprocessing.Process] = []
|
processes: list[mp.Process] = []
|
||||||
for i in range(number_of_processes):
|
for i in range(number_of_processes):
|
||||||
env: dict[str, str] = {}
|
env: dict[str, str] = {}
|
||||||
env["RANK"] = str(i)
|
env["RANK"] = str(i)
|
||||||
@ -29,7 +31,7 @@ def distributed_run(fn, world_size):
|
|||||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||||
env["MASTER_ADDR"] = "localhost"
|
env["MASTER_ADDR"] = "localhost"
|
||||||
env["MASTER_PORT"] = "12345"
|
env["MASTER_PORT"] = "12345"
|
||||||
p = multiprocessing.Process(target=fn, args=(env,))
|
p = mp.Process(target=fn, args=(env, world_size, *args))
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
@ -40,24 +42,16 @@ def distributed_run(fn, world_size):
|
|||||||
assert p.exitcode == 0
|
assert p.exitcode == 0
|
||||||
|
|
||||||
|
|
||||||
def worker_fn_wrapper(fn):
|
def set_env_vars_and_device(env: dict[str, str]) -> None:
|
||||||
# `multiprocessing.Process` cannot accept environment variables directly
|
update_environment_variables(env)
|
||||||
# so we need to pass the environment variables as arguments
|
local_rank = os.environ["LOCAL_RANK"]
|
||||||
# and update the environment variables in the function
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
def wrapped_fn(env):
|
torch.cuda.set_device(device)
|
||||||
update_environment_variables(env)
|
init_distributed_environment()
|
||||||
local_rank = os.environ["LOCAL_RANK"]
|
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
init_distributed_environment()
|
|
||||||
|
|
||||||
# Ensure each worker process has the same random seed
|
# Ensure each worker process has the same random seed
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
fn()
|
|
||||||
|
|
||||||
return wrapped_fn
|
|
||||||
|
|
||||||
|
|
||||||
def create_expert_indices_with_redundancy(
|
def create_expert_indices_with_redundancy(
|
||||||
@ -275,6 +269,79 @@ def verify_redundant_experts_have_same_weights(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_rearrange_expert_weights_with_redundancy(
|
||||||
|
env, world_size, num_layers, num_local_experts, num_logical_experts
|
||||||
|
) -> None:
|
||||||
|
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||||
|
# to expert parallel)
|
||||||
|
set_env_vars_and_device(env)
|
||||||
|
ensure_model_parallel_initialized(
|
||||||
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
|
)
|
||||||
|
|
||||||
|
ep_group = get_tp_group().cpu_group
|
||||||
|
ep_rank = torch.distributed.get_rank()
|
||||||
|
device = torch.device(f"cuda:{ep_rank}")
|
||||||
|
|
||||||
|
# Test parameters
|
||||||
|
total_physical_experts = world_size * num_local_experts
|
||||||
|
hidden_sizes = [32, 64] # Two different weight matrices
|
||||||
|
|
||||||
|
# Create old expert indices (with redundancy)
|
||||||
|
redundancy_config = create_redundancy_config(
|
||||||
|
num_logical_experts, total_physical_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
old_indices = create_expert_indices_with_redundancy(
|
||||||
|
num_layers,
|
||||||
|
num_logical_experts,
|
||||||
|
total_physical_experts,
|
||||||
|
redundancy_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new expert indices (with redundancy)
|
||||||
|
new_redundancy_config = create_redundancy_config(
|
||||||
|
num_logical_experts, total_physical_experts
|
||||||
|
)
|
||||||
|
new_indices = create_expert_indices_with_redundancy(
|
||||||
|
num_layers,
|
||||||
|
num_logical_experts,
|
||||||
|
total_physical_experts,
|
||||||
|
new_redundancy_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create expert weights
|
||||||
|
expert_weights = create_expert_weights(
|
||||||
|
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute weight rearrangement
|
||||||
|
rearrange_expert_weights_inplace(
|
||||||
|
old_indices,
|
||||||
|
new_indices,
|
||||||
|
expert_weights,
|
||||||
|
ep_group,
|
||||||
|
is_profile=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the rearrangement result
|
||||||
|
verify_expert_weights_after_shuffle(
|
||||||
|
expert_weights,
|
||||||
|
new_indices,
|
||||||
|
hidden_sizes,
|
||||||
|
ep_rank,
|
||||||
|
num_local_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_redundant_experts_have_same_weights(
|
||||||
|
expert_weights,
|
||||||
|
new_indices,
|
||||||
|
hidden_sizes,
|
||||||
|
world_size,
|
||||||
|
num_local_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"world_size,num_layers,num_local_experts,num_logical_experts",
|
"world_size,num_layers,num_local_experts,num_logical_experts",
|
||||||
[
|
[
|
||||||
@ -305,78 +372,69 @@ def test_rearrange_expert_weights_with_redundancy(
|
|||||||
|
|
||||||
if torch.cuda.device_count() < world_size:
|
if torch.cuda.device_count() < world_size:
|
||||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||||
|
distributed_run(
|
||||||
|
_test_rearrange_expert_weights_with_redundancy,
|
||||||
|
world_size,
|
||||||
|
num_layers,
|
||||||
|
num_local_experts,
|
||||||
|
num_logical_experts,
|
||||||
|
)
|
||||||
|
|
||||||
@worker_fn_wrapper
|
|
||||||
def worker_fn():
|
|
||||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
|
||||||
# to expert parallel)
|
|
||||||
ensure_model_parallel_initialized(
|
|
||||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
|
||||||
)
|
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||||
ep_rank = torch.distributed.get_rank()
|
set_env_vars_and_device(env)
|
||||||
device = torch.device(f"cuda:{ep_rank}")
|
ensure_model_parallel_initialized(
|
||||||
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
|
)
|
||||||
|
|
||||||
# Test parameters
|
ep_group = get_tp_group().cpu_group
|
||||||
total_physical_experts = world_size * num_local_experts
|
ep_rank = torch.distributed.get_rank()
|
||||||
hidden_sizes = [32, 64] # Two different weight matrices
|
device = torch.device(f"cuda:{ep_rank}")
|
||||||
|
|
||||||
# Create old expert indices (with redundancy)
|
num_layers = 2
|
||||||
redundancy_config = create_redundancy_config(
|
num_local_experts = 2
|
||||||
num_logical_experts, total_physical_experts
|
total_physical_experts = world_size * num_local_experts
|
||||||
)
|
num_logical_experts = total_physical_experts // 2 # Some redundancy
|
||||||
|
hidden_sizes = [32, 64]
|
||||||
|
|
||||||
old_indices = create_expert_indices_with_redundancy(
|
# Create redundancy configuration
|
||||||
num_layers,
|
redundancy_config = [2] * num_logical_experts
|
||||||
num_logical_experts,
|
|
||||||
total_physical_experts,
|
|
||||||
redundancy_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create new expert indices (with redundancy)
|
# Same indices - no change
|
||||||
new_redundancy_config = create_redundancy_config(
|
indices = create_expert_indices_with_redundancy(
|
||||||
num_logical_experts, total_physical_experts
|
num_layers, num_logical_experts, total_physical_experts, redundancy_config
|
||||||
)
|
)
|
||||||
new_indices = create_expert_indices_with_redundancy(
|
|
||||||
num_layers,
|
|
||||||
num_logical_experts,
|
|
||||||
total_physical_experts,
|
|
||||||
new_redundancy_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create expert weights
|
expert_weights = create_expert_weights(
|
||||||
expert_weights = create_expert_weights(
|
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
|
||||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Execute weight rearrangement
|
# Save original weights
|
||||||
rearrange_expert_weights_inplace(
|
original_weights = []
|
||||||
old_indices,
|
for layer_weights in expert_weights:
|
||||||
new_indices,
|
layer_copy = []
|
||||||
expert_weights,
|
for weight in layer_weights:
|
||||||
ep_group,
|
layer_copy.append(weight.clone())
|
||||||
is_profile=False,
|
original_weights.append(layer_copy)
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the rearrangement result
|
# Execute rearrangement (should be no change)
|
||||||
verify_expert_weights_after_shuffle(
|
rearrange_expert_weights_inplace(
|
||||||
expert_weights,
|
indices,
|
||||||
new_indices,
|
indices, # Same indices
|
||||||
hidden_sizes,
|
expert_weights,
|
||||||
ep_rank,
|
ep_group,
|
||||||
num_local_experts,
|
is_profile=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
verify_redundant_experts_have_same_weights(
|
# Verify that the weights have not changed
|
||||||
expert_weights,
|
for layer in range(num_layers):
|
||||||
new_indices,
|
for weight_idx in range(len(hidden_sizes)):
|
||||||
hidden_sizes,
|
torch.testing.assert_close(
|
||||||
world_size,
|
expert_weights[layer][weight_idx],
|
||||||
num_local_experts,
|
original_weights[layer][weight_idx],
|
||||||
)
|
msg=f"""Layer {layer}, weight {weight_idx}
|
||||||
|
should remain unchanged""",
|
||||||
distributed_run(worker_fn, world_size)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("world_size", [2, 4])
|
@pytest.mark.parametrize("world_size", [2, 4])
|
||||||
@ -388,62 +446,69 @@ def test_rearrange_expert_weights_no_change(world_size):
|
|||||||
|
|
||||||
if torch.cuda.device_count() < world_size:
|
if torch.cuda.device_count() < world_size:
|
||||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||||
|
distributed_run(_test_rearrange_expert_weights_no_change, world_size)
|
||||||
|
|
||||||
@worker_fn_wrapper
|
|
||||||
def worker_fn():
|
|
||||||
ensure_model_parallel_initialized(
|
|
||||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
|
||||||
)
|
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||||
ep_rank = torch.distributed.get_rank()
|
set_env_vars_and_device(env)
|
||||||
device = torch.device(f"cuda:{ep_rank}")
|
ensure_model_parallel_initialized(
|
||||||
|
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||||
|
)
|
||||||
|
|
||||||
num_layers = 2
|
ep_group = get_tp_group().cpu_group
|
||||||
num_local_experts = 2
|
ep_rank = torch.distributed.get_rank()
|
||||||
total_physical_experts = world_size * num_local_experts
|
device = torch.device(f"cuda:{ep_rank}")
|
||||||
num_logical_experts = total_physical_experts // 2 # Some redundancy
|
|
||||||
hidden_sizes = [32, 64]
|
|
||||||
|
|
||||||
# Create redundancy configuration
|
num_layers = 1
|
||||||
redundancy_config = [2] * num_logical_experts
|
num_local_experts = 2
|
||||||
|
total_physical_experts = world_size * num_local_experts
|
||||||
|
num_logical_experts = total_physical_experts // 2
|
||||||
|
hidden_sizes = [32]
|
||||||
|
|
||||||
# Same indices - no change
|
# Create different index distributions
|
||||||
indices = create_expert_indices_with_redundancy(
|
old_redundancy = create_redundancy_config(
|
||||||
num_layers, num_logical_experts, total_physical_experts, redundancy_config
|
num_logical_experts, total_physical_experts
|
||||||
)
|
)
|
||||||
|
new_redundancy = create_redundancy_config(
|
||||||
|
num_logical_experts, total_physical_experts
|
||||||
|
)
|
||||||
|
|
||||||
expert_weights = create_expert_weights(
|
old_indices = create_expert_indices_with_redundancy(
|
||||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
|
num_layers, num_logical_experts, total_physical_experts, old_redundancy
|
||||||
)
|
)
|
||||||
|
new_indices = create_expert_indices_with_redundancy(
|
||||||
|
num_layers, num_logical_experts, total_physical_experts, new_redundancy
|
||||||
|
)
|
||||||
|
|
||||||
# Save original weights
|
expert_weights = create_expert_weights(
|
||||||
original_weights = []
|
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||||
for layer_weights in expert_weights:
|
)
|
||||||
layer_copy = []
|
|
||||||
for weight in layer_weights:
|
|
||||||
layer_copy.append(weight.clone())
|
|
||||||
original_weights.append(layer_copy)
|
|
||||||
|
|
||||||
# Execute rearrangement (should be no change)
|
# Save original weights
|
||||||
rearrange_expert_weights_inplace(
|
original_weights = []
|
||||||
indices,
|
for layer_weights in expert_weights:
|
||||||
indices, # Same indices
|
layer_copy = []
|
||||||
expert_weights,
|
for weight in layer_weights:
|
||||||
ep_group,
|
layer_copy.append(weight.clone())
|
||||||
is_profile=False,
|
original_weights.append(layer_copy)
|
||||||
)
|
|
||||||
|
|
||||||
# Verify that the weights have not changed
|
# Execute profile mode rearrangement
|
||||||
for layer in range(num_layers):
|
rearrange_expert_weights_inplace(
|
||||||
for weight_idx in range(len(hidden_sizes)):
|
old_indices,
|
||||||
torch.testing.assert_close(
|
new_indices,
|
||||||
expert_weights[layer][weight_idx],
|
expert_weights,
|
||||||
original_weights[layer][weight_idx],
|
ep_group,
|
||||||
msg=f"Layer {layer}, weight {weight_idx} should remain unchanged",
|
is_profile=True, # Profile mode
|
||||||
)
|
)
|
||||||
|
|
||||||
distributed_run(worker_fn, world_size)
|
# In profile mode, the weights should remain unchanged
|
||||||
|
for layer in range(num_layers):
|
||||||
|
for weight_idx in range(len(hidden_sizes)):
|
||||||
|
torch.testing.assert_close(
|
||||||
|
expert_weights[layer][weight_idx],
|
||||||
|
original_weights[layer][weight_idx],
|
||||||
|
msg="In profile mode, the weights should remain unchanged",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("world_size", [2, 4])
|
@pytest.mark.parametrize("world_size", [2, 4])
|
||||||
@ -452,66 +517,4 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
|||||||
|
|
||||||
if torch.cuda.device_count() < world_size:
|
if torch.cuda.device_count() < world_size:
|
||||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||||
|
distributed_run(_test_rearrange_expert_weights_profile_mode, world_size)
|
||||||
@worker_fn_wrapper
|
|
||||||
def worker_fn():
|
|
||||||
ensure_model_parallel_initialized(
|
|
||||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
|
||||||
)
|
|
||||||
|
|
||||||
ep_group = get_tp_group().cpu_group
|
|
||||||
ep_rank = torch.distributed.get_rank()
|
|
||||||
device = torch.device(f"cuda:{ep_rank}")
|
|
||||||
|
|
||||||
num_layers = 1
|
|
||||||
num_local_experts = 2
|
|
||||||
total_physical_experts = world_size * num_local_experts
|
|
||||||
num_logical_experts = total_physical_experts // 2
|
|
||||||
hidden_sizes = [32]
|
|
||||||
|
|
||||||
# Create different index distributions
|
|
||||||
old_redundancy = create_redundancy_config(
|
|
||||||
num_logical_experts, total_physical_experts
|
|
||||||
)
|
|
||||||
new_redundancy = create_redundancy_config(
|
|
||||||
num_logical_experts, total_physical_experts
|
|
||||||
)
|
|
||||||
|
|
||||||
old_indices = create_expert_indices_with_redundancy(
|
|
||||||
num_layers, num_logical_experts, total_physical_experts, old_redundancy
|
|
||||||
)
|
|
||||||
new_indices = create_expert_indices_with_redundancy(
|
|
||||||
num_layers, num_logical_experts, total_physical_experts, new_redundancy
|
|
||||||
)
|
|
||||||
|
|
||||||
expert_weights = create_expert_weights(
|
|
||||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save original weights
|
|
||||||
original_weights = []
|
|
||||||
for layer_weights in expert_weights:
|
|
||||||
layer_copy = []
|
|
||||||
for weight in layer_weights:
|
|
||||||
layer_copy.append(weight.clone())
|
|
||||||
original_weights.append(layer_copy)
|
|
||||||
|
|
||||||
# Execute profile mode rearrangement
|
|
||||||
rearrange_expert_weights_inplace(
|
|
||||||
old_indices,
|
|
||||||
new_indices,
|
|
||||||
expert_weights,
|
|
||||||
ep_group,
|
|
||||||
is_profile=True, # Profile mode
|
|
||||||
)
|
|
||||||
|
|
||||||
# In profile mode, the weights should remain unchanged
|
|
||||||
for layer in range(num_layers):
|
|
||||||
for weight_idx in range(len(hidden_sizes)):
|
|
||||||
torch.testing.assert_close(
|
|
||||||
expert_weights[layer][weight_idx],
|
|
||||||
original_weights[layer][weight_idx],
|
|
||||||
msg="In profile mode, the weights should remain unchanged",
|
|
||||||
)
|
|
||||||
|
|
||||||
distributed_run(worker_fn, world_size)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user