diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index 0ae992ad1110c..b724e112b9dd3 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -204,6 +204,11 @@ def use_torch( return torch.stack(outputs, dim=0) +DTYPES = [torch.float16, torch.bfloat16] +DEVICES = [f"cuda:{0}"] +SEED = [42] + + @pytest.mark.parametrize("num_tokens", [100]) @pytest.mark.parametrize("top_k_num", [6, 12]) @pytest.mark.parametrize("num_experts", [64]) @@ -212,6 +217,9 @@ def use_torch( @pytest.mark.parametrize("K", [2048]) @pytest.mark.parametrize("max_lora_rank", [16, 32, 64]) @pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) def test_fused_moe_lora_kernel( num_tokens, top_k_num, @@ -221,9 +229,12 @@ def test_fused_moe_lora_kernel( K, max_lora_rank, block_size, + dtype, + device, + seed, ): - torch.set_default_device("cuda:0") - current_platform.seed_everything(42) + torch.set_default_device(device) + current_platform.seed_everything(seed) # the number of randomly generated sentences. num_sequences = 10 # generate data @@ -240,7 +251,7 @@ def test_fused_moe_lora_kernel( max_lora_rank, K, ), - dtype=torch.bfloat16, + dtype=dtype, ) ] lora_b_stacked = [ @@ -251,7 +262,7 @@ def test_fused_moe_lora_kernel( N, max_lora_rank, ), - dtype=torch.bfloat16, + dtype=dtype, ) ] hidden_states = torch.rand( @@ -259,11 +270,11 @@ def test_fused_moe_lora_kernel( num_tokens, K, ), - dtype=torch.bfloat16, + dtype=dtype, ) # fused_moe_lora_kernel output - output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16) + output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype) use_fused_moe_lora_kernel( topk_ids, topk_weights, diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 2031ade64b5fc..e681f3882908e 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -2,9 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch -import triton -import triton.language as tl +from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op _LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} @@ -110,7 +109,7 @@ def _fused_moe_lora_kernel( # get a_ptr,b_ptr,c_ptr cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size - cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(tl.bfloat16)) + cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N @@ -154,7 +153,7 @@ def _fused_moe_lora_kernel( moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] - accumulator = accumulator.to(tl.bfloat16) + accumulator = accumulator.to(c_ptr.dtype.element_ty) # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] @@ -205,6 +204,10 @@ def _fused_moe_lora( assert output.shape[0] == topk_weights.shape[0] assert top_k_num == topk_weights.shape[1] + for lora_a, lora_b in zip(lora_a_stacked, lora_b_stacked): + assert lora_a.dtype == lora_b.dtype == output.dtype == qcurr_hidden_states.dtype + assert lora_a.dtype in [torch.float16, torch.bfloat16] + device = qcurr_hidden_states.device num_slices = len(lora_a_stacked) @@ -227,9 +230,9 @@ def _fused_moe_lora( num_tokens = M * top_k_num w1_output_dim_size = w1_lora_b_stacked.shape[2] - lora_intermediate_cache1 = torch.zeros( + lora_intermediate_cache1 = torch.empty( (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)), - dtype=torch.bfloat16, + dtype=output.dtype, device=device, ) @@ -288,10 +291,6 @@ def _fused_moe_lora( K = max_lora_rank N = w1_output_dim_size - # a_intermediate_cache1 = a_intermediate_cache1.view( - # M, -1, a_intermediate_cache1.shape[3] - # ) - a_intermediate_cache1 = a_intermediate_cache1.view( -1, a_intermediate_cache1.shape[3] )