add support for --fully-sharded-loras in fused_moe (#28761)

Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
gnovack 2025-11-19 00:32:00 -08:00 committed by GitHub
parent ae4821a108
commit d69062c67a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 274 additions and 10 deletions

View File

@ -1,13 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import random
import pytest
import torch
from tests.utils import multi_gpu_test
from vllm import _custom_ops as ops
from vllm.distributed import (
init_distributed_environment,
initialize_model_parallel,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.lora.ops.triton_ops import fused_moe_lora
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
@pytest.fixture(autouse=True)
@ -122,6 +134,8 @@ def use_fused_moe_lora_kernel(
max_loras,
num_experts,
block_size,
fully_sharded=False,
offset=0,
):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
@ -195,10 +209,10 @@ def use_fused_moe_lora_kernel(
config["NUM_STAGES"],
config["SPLIT_K"],
mul_routed_weight,
fully_sharded=fully_sharded,
offset=offset,
)
return output
def use_torch(
hidden_states,
@ -317,3 +331,193 @@ def test_fused_moe_lora_kernel(
)
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("num_tokens", [100])
@pytest.mark.parametrize("top_k_num", [6])
@pytest.mark.parametrize("num_experts", [64])
@pytest.mark.parametrize("max_loras", [4])
@pytest.mark.parametrize("N", [1408])
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("column_parallel", [True, False])
def test_fused_moe_lora_kernel_fully_sharded(
num_tokens,
top_k_num,
num_experts,
max_loras,
N,
K,
max_lora_rank,
block_size,
dtype,
seed,
column_parallel,
):
current_platform.seed_everything(seed)
# the number of randomly generated sentences.
num_sequences = 10
# generate data
topk_ids, topk_weights, token_lora_mapping = sample_data(
num_tokens, num_sequences, max_loras, num_experts, top_k_num
)
def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(
fn,
args=(
nprocs,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
dtype,
seed,
N,
K,
num_tokens,
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
max_loras,
num_experts,
block_size,
column_parallel,
),
nprocs=nprocs,
)
run_torch_spawn(use_fused_moe_lora_kernel_tensor_parallel, nprocs=2)
def use_fused_moe_lora_kernel_tensor_parallel(
local_rank,
world_size,
init_method,
dtype,
seed,
N,
K,
num_tokens,
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
max_loras,
num_experts,
block_size,
column_parallel,
):
def _get_shard_slice(shard_size):
return slice(local_rank * shard_size, (local_rank + 1) * shard_size)
current_platform.seed_everything(seed)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
init_distributed_environment(
world_size=world_size,
rank=local_rank,
local_rank=local_rank,
distributed_init_method=init_method,
)
initialize_model_parallel(world_size, 1)
tp_size = get_tensor_model_parallel_world_size()
input_dim = K if column_parallel else N
output_dim = N if column_parallel else K
# init lora weights
lora_a = torch.rand(
(
max_loras,
num_experts,
max_lora_rank,
input_dim,
),
dtype=dtype,
)
lora_b = torch.rand(
(
max_loras,
num_experts,
output_dim,
max_lora_rank,
),
dtype=dtype,
)
hidden_states = torch.rand(
(
num_tokens,
input_dim,
),
dtype=dtype,
)
output = torch.zeros((num_tokens, top_k_num, output_dim), dtype=dtype)
topk_ids = topk_ids.to(device)
topk_weights = topk_weights.to(device)
token_lora_mapping = token_lora_mapping.to(device)
ref_output = use_torch(
hidden_states,
token_lora_mapping,
topk_ids,
[lora_a],
[lora_b],
top_k_num,
)
if column_parallel:
# Column parallel (e.g. gate_up_proj): LoRA A is sliced along the rank dim,
# and Lora B is sliced along the output dim
lora_a_shard_size = max_lora_rank // tp_size
lora_a = lora_a[:, :, _get_shard_slice(lora_a_shard_size), :]
max_lora_rank = lora_a_shard_size
offset = 0
lora_b_shard_size = output_dim // tp_size
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
output = output[:, :, _get_shard_slice(lora_b_shard_size)].contiguous()
else:
# Row parallel (e.g. down proj): LoRA A is sliced along the input dim,
# and LoRA B is sliced along the output dim
lora_a_shard_size = input_dim // tp_size
lora_a = lora_a[:, :, :, _get_shard_slice(lora_a_shard_size)]
hidden_states = hidden_states[:, _get_shard_slice(lora_a_shard_size)]
lora_b_shard_size = output_dim // tp_size
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
offset = lora_b_shard_size * local_rank
use_fused_moe_lora_kernel(
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
[lora_a],
[lora_b],
hidden_states,
output,
max_loras,
num_experts,
block_size,
fully_sharded=True,
offset=offset,
)
if column_parallel:
output = tensor_model_parallel_all_gather(output)
else:
output = tensor_model_parallel_all_reduce(output)
torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-1)

View File

@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm
from vllm.lora.request import LoRARequest
@ -111,8 +113,9 @@ def test_olmoe_lora_mixed(olmoe_lora_files):
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
@multi_gpu_test(num_gpus=2)
def test_olmoe_lora_tp2(olmoe_lora_files):
def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
@ -122,14 +125,16 @@ def test_olmoe_lora_tp2(olmoe_lora_files):
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=2,
fully_sharded_loras=fully_sharded_loras,
)
generate_and_test(llm, olmoe_lora_files, lora_id=1)
generate_and_test(llm, olmoe_lora_files, lora_id=2)
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
@multi_gpu_test(num_gpus=4)
def test_olmoe_lora_tp4(olmoe_lora_files):
def test_olmoe_lora_tp4(olmoe_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
@ -139,6 +144,7 @@ def test_olmoe_lora_tp4(olmoe_lora_files):
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=4,
fully_sharded_loras=fully_sharded_loras,
)
generate_and_test(llm, olmoe_lora_files, lora_id=1)

View File

@ -12,6 +12,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.distributed.utils import divide
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -205,6 +206,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
fully_sharded=self.fully_sharded,
)
result = func(*args, **kwargs)
@ -250,7 +252,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3,
intermediate_cache2,
@ -266,6 +271,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
expand_config, ## pass the expand config
self.adapter_enabled,
True,
fully_sharded=self.fully_sharded,
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
)
result = func(*args, **kwargs)
@ -294,6 +301,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
@ -303,7 +311,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
lora_config.max_lora_rank
if not self.fully_sharded
else divide(lora_config.max_lora_rank, self.tp_size),
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
@ -334,7 +344,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size,
self.base_layer.hidden_size
if not self.fully_sharded
else divide(self.base_layer.hidden_size, self.tp_size),
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
@ -345,7 +357,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
lora_config.max_lora_rank
if not self.fully_sharded
else divide(lora_config.max_lora_rank, self.tp_size),
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
@ -419,6 +433,20 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
if self.fully_sharded:
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
# and W2 B along the hidden_size dim.
w13_shard_size = self.w1_lora_a_stacked[index, eid].shape[0]
w13_start_idx = self.tp_rank * w13_shard_size
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
w2_shard_size = self.w2_lora_b_stacked[index, eid].shape[0]
w2_start_idx = self.tp_rank * w2_shard_size
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
self.w1_lora_a_stacked[
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
].copy_(w1_lora_a, non_blocking=True)

