mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 07:07:03 +08:00
[Kernel] Enable moe LoRA kernel support FP16 (#27468)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
a663f6ae64
commit
f4e8154076
@ -204,6 +204,11 @@ def use_torch(
|
|||||||
return torch.stack(outputs, dim=0)
|
return torch.stack(outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
DEVICES = [f"cuda:{0}"]
|
||||||
|
SEED = [42]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_tokens", [100])
|
@pytest.mark.parametrize("num_tokens", [100])
|
||||||
@pytest.mark.parametrize("top_k_num", [6, 12])
|
@pytest.mark.parametrize("top_k_num", [6, 12])
|
||||||
@pytest.mark.parametrize("num_experts", [64])
|
@pytest.mark.parametrize("num_experts", [64])
|
||||||
@ -212,6 +217,9 @@ def use_torch(
|
|||||||
@pytest.mark.parametrize("K", [2048])
|
@pytest.mark.parametrize("K", [2048])
|
||||||
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
|
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
|
||||||
@pytest.mark.parametrize("block_size", [16])
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
def test_fused_moe_lora_kernel(
|
def test_fused_moe_lora_kernel(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
top_k_num,
|
top_k_num,
|
||||||
@ -221,9 +229,12 @@ def test_fused_moe_lora_kernel(
|
|||||||
K,
|
K,
|
||||||
max_lora_rank,
|
max_lora_rank,
|
||||||
block_size,
|
block_size,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
seed,
|
||||||
):
|
):
|
||||||
torch.set_default_device("cuda:0")
|
torch.set_default_device(device)
|
||||||
current_platform.seed_everything(42)
|
current_platform.seed_everything(seed)
|
||||||
# the number of randomly generated sentences.
|
# the number of randomly generated sentences.
|
||||||
num_sequences = 10
|
num_sequences = 10
|
||||||
# generate data
|
# generate data
|
||||||
@ -240,7 +251,7 @@ def test_fused_moe_lora_kernel(
|
|||||||
max_lora_rank,
|
max_lora_rank,
|
||||||
K,
|
K,
|
||||||
),
|
),
|
||||||
dtype=torch.bfloat16,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
lora_b_stacked = [
|
lora_b_stacked = [
|
||||||
@ -251,7 +262,7 @@ def test_fused_moe_lora_kernel(
|
|||||||
N,
|
N,
|
||||||
max_lora_rank,
|
max_lora_rank,
|
||||||
),
|
),
|
||||||
dtype=torch.bfloat16,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
hidden_states = torch.rand(
|
hidden_states = torch.rand(
|
||||||
@ -259,11 +270,11 @@ def test_fused_moe_lora_kernel(
|
|||||||
num_tokens,
|
num_tokens,
|
||||||
K,
|
K,
|
||||||
),
|
),
|
||||||
dtype=torch.bfloat16,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# fused_moe_lora_kernel output
|
# fused_moe_lora_kernel output
|
||||||
output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16)
|
output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype)
|
||||||
use_fused_moe_lora_kernel(
|
use_fused_moe_lora_kernel(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
|
|||||||
@ -2,9 +2,8 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
||||||
@ -110,7 +109,7 @@ def _fused_moe_lora_kernel(
|
|||||||
|
|
||||||
# get a_ptr,b_ptr,c_ptr
|
# get a_ptr,b_ptr,c_ptr
|
||||||
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
|
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
|
||||||
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(tl.bfloat16))
|
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
|
||||||
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
||||||
|
|
||||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||||
@ -154,7 +153,7 @@ def _fused_moe_lora_kernel(
|
|||||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||||
accumulator = accumulator * moe_weight[:, None]
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
|
||||||
accumulator = accumulator.to(tl.bfloat16)
|
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||||
# Write back the block of the output
|
# Write back the block of the output
|
||||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||||
@ -205,6 +204,10 @@ def _fused_moe_lora(
|
|||||||
assert output.shape[0] == topk_weights.shape[0]
|
assert output.shape[0] == topk_weights.shape[0]
|
||||||
assert top_k_num == topk_weights.shape[1]
|
assert top_k_num == topk_weights.shape[1]
|
||||||
|
|
||||||
|
for lora_a, lora_b in zip(lora_a_stacked, lora_b_stacked):
|
||||||
|
assert lora_a.dtype == lora_b.dtype == output.dtype == qcurr_hidden_states.dtype
|
||||||
|
assert lora_a.dtype in [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
device = qcurr_hidden_states.device
|
device = qcurr_hidden_states.device
|
||||||
num_slices = len(lora_a_stacked)
|
num_slices = len(lora_a_stacked)
|
||||||
|
|
||||||
@ -227,9 +230,9 @@ def _fused_moe_lora(
|
|||||||
num_tokens = M * top_k_num
|
num_tokens = M * top_k_num
|
||||||
w1_output_dim_size = w1_lora_b_stacked.shape[2]
|
w1_output_dim_size = w1_lora_b_stacked.shape[2]
|
||||||
|
|
||||||
lora_intermediate_cache1 = torch.zeros(
|
lora_intermediate_cache1 = torch.empty(
|
||||||
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
|
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
|
||||||
dtype=torch.bfloat16,
|
dtype=output.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -288,10 +291,6 @@ def _fused_moe_lora(
|
|||||||
K = max_lora_rank
|
K = max_lora_rank
|
||||||
N = w1_output_dim_size
|
N = w1_output_dim_size
|
||||||
|
|
||||||
# a_intermediate_cache1 = a_intermediate_cache1.view(
|
|
||||||
# M, -1, a_intermediate_cache1.shape[3]
|
|
||||||
# )
|
|
||||||
|
|
||||||
a_intermediate_cache1 = a_intermediate_cache1.view(
|
a_intermediate_cache1 = a_intermediate_cache1.view(
|
||||||
-1, a_intermediate_cache1.shape[3]
|
-1, a_intermediate_cache1.shape[3]
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user