mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 15:46:59 +08:00
add support for --fully-sharded-loras in fused_moe (#28761)
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
ae4821a108
commit
d69062c67a
@ -1,13 +1,25 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.utils import multi_gpu_test
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.distributed import (
|
||||||
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
from vllm.lora.ops.triton_ops import fused_moe_lora
|
from vllm.lora.ops.triton_ops import fused_moe_lora
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.network_utils import get_open_port
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -122,6 +134,8 @@ def use_fused_moe_lora_kernel(
|
|||||||
max_loras,
|
max_loras,
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
|
fully_sharded=False,
|
||||||
|
offset=0,
|
||||||
):
|
):
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||||
@ -195,10 +209,10 @@ def use_fused_moe_lora_kernel(
|
|||||||
config["NUM_STAGES"],
|
config["NUM_STAGES"],
|
||||||
config["SPLIT_K"],
|
config["SPLIT_K"],
|
||||||
mul_routed_weight,
|
mul_routed_weight,
|
||||||
|
fully_sharded=fully_sharded,
|
||||||
|
offset=offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def use_torch(
|
def use_torch(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -317,3 +331,193 @@ def test_fused_moe_lora_kernel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)
|
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
@pytest.mark.parametrize("num_tokens", [100])
|
||||||
|
@pytest.mark.parametrize("top_k_num", [6])
|
||||||
|
@pytest.mark.parametrize("num_experts", [64])
|
||||||
|
@pytest.mark.parametrize("max_loras", [4])
|
||||||
|
@pytest.mark.parametrize("N", [1408])
|
||||||
|
@pytest.mark.parametrize("K", [2048])
|
||||||
|
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
|
||||||
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
|
@pytest.mark.parametrize("column_parallel", [True, False])
|
||||||
|
def test_fused_moe_lora_kernel_fully_sharded(
|
||||||
|
num_tokens,
|
||||||
|
top_k_num,
|
||||||
|
num_experts,
|
||||||
|
max_loras,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
max_lora_rank,
|
||||||
|
block_size,
|
||||||
|
dtype,
|
||||||
|
seed,
|
||||||
|
column_parallel,
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
# the number of randomly generated sentences.
|
||||||
|
num_sequences = 10
|
||||||
|
# generate data
|
||||||
|
topk_ids, topk_weights, token_lora_mapping = sample_data(
|
||||||
|
num_tokens, num_sequences, max_loras, num_experts, top_k_num
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_torch_spawn(fn, nprocs):
|
||||||
|
torch.multiprocessing.spawn(
|
||||||
|
fn,
|
||||||
|
args=(
|
||||||
|
nprocs,
|
||||||
|
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||||
|
dtype,
|
||||||
|
seed,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
num_tokens,
|
||||||
|
topk_ids,
|
||||||
|
topk_weights,
|
||||||
|
token_lora_mapping,
|
||||||
|
max_lora_rank,
|
||||||
|
top_k_num,
|
||||||
|
max_loras,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
column_parallel,
|
||||||
|
),
|
||||||
|
nprocs=nprocs,
|
||||||
|
)
|
||||||
|
|
||||||
|
run_torch_spawn(use_fused_moe_lora_kernel_tensor_parallel, nprocs=2)
|
||||||
|
|
||||||
|
|
||||||
|
def use_fused_moe_lora_kernel_tensor_parallel(
|
||||||
|
local_rank,
|
||||||
|
world_size,
|
||||||
|
init_method,
|
||||||
|
dtype,
|
||||||
|
seed,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
num_tokens,
|
||||||
|
topk_ids,
|
||||||
|
topk_weights,
|
||||||
|
token_lora_mapping,
|
||||||
|
max_lora_rank,
|
||||||
|
top_k_num,
|
||||||
|
max_loras,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
column_parallel,
|
||||||
|
):
|
||||||
|
def _get_shard_slice(shard_size):
|
||||||
|
return slice(local_rank * shard_size, (local_rank + 1) * shard_size)
|
||||||
|
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=world_size,
|
||||||
|
rank=local_rank,
|
||||||
|
local_rank=local_rank,
|
||||||
|
distributed_init_method=init_method,
|
||||||
|
)
|
||||||
|
initialize_model_parallel(world_size, 1)
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
input_dim = K if column_parallel else N
|
||||||
|
output_dim = N if column_parallel else K
|
||||||
|
|
||||||
|
# init lora weights
|
||||||
|
lora_a = torch.rand(
|
||||||
|
(
|
||||||
|
max_loras,
|
||||||
|
num_experts,
|
||||||
|
max_lora_rank,
|
||||||
|
input_dim,
|
||||||
|
),
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
lora_b = torch.rand(
|
||||||
|
(
|
||||||
|
max_loras,
|
||||||
|
num_experts,
|
||||||
|
output_dim,
|
||||||
|
max_lora_rank,
|
||||||
|
),
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.rand(
|
||||||
|
(
|
||||||
|
num_tokens,
|
||||||
|
input_dim,
|
||||||
|
),
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch.zeros((num_tokens, top_k_num, output_dim), dtype=dtype)
|
||||||
|
topk_ids = topk_ids.to(device)
|
||||||
|
topk_weights = topk_weights.to(device)
|
||||||
|
token_lora_mapping = token_lora_mapping.to(device)
|
||||||
|
|
||||||
|
ref_output = use_torch(
|
||||||
|
hidden_states,
|
||||||
|
token_lora_mapping,
|
||||||
|
topk_ids,
|
||||||
|
[lora_a],
|
||||||
|
[lora_b],
|
||||||
|
top_k_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
if column_parallel:
|
||||||
|
# Column parallel (e.g. gate_up_proj): LoRA A is sliced along the rank dim,
|
||||||
|
# and Lora B is sliced along the output dim
|
||||||
|
lora_a_shard_size = max_lora_rank // tp_size
|
||||||
|
lora_a = lora_a[:, :, _get_shard_slice(lora_a_shard_size), :]
|
||||||
|
max_lora_rank = lora_a_shard_size
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
lora_b_shard_size = output_dim // tp_size
|
||||||
|
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
|
||||||
|
output = output[:, :, _get_shard_slice(lora_b_shard_size)].contiguous()
|
||||||
|
else:
|
||||||
|
# Row parallel (e.g. down proj): LoRA A is sliced along the input dim,
|
||||||
|
# and LoRA B is sliced along the output dim
|
||||||
|
lora_a_shard_size = input_dim // tp_size
|
||||||
|
lora_a = lora_a[:, :, :, _get_shard_slice(lora_a_shard_size)]
|
||||||
|
hidden_states = hidden_states[:, _get_shard_slice(lora_a_shard_size)]
|
||||||
|
|
||||||
|
lora_b_shard_size = output_dim // tp_size
|
||||||
|
lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
|
||||||
|
offset = lora_b_shard_size * local_rank
|
||||||
|
|
||||||
|
use_fused_moe_lora_kernel(
|
||||||
|
topk_ids,
|
||||||
|
topk_weights,
|
||||||
|
token_lora_mapping,
|
||||||
|
max_lora_rank,
|
||||||
|
top_k_num,
|
||||||
|
[lora_a],
|
||||||
|
[lora_b],
|
||||||
|
hidden_states,
|
||||||
|
output,
|
||||||
|
max_loras,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
fully_sharded=True,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
if column_parallel:
|
||||||
|
output = tensor_model_parallel_all_gather(output)
|
||||||
|
else:
|
||||||
|
output = tensor_model_parallel_all_reduce(output)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-1)
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
@ -111,8 +113,9 @@ def test_olmoe_lora_mixed(olmoe_lora_files):
|
|||||||
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
|
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
def test_olmoe_lora_tp2(olmoe_lora_files):
|
def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras):
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
@ -122,14 +125,16 @@ def test_olmoe_lora_tp2(olmoe_lora_files):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
|
fully_sharded_loras=fully_sharded_loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
generate_and_test(llm, olmoe_lora_files, lora_id=1)
|
generate_and_test(llm, olmoe_lora_files, lora_id=1)
|
||||||
generate_and_test(llm, olmoe_lora_files, lora_id=2)
|
generate_and_test(llm, olmoe_lora_files, lora_id=2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
|
||||||
@multi_gpu_test(num_gpus=4)
|
@multi_gpu_test(num_gpus=4)
|
||||||
def test_olmoe_lora_tp4(olmoe_lora_files):
|
def test_olmoe_lora_tp4(olmoe_lora_files, fully_sharded_loras):
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
@ -139,6 +144,7 @@ def test_olmoe_lora_tp4(olmoe_lora_files):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
|
fully_sharded_loras=fully_sharded_loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
generate_and_test(llm, olmoe_lora_files, lora_id=1)
|
generate_and_test(llm, olmoe_lora_files, lora_id=1)
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
|
from vllm.distributed.utils import divide
|
||||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||||
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
|
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@ -205,6 +206,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
shrink_config, ## pass the shrink config
|
shrink_config, ## pass the shrink config
|
||||||
expand_config, ## pass the expand config
|
expand_config, ## pass the expand config
|
||||||
self.adapter_enabled,
|
self.adapter_enabled,
|
||||||
|
fully_sharded=self.fully_sharded,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
@ -250,7 +252,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
|
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
|
||||||
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
|
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
|
||||||
intermediate_cache3 = args[0]
|
intermediate_cache3 = args[0]
|
||||||
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
|
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
|
||||||
|
|
||||||
|
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
|
||||||
|
|
||||||
self.punica_wrapper.add_lora_fused_moe(
|
self.punica_wrapper.add_lora_fused_moe(
|
||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
intermediate_cache2,
|
intermediate_cache2,
|
||||||
@ -266,6 +271,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
expand_config, ## pass the expand config
|
expand_config, ## pass the expand config
|
||||||
self.adapter_enabled,
|
self.adapter_enabled,
|
||||||
True,
|
True,
|
||||||
|
fully_sharded=self.fully_sharded,
|
||||||
|
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
@ -294,6 +301,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
model_config: PretrainedConfig | None = None,
|
model_config: PretrainedConfig | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes lora matrices."""
|
"""Initializes lora matrices."""
|
||||||
|
self.fully_sharded = lora_config.fully_sharded_loras
|
||||||
|
|
||||||
self.adapter_enabled = torch.tensor(
|
self.adapter_enabled = torch.tensor(
|
||||||
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
||||||
@ -303,7 +311,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.local_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank
|
||||||
|
if not self.fully_sharded
|
||||||
|
else divide(lora_config.max_lora_rank, self.tp_size),
|
||||||
self.base_layer.hidden_size,
|
self.base_layer.hidden_size,
|
||||||
),
|
),
|
||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
@ -334,7 +344,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.local_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
self.base_layer.hidden_size,
|
self.base_layer.hidden_size
|
||||||
|
if not self.fully_sharded
|
||||||
|
else divide(self.base_layer.hidden_size, self.tp_size),
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
),
|
),
|
||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
@ -345,7 +357,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
self.base_layer.local_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank
|
||||||
|
if not self.fully_sharded
|
||||||
|
else divide(lora_config.max_lora_rank, self.tp_size),
|
||||||
self.base_layer.hidden_size,
|
self.base_layer.hidden_size,
|
||||||
),
|
),
|
||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
@ -419,6 +433,20 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
|
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
|
||||||
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
|
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
|
||||||
|
|
||||||
|
if self.fully_sharded:
|
||||||
|
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
|
||||||
|
# and W2 B along the hidden_size dim.
|
||||||
|
w13_shard_size = self.w1_lora_a_stacked[index, eid].shape[0]
|
||||||
|
w13_start_idx = self.tp_rank * w13_shard_size
|
||||||
|
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
|
||||||
|
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
|
||||||
|
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
|
||||||
|
|
||||||
|
w2_shard_size = self.w2_lora_b_stacked[index, eid].shape[0]
|
||||||
|
w2_start_idx = self.tp_rank * w2_shard_size
|
||||||
|
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
|
||||||
|
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
|
||||||
|
|
||||||
self.w1_lora_a_stacked[
|
self.w1_lora_a_stacked[
|
||||||
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
|
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
|
||||||
].copy_(w1_lora_a, non_blocking=True)
|
].copy_(w1_lora_a, non_blocking=True)
|
||||||
|
|||||||
@ -3,6 +3,10 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed import (
|
||||||
|
tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -311,6 +315,7 @@ def _fused_moe_lora_expand(
|
|||||||
num_stages: int,
|
num_stages: int,
|
||||||
split_k: int,
|
split_k: int,
|
||||||
mul_routed_weight: bool = False,
|
mul_routed_weight: bool = False,
|
||||||
|
offset: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
b_ptr = _get_ptr(lora_b_stacked, device)
|
b_ptr = _get_ptr(lora_b_stacked, device)
|
||||||
K = max_lora_rank
|
K = max_lora_rank
|
||||||
@ -380,7 +385,7 @@ def _fused_moe_lora_expand(
|
|||||||
**expand_config,
|
**expand_config,
|
||||||
)
|
)
|
||||||
for i in range(num_slices):
|
for i in range(num_slices):
|
||||||
output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i]
|
output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i]
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -416,6 +421,8 @@ def _fused_moe_lora(
|
|||||||
expand_num_stages: int,
|
expand_num_stages: int,
|
||||||
expand_split_k: int,
|
expand_split_k: int,
|
||||||
mul_routed_weight: bool = False,
|
mul_routed_weight: bool = False,
|
||||||
|
fully_sharded: bool = False,
|
||||||
|
offset: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
|
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
|
||||||
assert (
|
assert (
|
||||||
@ -430,7 +437,6 @@ def _fused_moe_lora(
|
|||||||
== expert_ids.shape[0]
|
== expert_ids.shape[0]
|
||||||
== num_tokens_post_padded.shape[0]
|
== num_tokens_post_padded.shape[0]
|
||||||
)
|
)
|
||||||
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
|
|
||||||
assert output.shape[0] == topk_weights.shape[0]
|
assert output.shape[0] == topk_weights.shape[0]
|
||||||
assert top_k_num == topk_weights.shape[1]
|
assert top_k_num == topk_weights.shape[1]
|
||||||
device = qcurr_hidden_states.device
|
device = qcurr_hidden_states.device
|
||||||
@ -480,6 +486,19 @@ def _fused_moe_lora(
|
|||||||
mul_routed_weight,
|
mul_routed_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if fully_sharded:
|
||||||
|
if max_lora_rank == w1_lora_b_stacked.shape[-1]:
|
||||||
|
a_intermediate_cache1 = tensor_model_parallel_all_reduce(
|
||||||
|
a_intermediate_cache1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
a_intermediate_cache1 = tensor_model_parallel_all_gather(
|
||||||
|
a_intermediate_cache1
|
||||||
|
)
|
||||||
|
|
||||||
|
# reset max_lora_rank to the full rank after allgather
|
||||||
|
max_lora_rank = a_intermediate_cache1.shape[-1]
|
||||||
|
|
||||||
_fused_moe_lora_expand(
|
_fused_moe_lora_expand(
|
||||||
output,
|
output,
|
||||||
a_intermediate_cache1,
|
a_intermediate_cache1,
|
||||||
@ -510,6 +529,7 @@ def _fused_moe_lora(
|
|||||||
expand_num_stages,
|
expand_num_stages,
|
||||||
expand_split_k,
|
expand_split_k,
|
||||||
mul_routed_weight,
|
mul_routed_weight,
|
||||||
|
offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -483,6 +483,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
expand_config,
|
expand_config,
|
||||||
adapter_enabled: torch.Tensor,
|
adapter_enabled: torch.Tensor,
|
||||||
mul_routed_weight=False,
|
mul_routed_weight=False,
|
||||||
|
fully_sharded: bool = False,
|
||||||
|
offset: int = 0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Performs a fused forward computation for LoRA of
|
Performs a fused forward computation for LoRA of
|
||||||
|
|||||||
@ -375,6 +375,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
expand_config,
|
expand_config,
|
||||||
adapter_enabled: torch.Tensor,
|
adapter_enabled: torch.Tensor,
|
||||||
mul_routed_weight=False,
|
mul_routed_weight=False,
|
||||||
|
fully_sharded: bool = False,
|
||||||
|
offset: int = 0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
|
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
|
||||||
@ -408,4 +410,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
expand_config.get("NUM_STAGES", 3),
|
expand_config.get("NUM_STAGES", 3),
|
||||||
expand_config.get("SPLIT_K", 1),
|
expand_config.get("SPLIT_K", 1),
|
||||||
mul_routed_weight,
|
mul_routed_weight,
|
||||||
|
fully_sharded,
|
||||||
|
offset,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user