View File

@ -3,6 +3,10 @@
import torch
from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
@ -311,6 +315,7 @@ def _fused_moe_lora_expand(
num_stages: int,
split_k: int,
mul_routed_weight: bool = False,
offset: int = 0,
) -> None:
b_ptr = _get_ptr(lora_b_stacked, device)
K = max_lora_rank
@ -380,7 +385,7 @@ def _fused_moe_lora_expand(
**expand_config,
)
for i in range(num_slices):
output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i]
output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i]
@torch.inference_mode()
@ -416,6 +421,8 @@ def _fused_moe_lora(
expand_num_stages: int,
expand_split_k: int,
mul_routed_weight: bool = False,
fully_sharded: bool = False,
offset: int = 0,
) -> None:
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
assert (
@ -430,7 +437,6 @@ def _fused_moe_lora(
== expert_ids.shape[0]
== num_tokens_post_padded.shape[0]
)
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
assert output.shape[0] == topk_weights.shape[0]
assert top_k_num == topk_weights.shape[1]
device = qcurr_hidden_states.device
@ -480,6 +486,19 @@ def _fused_moe_lora(
mul_routed_weight,
)
if fully_sharded:
if max_lora_rank == w1_lora_b_stacked.shape[-1]:
a_intermediate_cache1 = tensor_model_parallel_all_reduce(
a_intermediate_cache1
)
else:
a_intermediate_cache1 = tensor_model_parallel_all_gather(
a_intermediate_cache1
)
# reset max_lora_rank to the full rank after allgather
max_lora_rank = a_intermediate_cache1.shape[-1]
_fused_moe_lora_expand(
output,
a_intermediate_cache1,
@ -510,6 +529,7 @@ def _fused_moe_lora(
expand_num_stages,
expand_split_k,
mul_routed_weight,
offset,
)

View File

@ -483,6 +483,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
expand_config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
fully_sharded: bool = False,
offset: int = 0,
):
"""
Performs a fused forward computation for LoRA of

View File

@ -375,6 +375,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
expand_config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
fully_sharded: bool = False,
offset: int = 0,
):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
@ -408,4 +410,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
expand_config.get("NUM_STAGES", 3),
expand_config.get("SPLIT_K", 1),
mul_routed_weight,
fully_sharded,
offset,
)