Modularize fused experts and integrate PPLX kernels (#15956)

This commit is contained in:
bnellnm 2025-05-14 16:11:54 -04:00 committed by GitHub
parent 418d2f8bfb
commit f9c069c85e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 3830 additions and 660 deletions

View File

@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \

View File

@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))

View File

@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
} }
if (use_global_memory) { if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors // tensors
@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer.data_ptr<int32_t>()); cumsum_buffer.data_ptr<int32_t>());
}); });
} else if (use_i16) { } else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem // set dynamic shared mem
auto kernel = auto kernel =
@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel()); topk_ids.numel());
}); });
} else { } else {
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel = auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>; vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK(num_experts == 256, TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3."); "sgl_moe_align_block_size kernel only supports deepseek v3.");
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors // calc needed amount of shared mem for `cumsum` tensors
auto options_int = auto options_int =

View File

@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
} }
} }
template <int TPB> template <int TPB, typename IndType>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, __launch_bounds__(TPB) __global__ void moeTopK(
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) const float* inputs_after_softmax,
const bool* finished,
float* output,
IndType* indices,
int* source_rows,
const int num_experts,
const int k,
const int start_expert,
const int end_expert)
{ {
using cub_kvp = cub::KeyValuePair<int, float>; using cub_kvp = cub::KeyValuePair<int, float>;
@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k. 2) This implementation assumes k is small, but will work for any k.
*/ */
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG> template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert) int* source_rows, const int k, const int start_expert, const int end_expert)
{ {
// We begin by enforcing compile time assertions and setting up compile time constants. // We begin by enforcing compile time assertions and setting up compile time constants.
@ -397,8 +405,8 @@ struct TopkConstants
}; };
} // namespace detail } // namespace detail
template <int EXPERTS, int WARPS_PER_TB> template <int EXPERTS, int WARPS_PER_TB, typename IndType>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{ {
static constexpr std::size_t MAX_BYTES_PER_LDG = 16; static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \ token_expert_indices, num_tokens, topk, 0, num_experts, \
stream); stream);
template <typename IndType>
void topkGatingSoftmaxKernelLauncher( void topkGatingSoftmaxKernelLauncher(
const float* gating_output, const float* gating_output,
float* topk_weights, float* topk_weights,
int* topk_indicies, IndType* topk_indicies,
int* token_expert_indices, int* token_expert_indices,
float* softmax_workspace, float* softmax_workspace,
const int num_tokens, const int num_tokens,
@ -493,14 +502,32 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(), if(topk_indices.scalar_type() == at::ScalarType::Int)
topk_weights.data_ptr<float>(), {
topk_indices.data_ptr<int>(), vllm::moe::topkGatingSoftmaxKernelLauncher(
token_expert_indices.data_ptr<int>(), gating_output.data_ptr<float>(),
softmax_workspace.data_ptr<float>(), topk_weights.data_ptr<float>(),
num_tokens, topk_indices.data_ptr<int>(),
num_experts, token_expert_indices.data_ptr<int>(),
topk, softmax_workspace.data_ptr<float>(),
stream); num_tokens,
num_experts,
topk,
stream);
}
else
{
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
} }

View File

@ -65,11 +65,17 @@ def parse_args():
type=int, type=int,
default=0, default=0,
help="Master node port") help="Master node port")
parser.add_argument("--enforce-eager",
action='store_true',
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action='store_true',
help="Trust remote code.")
return parser.parse_args() return parser.parse_args()
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank): dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size) os.environ["VLLM_DP_SIZE"] = str(dp_size)
@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
max_tokens=[16, 20][global_dp_rank % 2]) max_tokens=[16, 20][global_dp_rank % 2])
# Create an LLM. # Create an LLM.
llm = LLM(model=model, llm = LLM(
tensor_parallel_size=GPUs_per_dp_rank, model=model,
enforce_eager=True, tensor_parallel_size=GPUs_per_dp_rank,
enable_expert_parallel=True) enforce_eager=enforce_eager,
enable_expert_parallel=True,
trust_remote_code=trust_remote_code,
)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
@ -155,7 +164,8 @@ if __name__ == "__main__":
proc = Process(target=main, proc = Process(target=main,
args=(args.model, dp_size, local_dp_rank, args=(args.model, dp_size, local_dp_rank,
global_dp_rank, dp_master_ip, dp_master_port, global_dp_rank, dp_master_ip, dp_master_port,
tp_size)) tp_size, args.enforce_eager,
args.trust_remote_code))
proc.start() proc.start()
procs.append(proc) procs.append(proc)
exit_code = 0 exit_code = 0

View File

@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import pytest
import torch
import triton.language as tl
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)
@dataclass
class BatchedMMConfig:
dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
K: int
N: int
@dataclass
class BatchedMMTensors:
A: torch.Tensor # [E, max_tokens, K]
B: torch.Tensor # [E, K, N] - column major
C: torch.Tensor # [E, max_tokens, N]
num_expert_tokens: torch.Tensor # [E]
@staticmethod
def make_tensors(config: BatchedMMConfig):
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C
@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config)
test_output = tensors.C
ref_output = test_output.clone()
compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]
invoke_moe_batched_triton_kernel(
tensors.A,
tensors.B,
test_output,
tensors.num_expert_tokens,
compute_tl_dtype,
# Quantization data
None,
None,
None,
# Quantization schemes
False,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16
})
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
tensors.num_expert_tokens)
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)

View File

@ -30,6 +30,11 @@ MNK_FACTORS = [
(224, 3072, 1536), (224, 3072, 1536),
] ]
vllm_config = VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclasses.dataclass @dataclasses.dataclass
class MOETensors: class MOETensors:
@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
'topk_weights': topk_weights, 'topk_weights': topk_weights,
'topk_ids_': topk_ids, 'topk_ids': topk_ids,
'ab_strides1': moe_tensors.ab_strides1, 'ab_strides1': moe_tensors.ab_strides1,
'c_strides1': moe_tensors.c_strides1, 'c_strides1': moe_tensors.c_strides1,
'ab_strides2': moe_tensors.ab_strides2, 'ab_strides2': moe_tensors.ab_strides2,
@ -231,18 +236,15 @@ def test_cutlass_moe_8_bit_no_graph(
per_out_ch: bool, per_out_ch: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
with set_current_vllm_config( with set_current_vllm_config(vllm_config):
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch) per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=torch.half) score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids = fused_topk(mt.a, topk_weights, topk_ids, _ = fused_topk(mt.a,
score, score,
topk, topk,
renormalize=False) renormalize=False)
# Note that we are using the dequantized versions of the tensors. # Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences. # Using a, w1 and w2 directly results in minor output differences.
@ -276,20 +278,17 @@ def test_cutlass_moe_8_bit_cuda_graph(
per_out_ch: bool, per_out_ch: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
with set_current_vllm_config( with set_current_vllm_config(vllm_config):
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
dtype = torch.half dtype = torch.half
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch) per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(mt.a, topk_weights, topk_ids, _ = fused_topk(mt.a,
score, score,
topk, topk,
renormalize=False) renormalize=False)
# Note that we are using the dequantized versions of the tensors. # Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences. # Using a, w1 and w2 directly results in minor output differences.
@ -334,18 +333,15 @@ def test_cutlass_moe_8_bit_EP(
ep_size: int, ep_size: int,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
with set_current_vllm_config( with set_current_vllm_config(vllm_config):
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel) per_out_channel)
score = torch.randn((m, e), device="cuda", dtype=torch.half) score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids = fused_topk(mt.a, topk_weights, topk_ids, _ = fused_topk(mt.a,
score, score,
topk, topk,
renormalize=False) renormalize=False)
# Note that we are using the dequantized versions of the tensors. # Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences. # Using a, w1 and w2 directly results in minor output differences.

View File

@ -12,6 +12,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
@ -32,6 +33,10 @@ NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4] EP_SIZE = [1, 4]
TOP_KS = [2, 6] TOP_KS = [2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048])
@ -70,31 +75,33 @@ def test_fused_moe(
else: else:
e_map = None e_map = None
torch_output = torch_moe(a, w1, w2, score, topk, e_map) with set_current_vllm_config(vllm_config):
iterative_output = iterative_moe(a, torch_output = torch_moe(a, w1, w2, score, topk, e_map)
w1, iterative_output = iterative_moe(a,
w2, w1,
score, w2,
topk, score,
global_num_experts=e, topk,
expert_map=e_map, global_num_experts=e,
renormalize=False) expert_map=e_map,
renormalize=False)
# Pad the weight if moe padding is enabled # Pad the weight if moe padding is enabled
if padding: if padding:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache() torch.cuda.empty_cache()
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache() torch.cuda.empty_cache()
triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(iterative_output, torch.testing.assert_close(iterative_output,
torch_output, torch_output,
@ -115,7 +122,6 @@ def test_fused_moe(
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
ep_size: int, dtype: torch.dtype, group_size: int, ep_size: int, dtype: torch.dtype, group_size: int,
has_zp: bool, weight_bits: int): has_zp: bool, weight_bits: int):
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
@ -194,22 +200,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
else: else:
e_map = None e_map = None
triton_output = fused_moe(a, with set_current_vllm_config(vllm_config):
w1_qweight, triton_output = fused_moe(a,
w2_qweight, w1_qweight,
score, w2_qweight,
topk, score,
renormalize=False, topk,
use_int4_w4a16=weight_bits == 4, renormalize=False,
use_int8_w8a16=weight_bits == 8, use_int4_w4a16=weight_bits == 4,
global_num_experts=e, use_int8_w8a16=weight_bits == 8,
expert_map=e_map, global_num_experts=e,
w1_scale=w1_scales, expert_map=e_map,
w2_scale=w2_scales, w1_scale=w1_scales,
w1_zp=w1_qzeros if has_zp else None, w2_scale=w2_scales,
w2_zp=w2_qzeros if has_zp else None, w1_zp=w1_qzeros if has_zp else None,
block_shape=[0, group_size]) w2_zp=w2_qzeros if has_zp else None,
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) block_shape=[0, group_size])
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@ -515,7 +523,8 @@ def test_fused_marlin_moe(
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe( marlin_output = torch.ops.vllm.fused_marlin_moe(
a, a,

View File

@ -0,0 +1,691 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for the MOE layers.
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import dataclasses
import os
import traceback
from typing import Callable, Optional
import pytest
import torch
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_finalize, nvshmem_get_unique_id,
nvshmem_init)
has_pplx = True
except ImportError:
has_pplx = False
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import override_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
get_default_config)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
(222, 2048, 1024)]
PPLX_MOE_COMBOS = [
(1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
(32, 128, 1024),
(45, 512, 2048),
(64, 1024, 1024),
(222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
TOP_KS = [1, 2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
P = ParamSpec("P")
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)
def torch_prepare(
a: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
max_num_tokens: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert topk_ids.dim() == 2
assert topk_ids.shape[0] == a.shape[0]
num_tokens, hidden_dim = a.shape
topk = topk_ids.shape[1]
tokens_per_expert = torch.bincount(topk_ids.view(-1),
minlength=num_experts)
assert tokens_per_expert.numel() == num_experts
if max_num_tokens is None:
max_num_tokens = int(tokens_per_expert.max().item())
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
dtype=a.dtype,
device=a.device)
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
for token in range(num_tokens):
for j in range(topk):
expert_id = topk_ids[token, j]
idx = token_counts[expert_id]
b_a[expert_id, idx:idx + 1, :] = a[token, :]
token_counts[expert_id] = token_counts[expert_id] + 1
return b_a, tokens_per_expert
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
num_tokens = topk_ids.shape[0]
num_experts = b_out.shape[0]
K = b_out.shape[-1]
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
expert_counts = torch.zeros(num_experts,
dtype=torch.int,
device=b_out.device)
for token in range(num_tokens):
expert_ids = topk_ids[token]
for i in range(expert_ids.numel()):
expert_id = expert_ids[i]
idx = expert_counts[expert_id]
out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
1, :] * topk_weight[token, i]
expert_counts[expert_id] = expert_counts[expert_id] + 1
return out
def torch_batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
num_experts = w1.shape[0]
b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
assert b_a.dim() == 3
num_tokens, topk = topk_ids.shape
_, max_num_tokens, K = b_a.shape
assert num_experts == b_a.shape[0] and w2.shape[1] == K
out = torch.zeros((num_experts, max_num_tokens, K),
dtype=b_a.dtype,
device=b_a.device)
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
dtype=b_a.dtype,
device=b_a.device)
for expert in range(num_experts):
num = tokens_per_expert[expert]
if num > 0:
torch.ops._C.silu_and_mul(
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
return torch_finalize(out, topk_weight, topk_ids)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
num_experts = w1.shape[0]
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0),
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
# Note: same as torch_moe but with fused_topk factored out.
def torch_moe2(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_fused_moe_batched_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
):
current_platform.seed_everything(7)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
torch.testing.assert_close(baseline_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_output,
batched_output,
atol=2e-2,
rtol=0)
def rank_chunk(num: int, r: int, w: int) -> int:
rem = num % w
return (num // w) + (1 if r < rem else 0)
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
chunk = rank_chunk(t.shape[0], r, w)
return t[(r * chunk):(r + 1) * chunk]
def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
topk_weight: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
assert torch.cuda.current_device() == pgi.local_rank
topk = topk_ids.shape[1]
num_tokens, hidden_dim = a.shape
block_size = 128
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
)
topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens,
world_size,
rank,
dp_size,
a.dtype,
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare(
a_chunk,
None,
None,
chunk_topk_weight,
chunk_topk_ids,
num_experts,
None,
False,
)
b_a = b_a * 1.5
out = torch.full(
(max_num_tokens, hidden_dim),
torch.nan,
dtype=a.dtype,
device=device,
)
prepare_finalize.finalize(
out,
b_a,
chunk_topk_weight,
chunk_topk_ids,
False,
)
torch.cuda.synchronize()
ata.destroy()
num_tokens = a_chunk.shape[0]
return out[:num_tokens]
def _pplx_prepare_finalize(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
score: torch.Tensor,
topk: torch.Tensor,
num_experts: int,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
device = pgi.device
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
k = a.shape[1]
a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
torch_output = (a_rep.view(-1, topk, k) * 1.5 *
topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(
a.dtype)
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
num_experts)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
nvshmem_finalize()
# TODO (bnell): this test point does not work for odd M due to how the test is
# written, not due to limitations of the pplx kernels. The pplx_moe
# test below is able to deal with odd M.
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@requires_pplx
def test_pplx_prepare_finalize(
mnk: tuple[int, int, int],
e: int,
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
):
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
device = "cuda"
a = torch.randn((m, k), device=device, dtype=dtype) / 10
score = torch.randn((m, e), device=device, dtype=dtype)
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
topk, e)
def pplx_moe(
rank: int,
world_size: int,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
use_compile: bool = True,
use_cudagraphs: bool = True,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
device = torch.device("cuda", rank)
hidden_dim = a.shape[1]
num_experts = w1.shape[0]
block_size = 128
topk = topk_ids.shape[1]
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
)
topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens,
world_size,
rank,
dp_size,
)
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
world_size=world_size,
dp_size=dp_size)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
# Chunking weights like this only works for batched format
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
if use_compile:
_fused_experts = torch.compile(fused_experts,
backend='inductor',
fullgraph=True)
else:
_fused_experts = fused_experts
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
if use_cudagraphs:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
ata.destroy()
return out
def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
assert torch.cuda.current_device() == pgi.local_rank
num_experts = w1.shape[0]
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
prepare_finalize = BatchedPrepareAndFinalize(
max_num_tokens=max_num_tokens,
world_size=world_size,
dp_size=dp_size,
rank=rank,
)
experts = BatchedExperts(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
out = fused_experts(
a_chunk,
# Chunking weights like this only works for batched format
chunk_by_rank(w1, rank, world_size).to(device),
chunk_by_rank(w2, rank, world_size).to(device),
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
return out
def _pplx_moe(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
m, k = a.shape
e, _, n = w2.shape
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
topk_weight, topk_ids)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
nvshmem_finalize()
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@requires_pplx
def test_pplx_moe(
mnk: tuple[int, int, int],
e: int,
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
):
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)

View File

@ -7,6 +7,7 @@ import pytest
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True) allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input """Matrix multiplication function that supports per-token input
@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) with set_current_vllm_config(vllm_config):
out = fused_moe( ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
a, out = fused_moe(
w1, a,
w2, w1,
score, w2,
topk, score,
renormalize=False, topk,
use_fp8_w8a8=True, # using fp8 renormalize=False,
per_channel_quant=True, use_fp8_w8a8=True, # using fp8
w1_scale=w1_s, per_channel_quant=True,
w2_scale=w2_s, w1_scale=w1_s,
block_shape=None, # Not using block quantization w2_scale=w2_s,
) block_shape=None, # Not using block quantization
)
# Check results # Check results
rel_diff = (torch.mean( rel_diff = (torch.mean(

View File

@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
deep_gemm_moe_fp8) _valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size) moe_align_block_size)
@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True) allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations # Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048] NUM_TOKENS = [7, 83, 2048]
@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam. # Set the context to avoid lots of warning spam.
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
out = fused_moe( out = fused_moe(
a, a,
@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed", "M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes # only aligned sizes
@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
block_size = [block_m, block_m] block_size = [block_m, block_m]
dtype = torch.bfloat16 dtype = torch.bfloat16
# only aligned sizes if topk > E:
if (N % block_m != 0 or K % block_m != 0 or topk > E): pytest.skip(f"Skipping test: topk={topk} > E={E}")
pytest.skip(
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
if N <= 512: if not _valid_deep_gemm_shape(M, N, K):
pytest.skip("Skipping N <= 512 until performance issues solved.") pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
vllm_config = VllmConfig()
torch.manual_seed(seed) torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_info = torch.finfo(torch.float8_e4m3fn)

View File

@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True) allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# For test # For test
def native_per_token_group_quant_int8(x, def native_per_token_group_quant_int8(x,
@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam. # Set the context to avoid lots of warning spam.
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
out = fused_moe( out = fused_moe(
a, a,

View File

@ -23,6 +23,7 @@ If you only need to use the distributed environment without model/pipeline
""" """
import contextlib import contextlib
import gc import gc
import importlib.util
import pickle import pickle
import weakref import weakref
from collections import namedtuple from collections import namedtuple
@ -42,7 +43,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
supports_custom_op) run_once, supports_custom_op)
@dataclass @dataclass
@ -936,9 +937,49 @@ def init_distributed_environment(
"world group already initialized with a different world size") "world group already initialized with a different world size")
PPLX_DID_INIT: bool = False
@run_once
def pplx_init(rank, world_size):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if has_pplx and world_size > 1:
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id, nvshmem_init)
try:
global PPLX_DID_INIT
logger.debug(
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
"world size=%d", rank, world_size)
uid = nvshmem_get_unique_id(
) if rank == 0 else nvshmem_alloc_empty_unique_id()
uid_gpu = uid.cuda()
get_world_group().broadcast(uid_gpu, src=0)
uid = uid_gpu.to(device='cpu')
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, rank, world_size)
PPLX_DID_INIT = True
except Exception as ex:
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
@run_once
def pplx_finalize():
global PPLX_DID_INIT
if PPLX_DID_INIT:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
from vllm.model_executor.layers.fused_moe.layer import (
_all_to_all_cache)
_all_to_all_cache.destroy()
nvshmem_finalize()
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
enable_expert_parallel: bool = False,
backend: Optional[str] = None, backend: Optional[str] = None,
) -> None: ) -> None:
""" """
@ -1041,10 +1082,14 @@ def initialize_model_parallel(
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group) _EP.rank_in_group)
if enable_expert_parallel:
pplx_init(rank, world_size)
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
enable_expert_parallel: bool = False,
backend: Optional[str] = None, backend: Optional[str] = None,
) -> None: ) -> None:
"""Helper to initialize model parallel groups if they are not initialized, """Helper to initialize model parallel groups if they are not initialized,
@ -1055,7 +1100,8 @@ def ensure_model_parallel_initialized(
get_world_group().device_group) get_world_group().device_group)
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, backend) pipeline_model_parallel_size,
enable_expert_parallel, backend)
return return
assert ( assert (
@ -1133,6 +1179,9 @@ def get_tensor_model_parallel_rank():
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none and destroy them.""" """Set the groups to none and destroy them."""
global _TP global _TP
pplx_finalize()
if _TP: if _TP:
_TP.destroy() _TP.destroy()
_TP = None _TP = None

View File

@ -23,7 +23,7 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_tcp_uri from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)
@ -362,12 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
Destroy ProcessGroup returned by Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group(). stateless_init_torch_distributed_process_group().
""" """
# Lazy import for non-CUDA backends. if is_torch_equal_or_newer("2.7"):
try: pg.shutdown()
# pytorch <= 2.6 else:
# Lazy import for non-CUDA backends.
from torch.distributed.distributed_c10d import _shutdown_backend from torch.distributed.distributed_c10d import _shutdown_backend
_shutdown_backend(pg) _shutdown_backend(pg)
except ImportError:
# pytorch >= 2.7
pg.shutdown()
_unregister_process_group(pg.group_name) _unregister_process_group(pg.group_name)

View File

@ -27,6 +27,7 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@dataclass @dataclass
class DPMetadata: class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor
@ -90,8 +91,10 @@ def set_forward_context(attn_metadata: Any,
dtype=torch.int32) dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
cu_tokens_across_dp_cpu)
global _forward_context global _forward_context
prev_context = _forward_context prev_context = _forward_context

View File

@ -38,8 +38,8 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4, cutlass_moe_fp8) cutlass_moe_fp4, cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name, TritonExperts, fused_experts, fused_moe, fused_topk,
grouped_topk) get_config_file_name, grouped_topk)
__all__ += [ __all__ += [
"fused_moe", "fused_moe",
@ -49,4 +49,5 @@ if HAS_TRITON:
"grouped_topk", "grouped_topk",
"cutlass_moe_fp8", "cutlass_moe_fp8",
"cutlass_moe_fp4", "cutlass_moe_fp4",
"TritonExperts",
] ]

View File

@ -5,10 +5,176 @@ from typing import Optional
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
out_dtype: torch.dtype,
):
super().__init__()
self.ab_strides1 = ab_strides1
self.c_strides1 = c_strides1
self.ab_strides2 = ab_strides2
self.c_strides2 = c_strides2
self.out_dtype = out_dtype
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
# Note that K, N are transposed
N, K = K, N
workspace1 = M * topk * max(2 * N, K)
workspace2 = M * topk * N
return (workspace1, workspace2, self.out_dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
a1q = hidden_states
assert w1_scale is not None
assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1"
assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2"
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert a1q_scale is None or a1q_scale.dim(
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2.shape[2], "W2 scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
assert w1.shape[0] == w1_scale.shape[
0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
assert self.ab_strides1.shape[0] == w1.shape[
0], "AB Strides 1 expert number mismatch"
assert self.c_strides1.shape[0] == w1.shape[
0], "C Strides 1 expert number mismatch"
assert self.ab_strides2.shape[0] == w2.shape[
0], "AB Strides 2 expert number mismatch"
assert self.c_strides2.shape[0] == w2.shape[
0], "C Strides 2 expert number mismatch"
assert self.out_dtype in [torch.half,
torch.bfloat16], "Invalid output dtype"
M = a1q.shape[0]
_, N, K = w2.shape # because w1 + w2 are transposed
device = a1q.device
assert w1.shape[1] == K
assert global_num_experts != -1
assert a1q_scale is not None
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids] != -1,
expert_map[topk_ids], -1)
else:
local_topk_ids = topk_ids
topk = local_topk_ids.shape[1]
per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((global_num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((global_num_experts, 3),
dtype=torch.int32,
device=device)
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
if expert_map is not None:
a_map = torch.zeros((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
else:
a_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale,
expert_offsets[:-1], problem_sizes1,
self.ab_strides1, self.ab_strides1, self.c_strides1)
self.activation(activation, c2, c1)
a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
if expert_map is not None:
c3.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale,
expert_offsets[:-1], problem_sizes2,
self.ab_strides2, self.ab_strides2, self.c_strides2)
c3 = c3[c_map]
return c3
#TODO make the grouped gemm kernel consistent with scaled gemm kernel #TODO make the grouped gemm kernel consistent with scaled gemm kernel
def cutlass_moe_fp8( def cutlass_moe_fp8(
a: torch.Tensor, a: torch.Tensor,
@ -17,7 +183,7 @@ def cutlass_moe_fp8(
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids_: torch.Tensor, topk_ids: torch.Tensor,
ab_strides1: torch.Tensor, ab_strides1: torch.Tensor,
c_strides1: torch.Tensor, c_strides1: torch.Tensor,
ab_strides2: torch.Tensor, ab_strides2: torch.Tensor,
@ -59,7 +225,7 @@ def cutlass_moe_fp8(
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms. quantize the intermediate result between the gemms.
Shape: scalar or [M] Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type. - out_dtype (torch.dtype): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel, - expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i] mapping from global expert-id to local expert-id. When expert_map[i]
@ -71,115 +237,36 @@ def cutlass_moe_fp8(
Returns: Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer. - torch.Tensor: The fp16 output tensor after applying the MoE layer.
""" """
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert w1_q.dtype == torch.float8_e4m3fn
assert w2_q.dtype == torch.float8_e4m3fn
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert a1_scale is None or a1_scale.dim(
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1_q.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2_q.shape[2], "W2 scale shape mismatch"
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[
0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
assert ab_strides1.shape[0] == w1_q.shape[
0], "AB Strides 1 expert number mismatch"
assert c_strides1.shape[0] == w1_q.shape[
0], "C Strides 1 expert number mismatch"
assert ab_strides2.shape[0] == w2_q.shape[
0], "AB Strides 2 expert number mismatch"
assert c_strides2.shape[0] == w2_q.shape[
0], "C Strides 2 expert number mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(1)
n = w2_q.size(1)
local_topk_ids = topk_ids_
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids_] != -1,
expert_map[topk_ids_], -1)
topk = local_topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False) a2_scale.numel() != 1 if a2_scale is not None else False)
if apply_router_weight_on_input:
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
# TODO: this only works for topK=1, will need to update for topK>1
a = a * topk_weights.to(out_dtype)
a_q, a1_scale = ops.scaled_fp8_quant( fn = mk.FusedMoEModularKernel(
a, a1_scale, use_per_token_if_dynamic=per_act_token) MoEPrepareAndFinalizeNoEP(
device = a_q.device per_channel_quant=per_act_token,
quant_dtype=torch.float8_e4m3fn,
),
CutlassExpertsFp8(
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
out_dtype,
),
)
expert_offsets = torch.empty((num_experts + 1), return fn(
dtype=torch.int32, a,
device=device) w1_q,
problem_sizes1 = torch.empty((num_experts, 3), w2_q,
dtype=torch.int32, topk_weights,
device=device) topk_ids,
problem_sizes2 = torch.empty((num_experts, 3), expert_map=expert_map,
dtype=torch.int32, w1_scale=w1_scale,
device=device) w2_scale=w2_scale,
a1_scale=a1_scale,
a_map_initializer = torch.empty a2_scale=a2_scale,
c2_initializer = torch.empty apply_router_weight_on_input=apply_router_weight_on_input,
if expert_map is not None: )
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
a_map_initializer = torch.zeros
c2_initializer = torch.zeros
a_map = a_map_initializer((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, num_experts, n,
k)
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype)
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
expert_offsets[:-1], problem_sizes1, ab_strides1,
ab_strides1, c_strides1)
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
torch.ops._C.silu_and_mul(intermediate, c1)
intemediate_q, a2_scale = ops.scaled_fp8_quant(
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
expert_offsets[:-1], problem_sizes2, ab_strides2,
ab_strides2, c_strides2)
# Gather tokens
c2 = c2[c_map].view(m, topk, k)
if not apply_router_weight_on_input:
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
return c2.sum(dim=1)
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()

View File

@ -1,16 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import functools
import importlib.util import importlib.util
from typing import Optional from typing import Optional
import torch import torch
import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_align_block_size) _moe_permute)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, from vllm.model_executor.layers.fused_moe.prepare_finalize import (
_fp8_quantize, MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache) _resize_cache)
from vllm.utils import round_up from vllm.utils import round_up
@ -19,6 +20,19 @@ logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@functools.cache
def deep_gemm_block_shape() -> list[int]:
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
block = dg.get_m_alignment_for_contiguous_layout()
return [block, block]
def _valid_deep_gemm_shape(M: int, N: int, K: int):
align = deep_gemm_block_shape()[0]
return align <= M and N % align == 0 and K % align == 0
def _valid_deep_gemm(hidden_states: torch.Tensor, def _valid_deep_gemm(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
@ -29,89 +43,112 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
aligned by `dg.get_m_alignment_for_contiguous_layout()`. aligned by `dg.get_m_alignment_for_contiguous_layout()`.
""" """
if not has_deep_gemm: if not has_deep_gemm:
logger.debug("DeepGemm disabled: deep_gemm not available.")
return False return False
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
# Expert maps not supported yet.
if expert_map is not None: if expert_map is not None:
logger.debug("DeepGemm disabled: expert map NYI.")
return False return False
align = dg.get_m_alignment_for_contiguous_layout() M = hidden_states.size(0)
M = hidden_states.shape[0] _, K, N = w2.size()
_, K, N = w2.shape if not _valid_deep_gemm_shape(M, N, K):
logger.debug("DeepGemm disabled: unalinged problem size.")
# For now, disable DeepGemm for small N until better permute/unpermute
# ops are available.
if N <= 512:
return False return False
if align > M or N % align != 0 or K % align != 0: if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):
logger.debug("DeepGemm disabled: invalid weight dtype(s).")
return False return False
return (hidden_states.is_contiguous() and w1.is_contiguous() if (not hidden_states.is_contiguous() or not w1.is_contiguous()
and w2.is_contiguous()) or not w2.is_contiguous()):
logger.debug(
"DeepGemm disabled: weights or activations not contiguous.")
return False
return True
def _moe_permute( class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
curr_hidden_states: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
block_m: int,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.shape[1]
tokens_in_chunk, _ = curr_hidden_states.shape def __init__(self):
super().__init__()
self.block_shape = deep_gemm_block_shape()
sorted_token_ids, expert_ids, num_tokens_post_padded = ( def workspace_shapes(
moe_align_block_size(curr_topk_ids, self,
block_m, a: torch.Tensor,
global_num_experts, M: int,
expert_map, N: int,
pad_sorted_ids=True)) K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
workspace1 = M_sum * max(N * 2, K)
workspace2 = M_sum * N
return (workspace1, workspace2, a.dtype)
inv_perm: Optional[torch.Tensor] = None def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
import deep_gemm as dg
num_tokens = top_k_num * tokens_in_chunk a1q = hidden_states
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) _, N, K = w1.size()
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids. assert global_num_experts != -1
curr_hidden_states = _fp8_perm(curr_hidden_states, assert w2.size(1) == K
sorted_token_ids // top_k_num)
if a1q_scale is not None: a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
a1q_scale = a1q_scale[sorted_token_ids // top_k_num] a1q,
a1q_scale,
topk_ids,
global_num_experts,
expert_map,
self.block_shape[0],
)
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, # Note: M_sum is different than the pre-permuted shape of a1q.
inv_perm) M_sum = a1q.size(0)
workspace1 = _resize_cache(workspace13, (M_sum, N))
workspace2 = _resize_cache(workspace2, (M_sum, N // 2))
workspace3 = _resize_cache(workspace13, (M_sum, K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids)
def _moe_unpermute_and_reduce( self.activation(activation, workspace2, workspace1.view(-1, N))
out: torch.Tensor,
curr_hidden: torch.Tensor, a2q_scale: Optional[torch.Tensor] = None
inv_perm: Optional[torch.Tensor],
topk_weight: torch.Tensor, a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False,
) -> None: self.block_shape)
"""
Unpermute the final result and apply topk_weights, then perform the final dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
reduction on the hidden states. (a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
"""
M, topk = topk_weight.shape workspace3 = workspace3[inv_perm, ...]
K = curr_hidden.shape[1]
curr_hidden = curr_hidden[inv_perm, ...] return workspace3
curr_hidden = curr_hidden.view(-1, topk, K)
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def deep_gemm_moe_fp8( def deep_gemm_moe_fp8(
@ -128,6 +165,7 @@ def deep_gemm_moe_fp8(
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input=False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a a8w8-quantized Mixture of Experts (MoE) layer This function computes a a8w8-quantized Mixture of Experts (MoE) layer
@ -166,129 +204,24 @@ def deep_gemm_moe_fp8(
Returns: Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer. - torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
""" """
# Lazy import to avoid CUDA initialization problems. fn = mk.FusedMoEModularKernel(
import deep_gemm as dg MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn,
block_shape=deep_gemm_block_shape()),
assert expert_map is None, "Expert maps not supported yet" DeepGemmExperts(),
)
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" return fn(
hidden_states,
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" w1,
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" w2,
assert w1.stride(-1) == 1, "Stride of last dimension must be 1" topk_weights,
assert w2.stride(-1) == 1, "Stride of last dimension must be 1" topk_ids,
assert hidden_states.dtype in [ inplace,
torch.float32, torch.float16, torch.bfloat16 activation,
] global_num_experts,
assert w1.dtype == torch.float8_e4m3fn expert_map,
assert w2.dtype == torch.float8_e4m3fn w1_scale=w1_scale,
assert w1.shape[0] == w2.shape[0], "Expert number mismatch" w2_scale=w2_scale,
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" a1_scale=a1_scale,
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" a2_scale=a2_scale,
assert a1_scale is None or a1_scale.dim( apply_router_weight_on_input=apply_router_weight_on_input,
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[ )
0] == hidden_states.shape[0], "Input scale shape mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
K = w2.shape[1]
if global_num_experts == -1:
global_num_experts = E
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
block_m = dg.get_m_alignment_for_contiguous_layout()
block_shape = [block_m, block_m]
assert w1_scale is not None
assert w2_scale is not None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
num_chunks = (num_tokens // CHUNK_SIZE) + 1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.empty(M_sum * max(N, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
workspace1 = workspace13[:M_sum * N].view(M_sum, N)
workspace2 = torch.empty((M_sum, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
workspace3 = workspace13[:M_sum * K].view(M_sum, K)
for chunk in range(num_chunks):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
a1_scale, block_shape)
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
curr_topk_ids, global_num_experts,
expert_map, block_m)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
curr_M = sorted_token_ids.numel()
workspace1 = _resize_cache(workspace1, (curr_M, N))
workspace2 = _resize_cache(workspace2, (curr_M, N // 2))
workspace3 = _resize_cache(workspace3, (curr_M, K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1,
expert_ids)
if activation == "silu":
torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N))
elif activation == "gelu":
torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
a2q_scale: Optional[torch.Tensor] = None
qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale,
block_shape)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
_moe_unpermute_and_reduce(
out_hidden_states[begin_chunk_idx:end_chunk_idx],
workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights)
return out_hidden_states

View File

@ -0,0 +1,755 @@
# SPDX-License-Identifier: Apache-2.0
"""Fused batched MoE kernel."""
from typing import Optional
import torch
import triton
import triton.language as tl
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
@triton.jit
def moe_mmk(
a_ptrs,
b_ptrs,
K,
expert_id,
a_scale_ptr,
b_scale_ptr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ak,
stride_bk,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Offsets and masks
offs_m,
offs_n,
mask_m,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
compute_type: tl.constexpr,
use_w8a8: tl.constexpr,
use_w8a16: tl.constexpr):
offs_k = tl.arange(0, BLOCK_K)
if use_w8a16:
b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[
None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
if use_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
offs_bsn = offs_n // group_n
b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse +
offs_bsn * stride_bsn)
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + expert_id)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs,
mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K),
other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
# We accumulate along the K dimension.
if use_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_K
offs_ks = k_start // group_k
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
mask=mask_m,
other=0.0)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:,
None] * b_scale[None, :]
else:
if use_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
if use_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
return accumulator
@triton.jit
def expert_triton_kernel(
a_ptr, #[max_tokens, K]
b_ptr, #[K, N]
c_ptr, #[max_tokens, N]
expert_id,
compute_type: tl.constexpr,
# Dimensions
M,
N,
K,
# Quantization data
a_scale_ptr,
b_scale_ptr,
b_zp_ptr,
# strides
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Blockwise quantization data
group_n,
group_k,
# Quantization schemes
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
# Kernel config
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N) % N
offs_k = tl.arange(0, BLOCK_K)
mask_m = offs_m < M
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
accumulator = moe_mmk(
a_ptrs,
b_ptrs,
K,
expert_id,
a_scale_ptr,
b_scale_ptr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ak,
stride_bk,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Offsets and masks
offs_m,
offs_n,
mask_m,
# Block size for block-wise quantization
group_n,
group_k,
# Meta-parameters
BLOCK_M,
BLOCK_N,
BLOCK_K,
compute_type,
use_fp8_w8a8,
use_int8_w8a16)
# store in C
offs_cn = tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@triton.jit
def batched_triton_kernel(
a_ptr, # [E, max_num_tokens, K]
b_ptr, # [E, K, N]
c_ptr, # [E, max_num_tokens, N]
expert_num_tokens, # [E]
compute_type: tl.constexpr,
# Dimensions
max_num_tokens,
K,
N,
# Quantization data
a_scale_ptr,
b_scale_ptr,
b_zp_ptr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ae,
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_ce,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Blockwise quantization data
group_n: tl.constexpr,
group_k: tl.constexpr,
# Quantization schemes
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
# Kernel config
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
expert_id = tl.program_id(axis=0)
e_num_tokens = tl.load(expert_num_tokens + expert_id)
if e_num_tokens == 0:
# Early exit
return
pid_mn = tl.program_id(axis=1)
#num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid_mn // num_pid_n
pid_n = pid_mn % num_pid_n
cta_m_start = pid_m * BLOCK_M
cta_n_start = pid_n * BLOCK_N
if cta_m_start >= e_num_tokens:
# Early exit
return
cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start)
cta_n_size = min(BLOCK_N, N - cta_n_start)
a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
cta_n_start * stride_cn)
expert_triton_kernel(
a_ptr,
b_ptr,
c_ptr,
expert_id,
compute_type,
cta_m_size, # M
cta_n_size, # N
K, # K
a_scale_ptr,
b_scale_ptr,
b_zp_ptr,
# Strides
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Blockwise quantization data
group_n,
group_k,
# Quantization schemes
use_fp8_w8a8,
use_int8_w8a16,
# Kernel config
BLOCK_M,
BLOCK_N,
BLOCK_K)
def invoke_moe_batched_triton_kernel(
A: torch.Tensor, # [E, max_tokens, K]
B: torch.Tensor, # [E, K, N]
C: torch.Tensor, # [E, max_tokens, N]
expert_num_tokens: torch.Tensor, # [E]
compute_type: tl.dtype,
# Quantization data
A_scale: torch.Tensor,
B_scale: torch.Tensor,
B_zp: torch.Tensor,
# Quantization schemes
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
config: dict[str, int],
block_shape: Optional[list[int]] = None):
assert not use_int4_w4a16
max_num_tokens = A.size(1)
K = A.size(2)
N = C.size(2)
BLOCK_M = config['BLOCK_SIZE_M']
BLOCK_N = config['BLOCK_SIZE_N']
BLOCK_K = config['BLOCK_SIZE_K']
assert (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()
or max_num_tokens % BLOCK_M == 0)
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
triton.cdiv(B.size(1), BLOCK_N))
batched_triton_kernel[grid](
A,
B,
C,
expert_num_tokens,
compute_type,
# Dimensions
max_num_tokens,
K,
N,
# Quantization data
A_scale,
B_scale,
B_zp,
# Strides
A.stride(0),
A.stride(1),
A.stride(2),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(0),
C.stride(1),
C.stride(2),
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
# Blockwise quantization data
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
# Quantization schemes
use_fp8_w8a8,
use_int8_w8a16,
# Kernel config
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K)
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
A reference prepare/finalize class that reorganizes the tokens into
expert batched format, i.e. E x max_num_tokens x K. This is the format
that the PPLX dispatch/combine kernels use.
"""
def __init__(self, max_num_tokens: Optional[int], world_size: int,
dp_size: int, rank: int):
super().__init__()
self.world_size = world_size
self.dp_size = dp_size
self.rank = rank
self.max_num_tokens = max_num_tokens
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
assert a1.dim() == 2
assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0)
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
num_tokens, hidden_dim = a1.size()
topk = topk_ids.size(1)
if self.max_num_tokens is None:
tokens_per_expert = torch.bincount(topk_ids.view(-1),
minlength=num_experts)
self.max_num_tokens = int(tokens_per_expert.max().item())
else:
tokens_per_expert = torch.zeros(num_experts,
dtype=torch.int,
device=a1.device)
assert num_experts % self.world_size == 0
num_local_experts = num_experts // self.world_size
b_a1 = torch.zeros(
(num_local_experts, self.max_num_tokens, hidden_dim),
dtype=a1.dtype,
device=a1.device)
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
for expert_id in range(first_expert, last_expert):
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
rows = torch.count_nonzero(topks.flatten())
b_a1[expert_id -
first_expert, :rows, :] = a1[:topks.numel()][topks]
tokens_per_expert[expert_id - first_expert] = rows
return b_a1, a1_scale, tokens_per_expert
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
num_tokens = topk_ids.size(0)
num_local_experts = fused_expert_output.size(0)
K = fused_expert_output.size(-1)
assert output.size(0) == num_tokens and output.size(1) == K
output.fill_(0)
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
for expert_id in range(first_expert, last_expert):
matching_tokens = topk_ids == expert_id
topks = torch.any(matching_tokens, dim=1).flatten()
rows = torch.count_nonzero(topks)
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
if not apply_router_weight_on_input:
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
output[topks] = output[topks] + rhs
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
A reference MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx
dispatch/combine kernels use.
"""
def __init__(
self,
world_size: int,
dp_size: int,
max_num_tokens: Optional[int] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
):
super().__init__()
assert block_shape is None
assert block_m is None
assert not use_fp8_w8a8, "NYI"
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
self.max_num_tokens = max_num_tokens
self.world_size = world_size
self.dp_size = dp_size
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
workspace13 = num_experts * max_num_tokens * num_dp * K
workspace2 = max_num_tokens * num_dp * N
return (workspace13, workspace2, a.dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
assert hidden_states.dim() == 3
assert expert_num_tokens is not None
hidden_dim = hidden_states.size(-1)
if self.max_num_tokens is None:
max_num_tokens = hidden_states.size(1)
else:
max_num_tokens = self.max_num_tokens
num_dp = self.world_size // self.dp_size
num_experts = global_num_experts
out = _resize_cache(workspace13,
(num_experts, max_num_tokens * num_dp, hidden_dim))
num_local_experts = w1.size(0)
assert num_local_experts == w1.size(0), (
f"{num_local_experts} == {w1.size(0)}")
N = w1.size(1) // 2
# Not cudagraph friendly
assert (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()
or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), (
f"{expert_num_tokens} <= {max_num_tokens * num_dp}")
for expert in range(num_local_experts):
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
if (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()):
num = max_num_tokens * num_dp
else:
num = int(expert_num_tokens[expert].item())
tmp = _resize_cache(workspace2, (num, N))
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
self.activation(activation, tmp, input)
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
return out
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
A Triton based MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx
dispatch/combine kernels use.
"""
def __init__(
self,
max_num_tokens: Optional[int] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
block_shape: Optional[list[int]] = None,
world_size: int = 1,
dp_size: int = 1,
):
super().__init__()
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int8_w8a8 = use_int8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape
self.max_num_tokens = max_num_tokens
assert not use_int8_w8a8, "NYI"
assert not use_int4_w4a16, "NYI"
self.world_size = world_size
self.dp_size = dp_size
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
return (workspace13, workspace2, a.dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
assert hidden_states.size(-1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(-1)} "
f"!= {w1.size(2)}")
assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
]
# TODO: num_tokens -> max_num_tokens?
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids)
assert w1.size(0) == E
assert w2.size(0) == E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.size(),
w2.size(),
top_k_num,
config_dtype,
num_tokens,
block_shape=self.block_shape,
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn:
compute_type = tl.bfloat16
else:
raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}")
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
intermediate_cache2 = _resize_cache(workspace2,
(E, num_tokens, N // 2))
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
# MM1
invoke_moe_batched_triton_kernel(A=hidden_states,
B=w1,
C=intermediate_cache1,
expert_num_tokens=expert_num_tokens,
compute_type=compute_type,
A_scale=a1q_scale,
B_scale=w1_scale,
B_zp=w1_zp,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
config=config,
block_shape=self.block_shape)
# TODO: would be nice to use expert_num_tokens here to reduce
# garbage compute
self.activation(activation, intermediate_cache2.view(-1, N // 2),
intermediate_cache1.view(-1, N))
#qintermediate_cache2 = intermediate_cache2
a2q_scale = a2_scale
# TODO (varun) : support w8a8
assert not self.use_fp8_w8a8
#if self.use_fp8_w8a8:
# qintermediate_cache2, a2q_scale = _fp8_quantize(
# intermediate_cache2, a2_scale, self.block_shape)
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
B=w2,
C=intermediate_cache3,
expert_num_tokens=expert_num_tokens,
compute_type=compute_type,
A_scale=a2q_scale,
B_scale=w2_scale,
B_zp=w2_zp,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
config=config,
block_shape=self.block_shape)
return intermediate_cache3

View File

@ -8,16 +8,17 @@ from typing import Any, Callable, Optional
import torch import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8) _valid_deep_gemm, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size) moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
per_token_group_quant_fp8) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.quantization.utils.int8_utils import ( from vllm.model_executor.layers.fused_moe.utils import (
per_token_group_quant_int8, per_token_quant_int8) _resize_cache, moe_kernel_quantize_input)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert topk_weights is None or topk_weights.stride(1) == 1 assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.shape[0] M = A.shape[0]
num_tokens = M * top_k num_tokens = M * top_k
@ -855,6 +870,7 @@ def fused_topk(
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
indices_type: Optional[torch.dtype] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], ( assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch") "Number of tokens mismatch")
@ -865,10 +881,11 @@ def fused_topk(
topk, topk,
dtype=torch.float32, dtype=torch.float32,
device=hidden_states.device) device=hidden_states.device)
topk_ids = torch.empty(M, topk_ids = torch.empty(
topk, M,
dtype=torch.int32, topk,
device=hidden_states.device) dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device)
token_expert_indices = torch.empty(M, token_expert_indices = torch.empty(M,
topk, topk,
dtype=torch.int32, dtype=torch.int32,
@ -962,6 +979,20 @@ def get_config_dtype_str(
return None return None
# TODO (bnell): use scalar_type instead of bools?
def get_config_qtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
) -> Optional[torch.dtype]:
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
return None
def inplace_fused_experts(hidden_states: torch.Tensor, def inplace_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
@ -1128,7 +1159,10 @@ def fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor: allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8 # For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.shape[1]
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
assert apply_router_weight_on_input is False assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8( return deep_gemm_moe_fp8(
@ -1145,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
) )
else: else:
return dispatch_fused_experts_func(inplace)( return dispatch_fused_experts_func(inplace)(
@ -1171,87 +1206,37 @@ def fused_experts(hidden_states: torch.Tensor,
block_shape=block_shape) block_shape=block_shape)
def moe_kernel_prepare_input( def fused_experts_impl(
A: torch.Tensor, hidden_states: torch.Tensor,
B: torch.Tensor, w1: torch.Tensor,
A_scale: Optional[torch.Tensor], w2: torch.Tensor,
B_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
use_fp8_w8a8: bool, topk_ids: torch.Tensor,
use_int8_w8a8: bool, inplace: bool = False,
use_int8_w8a16: bool, activation: str = "silu",
use_int4_w4a16: bool, apply_router_weight_on_input: bool = False,
per_channel_quant: bool, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> torch.Tensor:
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
# activation channel-wise int8 quantization
assert (per_channel_quant
), "int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
return A, A_scale
def fused_experts_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None):
# Check constraints. # Check constraints.
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[ assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch" 2], "Hidden size mismatch"
else: else:
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert hidden_states.shape[1] == w1.shape[2], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
@ -1261,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
torch.float32, torch.float16, torch.bfloat16 torch.float32, torch.float16, torch.bfloat16
] ]
num_tokens, _ = hidden_states.shape num_tokens = hidden_states.shape[0]
E, N, _ = w1.shape E, N, _ = w1.shape
K = w2.shape[1] K = w2.shape[1]
if global_num_experts == -1: if global_num_experts == -1:
@ -1276,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
get_config_func = functools.partial( get_config_func = functools.partial(
try_get_optimal_moe_config, try_get_optimal_moe_config,
w1.shape, w1.shape,
@ -1338,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input( qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states, A=curr_hidden_states,
B=w1,
A_scale=a1_scale, A_scale=a1_scale,
B_scale=w1_scale, qtype=qtype,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape) block_shape=block_shape)
@ -1357,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel(qcurr_hidden_states, invoke_fused_moe_kernel(qcurr_hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
qa1_scale, a1q_scale,
w1_scale, w1_scale,
w1_zp, w1_zp,
curr_topk_weights, curr_topk_weights,
@ -1384,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else: else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}") raise ValueError(f"Unsupported FusedMoe activation: {activation}")
qintermediate_cache2, qa2_scale = moe_kernel_prepare_input( qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2, A=intermediate_cache2,
B=w2,
A_scale=a2_scale, A_scale=a2_scale,
B_scale=w2_scale, qtype=qtype,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape) block_shape=block_shape)
invoke_fused_moe_kernel(qintermediate_cache2, invoke_fused_moe_kernel(qintermediate_cache2,
w2, w2,
intermediate_cache3, intermediate_cache3,
qa2_scale, a2q_scale,
w2_scale, w2_scale,
w2_zp, w2_zp,
curr_topk_weights, curr_topk_weights,
@ -1534,3 +1514,209 @@ def fused_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape) block_shape=block_shape)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
):
super().__init__()
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape
self.block_m = block_m
self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
self.per_channel_quant = per_channel_quant
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
factor = num_experts if a.dim() == 3 else 1
workspace1 = M * topk * max(N * 2, K) * factor
workspace2 = M * topk * N * factor
return (workspace1, workspace2, a.dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
assert hidden_states.size(-1) == w1.size(2), \
(f"Hidden size mismatch {hidden_states.size(-1)} "
f"!= {w1.size(2)}")
assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous"
assert hidden_states.dim() == 2
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
]
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids)
if global_num_experts == -1:
global_num_experts = E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.shape,
w2.shape,
top_k_num,
config_dtype,
num_tokens,
block_shape=self.block_shape,
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn:
compute_type = tl.bfloat16
else:
raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13,
(num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(workspace2,
(num_tokens * top_k_num, N // 2))
intermediate_cache3 = _resize_cache(workspace13,
(num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map))
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
self.activation(activation, intermediate_cache2,
intermediate_cache1.view(-1, N))
a2q_scale: Optional[torch.Tensor] = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant,
self.block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
return intermediate_cache3
def modular_triton_fused_moe(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> mk.FusedMoEModularKernel:
qtype = get_config_qtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
)
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
quant_dtype=qtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
),
TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
),
)

View File

@ -1,15 +1,19 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import importlib
import threading
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
from weakref import WeakValueDictionary
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs import vllm.envs as envs
from vllm.config import get_current_vllm_config from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@ -26,8 +30,17 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_moe import fused_experts from .fused_batched_moe import (BatchedPrepareAndFinalize,
BatchedTritonExperts)
from .fused_moe import TritonExperts, fused_experts
from .modular_kernel import (FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
if has_pplx:
from .pplx_prepare_finalize import PplxPrepareAndFinalize
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
if is_rocm_aiter_moe_enabled(): if is_rocm_aiter_moe_enabled():
@ -42,6 +55,179 @@ else:
fused_moe_pallas = None # type: ignore fused_moe_pallas = None # type: ignore
logger = init_logger(__name__) logger = init_logger(__name__)
# Note: this limit is somewhat arbitrary and might be changed later.
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
MOE_DP_CHUNK_SIZE = 256
@dataclass
class FusedMoEParallelConfig:
tp_size: int
dp_size: int
ep_size: int
tp_rank: int
dp_rank: int
ep_rank: int
use_ep: bool # whether to use EP or not
@property
def use_pplx_kernels(self):
return self.dp_size > 1 and self.use_ep and has_pplx
@staticmethod
def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
"""
Determine MoE parallel configuration. Based on the input tp_size_,
dp_size_, ep_size_ and vllm's parallel config, determine what
level's of parallelism to use in the fused moe layer.
Args:
tp_size_ (int): tp_size passed into the FusedMoE constructor.
dp_size_ (int): dp_size passed into the FusedMoE constructor.
ep_size_ (int): ep_size passed into the FusedMoE constructor.
vllm_parallel_config (ParallelConfig): vllm's parallel config
object.
Examples:
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
we simply return the sizes unaltered and the ranks set to 0.
Expert Parallelism is considered only when either dp_size_ or tp_size_
is non trivial.
When TP = 2, DP = 1 and EP = False, the configuration on different
devices,
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
legend : {size, rank}
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
- Comment : Tensors are sharded across 2 devices.
When TP = 1, DP = 2 and EP = False, the configuration on different
devices,
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
- Comment: There are 2 engine instances and the tensors are sharded
across 2 decvices.
When TP = 2, DP = 2 and EP = False, the configuration on different
devices,
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
- Comment: There are 2 engine instances and the tensors are sharded
across 4 devices.
When, TP = 2, DP = 1 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
- Comment: The experts are split between the 2 devices.
When, TP = 1, DP = 2 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
- Comment: There are 2 engine instances and the experts are split
between the 2 devices.
When TP = 2, DP = 2 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
- Comment: There are 2 engine instances and the experts are split
between the 4 devices.
"""
def flatten_tp_across_dp(dp_rank: int):
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
# There are actually dp_size_ * tp_size_ devices. Update tp_size
# and tp_rank so we shard across all devices.
tp_size = dp_size_ * tp_size_
tp_rank = dp_rank * tp_size_ + tp_rank
return tp_size, tp_rank
use_ep = (dp_size_ * tp_size_ > 1
and vllm_parallel_config.enable_expert_parallel)
dp_size = dp_size_
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
if not use_ep:
return FusedMoEParallelConfig(tp_size=tp_size,
tp_rank=tp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=1,
ep_rank=0,
use_ep=False)
# DP + EP / TP + EP / DP + TP + EP
assert use_ep
# In EP, each device owns a set of experts fully. There is no tensor
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
ep_size = tp_size
ep_rank = tp_rank
return FusedMoEParallelConfig(tp_size=1,
tp_rank=0,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
use_ep=True)
# Adapted from pplx-kernels tests/all_to_all_utils.py
@dataclass
class MoEConfig:
num_experts: int
experts_per_token: int
hidden_dim: int
num_local_experts: int
moe_parallel_config: FusedMoEParallelConfig
in_dtype: torch.dtype # The activation type.
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@property
def dp_size(self):
return self.moe_parallel_config.dp_size
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@property
def tp_rank(self):
return self.moe_parallel_config.tp_rank
@property
def dp_rank(self):
return self.moe_parallel_config.dp_rank
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@property
def use_ep(self):
return self.moe_parallel_config.use_ep
@property
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
class FusedMoeWeightScaleSupported(Enum): class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor" TENSOR = "tensor"
@ -58,6 +244,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError raise NotImplementedError
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool:
return False
@abstractmethod @abstractmethod
def apply( def apply(
self, self,
@ -80,12 +274,54 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError raise NotImplementedError
class AllToAllCache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
def destroy(self):
with self._lock:
# TODO: can we do del self._cache?
for _, a2a in self._cache.items():
a2a.destroy()
def get_or_create(self, **kwargs):
assert has_pplx
import pplx_kernels as pplx
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))
with self._lock:
instance = self._cache.get(key)
if instance is None:
# TODO (varun): Add support to switch to intranode
# when all communications are within the same
# node.
logger.debug("Create AllToAll %s", kwargs)
instance = pplx.AllToAll.internode(**kwargs)
self._cache[key] = instance
return instance
# Global singleton
_all_to_all_cache = AllToAllCache()
# Factory function as a cleaner interface
def get_all_to_all(**kwargs):
return _all_to_all_cache.get_or_create(**kwargs)
@CustomOp.register("unquantized_fused_moe") @CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization.""" """MoE method without quantization."""
def __init__(self): def __init__(self, moe: MoEConfig):
super().__init__() super().__init__()
self.fused_experts = fused_experts
self.moe = moe
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
@ -193,6 +429,47 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input)
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool:
assert self.fused_experts == fused_experts
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
def forward_cuda( def forward_cuda(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -221,9 +498,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=torch.uint32 if self.moe.use_pplx_kernels else None)
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
assert not apply_router_weight_on_input
assert expert_map is None
return self.rocm_aiter_fused_experts( return self.rocm_aiter_fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
@ -232,18 +512,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input)
else:
return fused_experts( return self.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map,
)
def forward_cpu( def forward_cpu(
self, self,
@ -399,6 +680,45 @@ def determine_expert_map(
return (local_num_experts, expert_map) return (local_num_experts, expert_map)
def _construct_prepare_finalize(
moe: MoEConfig, quant_config: Optional[QuantizationConfig]
) -> Optional[FusedMoEPrepareAndFinalize]:
max_num_tokens = MOE_DP_CHUNK_SIZE
world_size = moe.ep_size
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
rank = moe.ep_rank
if moe.use_pplx_kernels:
logger.debug("using PplxPrepareAndFinalize")
all_to_all = get_all_to_all(
max_num_tokens=max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) //
moe.block_size * torch.float32.itemsize)))
return PplxPrepareAndFinalize(
all_to_all,
max_num_tokens=max_num_tokens,
world_size=world_size,
rank=rank,
dp_size=dp_size,
quant_dtype=moe.in_dtype,
)
return None
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
@ -449,21 +769,16 @@ class FusedMoE(torch.nn.Module):
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
# Note: here we guard against accessing the TP and DP groups when
# uninitialized (this happens when testing)
self.tp_size = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank()
self.dp_size = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.dp_rank = (0
if self.dp_size == 1 else get_dp_group().rank_in_group)
self.global_num_experts = num_experts
# Use expert parallelism instead of tensor parallelism?
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
use_ep = (vllm_config.parallel_config.enable_expert_parallel self.moe_parallel_config: FusedMoEParallelConfig = (
and self.tp_size * self.dp_size > 1) FusedMoEParallelConfig.make(
tp_size_=(tp_size if tp_size is not None else
get_tensor_model_parallel_world_size()),
dp_size_=(dp_size if dp_size is not None else
get_dp_group().world_size),
vllm_parallel_config=vllm_config.parallel_config))
self.global_num_experts = num_experts
# For smuggling this layer into the fused moe custom op # For smuggling this layer into the fused moe custom op
self.use_direct_call = self.dp_size == 1 self.use_direct_call = self.dp_size == 1
@ -474,28 +789,17 @@ class FusedMoE(torch.nn.Module):
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix self.layer_name = prefix
if use_ep: # Determine expert maps
# Set TP size to 1 to adjust for EP and adjust EP size and rank if self.use_ep:
# for DP attention.
self.ep_rank = tp_rank + self.tp_size * self.dp_rank
self.tp_rank = 0
self.ep_size = self.tp_size * self.dp_size
self.tp_size = 1
self.local_num_experts, self.expert_map = determine_expert_map( self.local_num_experts, self.expert_map = determine_expert_map(
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts) global_num_experts=self.global_num_experts)
else: else:
# Adjust TP size for DP attention self.local_num_experts, self.expert_map = (self.global_num_experts,
self.tp_rank = tp_rank + self.tp_size * self.dp_rank None)
self.ep_rank = 0
self.tp_size = self.tp_size * self.dp_size
self.ep_size = 1
self.local_num_experts = self.global_num_experts
self.expert_map = None
self.top_k = top_k self.top_k = top_k
self.global_num_experts = num_experts
assert intermediate_size % self.tp_size == 0 assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -520,14 +824,40 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
moe = MoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
)
# Note: get_quant_method will look at the layer's local_num_experts # Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = ( quant_method = UnquantizedFusedMoEMethod(moe)
UnquantizedFusedMoEMethod()) prepare_finalize = _construct_prepare_finalize(moe, quant_config)
else: else:
self.quant_method = quant_config.get_quant_method(self, prefix) quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None # No pplx for quantized types yet.
prepare_finalize = None
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if prepare_finalize is not None:
world_size = moe.ep_size
dp_size = int(moe.ep_size // moe.dp_size)
success = self.quant_method.set_prepare_finalize(
dp_size, world_size, prepare_finalize)
if not success:
logger.warning("DP+EP not supported for %s.",
type(self.quant_method))
moe_quant_params = { moe_quant_params = {
"num_experts": self.local_num_experts, "num_experts": self.local_num_experts,
@ -546,6 +876,38 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@property
def dp_size(self):
return self.moe_parallel_config.dp_size
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@property
def tp_rank(self):
return self.moe_parallel_config.tp_rank
@property
def dp_rank(self):
return self.moe_parallel_config.dp_rank
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@property
def use_ep(self):
return self.moe_parallel_config.use_ep
@property
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
@ -830,7 +1192,8 @@ class FusedMoE(torch.nn.Module):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None): e_score_correction_bias: Optional[torch.Tensor] = None,
indices_type: Optional[torch.dtype] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
# DeekSeekv2 uses grouped_top_k # DeekSeekv2 uses grouped_top_k
@ -846,21 +1209,52 @@ class FusedMoE(torch.nn.Module):
topk_group=topk_group, topk_group=topk_group,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
elif custom_routing_function is None: elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk( topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize) renormalize=renormalize,
indices_type=indices_type,
)
else: else:
topk_weights, topk_ids = custom_routing_function( topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize) renormalize=renormalize)
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
return topk_weights, topk_ids return topk_weights, topk_ids
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
layer. The result of this function is typically used as
the reduce_results argument to the module.
When just tensor-parallel is used, it is not required to reduce
the shared_experts results immediately. Instead we reduce at the
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
With EP and the pplx kernels - this is no longer viable as all
GPU ranks in DP, produce the complete set of hidden_states.
Therefore it is required that we reduce the shared_experts output
early.
"""
return self.use_pplx_kernels
def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor):
"""
The pplx combine kernel reduces across GPU ranks by default.
"""
if self.use_pplx_kernels:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
if self.use_direct_call: if self.use_direct_call:
@ -869,9 +1263,62 @@ class FusedMoE(torch.nn.Module):
return torch.ops.vllm.moe_forward(hidden_states, router_logits, return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name) self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
full_final_hidden_states = torch.empty_like(full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
)
if not skip_result_store:
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states)
ctx = get_forward_context()
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
num_tokens = full_hidden_states.size(0)
for chunk_start_ in range(0, max_tokens_across_dp,
moe_dp_chunk_size_per_rank):
chunk_start = chunk_start_
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
max_tokens_across_dp)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
process_chunk(chunk_start,
chunk_end,
skip_result_store=chunk_start_ >= num_tokens)
return full_final_hidden_states
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
assert self.quant_method is not None assert self.quant_method is not None
if self.moe_parallel_config.use_pplx_kernels:
return self.forward_impl_chunked(hidden_states, router_logits)
if self.dp_size > 1: if self.dp_size > 1:
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(

View File

@ -0,0 +1,364 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Optional
import torch
#
# This file defines a set of base classes used to make MoE kernels more modular.
# The goal is to be able to utilize different communication mechanisms with
# any fused MoE kernel without needing to have combinatoric implementations.
#
# The fused moe kernels are broken down into the following components:
#
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
#
# Each component will be independent of the others except for
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
# mixed and matched with so that DP+EP can be supported easily for multiple
# MoE kernel implementations.
#
# The following main classes are defined:
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
# The prepare method must take care of any needed quantization and the
# finalize method must apply weights and do the final reduction of the output.
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
# MoE operation. One important feature to note is that this class does not
# apply topk weights or reduce the final output.
# * FusedMoEModularKernel - an interface class that combines a
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
# provide the standard fused MoE kernel interface.
#
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
# class `FusedMoEPrepareAndFinalize` since they could use collective
# communication mechanisms that need to be consistent.
#
def _moe_problem_size(
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation tensors is
not obvious. It needs to be done this way specifically due to subtle issues
with particular kernels, e.g. the int4 kernels divide the trailing dimension
by two, so it's not "correct" to extract N or K from the trailing dimension
of w1 or w2. Similarly, some kernels transpose the weights, so this needs
to be kept in mind.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = w2.size(1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), \
f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
class FusedMoEPrepareAndFinalize(ABC):
"""
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
described above.
"""
@abstractmethod
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
- a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
sure the quantization is consistent for both gemms.
- topk_ids: The topk ids.
- topk_weights: The topk weights.
- num_experts: The total number of experts in the global expert space.
- expert_map: A tensor mapping expert indices from the global expert
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
Returns a tuple of:
- quantized + dispatched a.
- quantized + dispatched a1_scales.
"""
raise NotImplementedError
@abstractmethod
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output.
- output: The output tensor, written in place. Must be (M, K) shape.
- fused_expert_output: The unweighted, unreduced output of the fused
experts, it will have (M, topk, K) shape.
- topk_weights: The weights to be applied to the fused_experts_output.
- topk_ids: The topk_ids.
- apply_router_weight_on_input: When False, apply the weights to
fused_expert_output.
"""
raise NotImplementedError
class FusedMoEPermuteExpertsUnpermute(ABC):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above.
"""
@abstractmethod
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
"""
Compute the number of elements for the temporary outputs of the two
gemms and activation in the fused expert function. Since the
gemms are independent, the workspace for the first gemm can be shared
with the workspace for the last gemm.
Returns a tuple of:
- Number of workspace13 elements: must be large enough to hold the
result of either expert gemm.
- Number of workspace2 elements: must be large enough to hold the
result of the activation function.
- Workspace type: The dtype to use for the workspace tensors.
"""
raise NotImplementedError
def activation(self, activation: str, output: torch.Tensor,
input: torch.Tensor) -> None:
assert output.size(-1) * 2 == input.size(-1)
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
@abstractmethod
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
"""
This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2.
Parameters:
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
Returns:
- torch.Tensor: The unweighted, unreduced output tensor
"""
raise NotImplementedError
class FusedMoEModularKernel(torch.nn.Module):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
a FusedMoEPermuteExpertsUnpermute to provide an interface that
is compatible with the `fused_experts` function in fused_moe.py.
It takes care of managing any required scratch space.
Note: Instances of this class should only be used for a single model
layer due to any layer specific state that may be used by the component
objects.
"""
def __init__(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
def forward(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
a1 = hidden_states
E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids)
if global_num_experts == -1:
global_num_experts = E
output = a1 if inplace else torch.zeros_like(a1)
workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.zeros(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
expert_map, apply_router_weight_on_input)
fused_out = self.fused_experts.apply(
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input)
return output

View File

@ -3,6 +3,74 @@ from typing import Optional
import torch import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
block_m: int,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.size(1)
tokens_in_chunk = curr_hidden_states.sizze(0)
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids,
block_m,
global_num_experts,
expert_map,
pad_sorted_ids=True))
inv_perm: Optional[torch.Tensor] = None
num_tokens = top_k_num * tokens_in_chunk
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
curr_hidden_states = _fp8_perm(curr_hidden_states,
sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: Optional[torch.Tensor],
topk_weight: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.size()
K = curr_hidden.size(-1)
if inv_perm is not None:
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
if not apply_router_weight_on_input:
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def moe_permute( def moe_permute(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -17,21 +85,21 @@ def moe_permute(
fill_invalid_expert: int = -1 fill_invalid_expert: int = -1
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
This function expands and permutes activation to gather uncontinuous tokens This function expands and permutes activation to gather uncontinuous tokens
for each expert. for each expert.
Parameters: Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- topk_weights (torch.Tensor): topk expert route weight for each token. - topk_weights (torch.Tensor): topk expert route weight for each token.
- topk_ids (torch.Tensor): topk expert route id for each token. - topk_ids (torch.Tensor): topk expert route id for each token.
- token_expert_indices (torch.Tensor): indice for expanded hidden. - token_expert_indices (torch.Tensor): indice for expanded hidden.
- topk (int): The number of top-k experts to select. - topk (int): The number of top-k experts to select.
- n_expert (int): The number of expert. - n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank. - n_local_expert (int): The number of expert in current EP rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert from the global expert space to the local expert space of the expert
parallel shard. parallel shard.
- align_block_size (Optional[int]): align group gemm block size for deepgemm - align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert - fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices to workaround DeepGemm unsupported -1 in m_indices
Returns: Returns:
- permuted_hidden_states (torch.Tensor): permuted activation. - permuted_hidden_states (torch.Tensor): permuted activation.
@ -39,10 +107,10 @@ def moe_permute(
of each expert for standard grouped gemm. if enable 'align_block_size' of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'. expert_first_token_offset will align up to 'align_block_size'.
- src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. - src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute.
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.` the group which the j-th row of the LHS belong to.`
""" """
n_token, n_hidden = hidden_states.shape n_token, n_hidden = hidden_states.size()
assert (n_hidden * hidden_states.element_size() assert (n_hidden * hidden_states.element_size()
) % 16 == 0, "permue kernel need hidden dim align to 16B" ) % 16 == 0, "permue kernel need hidden dim align to 16B"
permuted_row_size = n_token * topk permuted_row_size = n_token * topk
@ -87,7 +155,7 @@ def moe_unpermute(
n_local_expert: int, n_local_expert: int,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function expands and permutes activation to gathering uncontinuous This function expands and permutes activation to gathering uncontinuous
tokens for each expert. tokens for each expert.
Parameters: Parameters:
- permuted_hidden_states (torch.Tensor): permuted activation. - permuted_hidden_states (torch.Tensor): permuted activation.
@ -99,10 +167,10 @@ def moe_unpermute(
- n_expert (int): The number of expert. - n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank. - n_local_expert (int): The number of expert in current EP rank.
Returns: Returns:
- hidden_states (torch.Tensor): The reduced and unpermuted activation - hidden_states (torch.Tensor): The reduced and unpermuted activation
tensor. tensor.
""" """
n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1] n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1)
assert (n_hidden * permuted_hidden_states.element_size() assert (n_hidden * permuted_hidden_states.element_size()
) % 16 == 0, "unpermue kernel need hidden dim align to 16B" ) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
hidden_states = torch.empty((n_token, n_hidden), hidden_states = torch.empty((n_token, n_hidden),

View File

@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pplx_kernels as pplx
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
# Note use: layer.get_all_to_all() to get an AllToAll instance
# The max_num_tokens, world_size and dp_size must be the same
# as the ones used to create the AllToAll.
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(self,
a2a: pplx.AllToAll,
max_num_tokens: int,
world_size: int,
rank: int,
dp_size: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
super().__init__()
assert max_num_tokens > 0
self.a2a = a2a
self.block_shape = block_shape
self.max_num_tokens = max_num_tokens
self.world_size = world_size
self.rank = rank
self.dp_size = dp_size
self.quant_dtype = quant_dtype
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
rank_topk_weights: torch.Tensor,
rank_topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
assert rank_topk_ids.size(0) == num_tokens
# assert expert_map is None, "NYI"
# Is this always going to be a1.device?
device = a1.device
if apply_router_weight_on_input:
topk = rank_topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
self.quant_dtype,
per_act_token,
self.block_shape)
# rem_experts need to be 0 for pplx to work properly.
rem_experts = num_experts % self.world_size
assert rem_experts == 0
num_local_experts = ((num_experts // self.world_size) +
(1 if self.rank < rem_experts else 0))
expert_num_tokens = torch.empty(
num_local_experts,
dtype=torch.int32,
device=device,
)
num_dp = self.world_size // self.dp_size
expert_x = torch.empty(
(num_local_experts, self.max_num_tokens * num_dp, hidden_dim),
dtype=a1q.dtype,
device=device,
)
expert_x_scale: Optional[torch.Tensor] = None
if a1q.dtype.itemsize == 1:
float32_size = torch.float32.itemsize
block_size = (self.block_shape[0] if self.block_shape is not None
else 1) * float32_size
expert_x_scale = torch.empty(
(
num_experts,
expert_x.size(1),
(expert_x.size(2) + block_size - 1) // block_size,
),
dtype=torch.float32,
device=device,
)
# This argument is optional, defaults to indices.size(0)
# There's not much point setting this unless it is != indices.size(0)
bound_m: Optional[torch.Tensor] = None
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=rank_topk_ids,
bound_m=bound_m,
)
return expert_x, expert_x_scale, expert_num_tokens
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
num_tokens = output.size(0) # M
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
bound_m: Optional[torch.Tensor] = None
assert topk_ids.size(0) == num_tokens, (
f"{topk_ids.size(0)} == {num_tokens}")
assert output.size(0) <= self.max_num_tokens, (
f"{output.size(0)} <= {self.max_num_tokens}")
assert output.size(1) == fused_expert_output.size(-1)
# Set weights to 1 if we did them in dispatch. This is hacky.
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
self.a2a.combine(out_tokens=output,
indices=topk_ids,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m)

View File

@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
quant_dtype: Optional[torch.dtype] = None,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
):
super().__init__()
self.per_channel_quant = per_channel_quant
self.block_shape = block_shape
self.quant_dtype = quant_dtype
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
self.quant_dtype,
self.per_channel_quant,
self.block_shape)
return a1q, a1q_scale, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
_moe_unpermute_and_reduce(output, fused_expert_output, None,
topk_weights, apply_router_weight_on_input)

View File

@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
allow_deep_gemm: bool = False):
super().__init__()
self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
block_m=block_m)
self.deep_gemm_expert = DeepGemmExperts()
self.allow_deep_gemm = allow_deep_gemm
self.use_fp8_w8a8 = use_fp8_w8a8
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
return self.deep_gemm_expert.workspace_shapes(
a, M, N, K, topk, num_experts)
else:
return self.triton_expert.workspace_shapes(a, M, N, K, topk,
num_experts)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
N = w1.size(1)
if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
return self.deep_gemm_expert.apply(
hidden_states,
w1,
w2,
topk_ids,
activation,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_num_tokens,
)
else:
return self.triton_expert.apply(
hidden_states,
w1,
w2,
topk_ids,
activation,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_num_tokens,
)

View File

@ -7,6 +7,8 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.utils import cdiv from vllm.utils import cdiv
@ -15,34 +17,81 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
Shrink the given tensor and apply the given view to it. This is Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches. used to resize the intermediate fused_moe caches.
""" """
assert prod(v) <= x.numel() assert prod(
v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly?
return x.flatten()[:prod(v)].view(*v) return x.flatten()[:prod(v)].view(*v)
def _fp8_quantize( def _fp8_quantize(
A: torch.Tensor, A: torch.Tensor,
A_scale: Optional[torch.Tensor], A_scale: Optional[torch.Tensor],
block_shape: Optional[list[int]], per_act_token: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Perform fp8 quantization on the inputs. If a block_shape Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked. is provided, the output will be blocked.
""" """
if block_shape is None: if block_shape is None:
A, A_scale = ops.scaled_fp8_quant(A, A_scale) A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_act_token)
else: else:
assert len(block_shape) == 2 assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1] _, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k) A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
return A, A_scale return A, A_scale
def _int8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform int8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static
if block_shape is None:
assert per_act_token, \
"int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
return A, A_scale
def moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
qtype: Optional[torch.dtype],
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if qtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
elif qtype == torch.int8:
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
else:
assert A_scale is None
return A, A_scale
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
""" """
A permutation routine that works on fp8 types. A permutation routine that works on fp8 types.
""" """
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: if torch.is_floating_point(m) and m.dtype.itemsize == 1:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else: else:
return m[idx, ...] return m[idx, ...]

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import functools
import importlib.util import importlib.util
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
@ -9,6 +10,7 @@ from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
@ -434,6 +436,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
""" """
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config):
from vllm.model_executor.layers.fused_moe import fused_experts
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
@ -458,6 +461,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once( logger.warning_once(
"DeepGemm not supported on the current platform.") "DeepGemm not supported on the current platform.")
self.fused_experts = functools.partial(
fused_experts,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm)
def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
@ -783,6 +791,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> bool:
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
if self.use_marlin or self.rocm_aiter_moe_enabled:
return False
experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
self.fused_experts = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -801,10 +834,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts)
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
@ -819,6 +848,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts)
return rocm_aiter_fused_experts( return rocm_aiter_fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
@ -835,8 +866,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size) block_shape=self.quant_config.weight_block_size)
elif self.use_marlin:
if self.use_marlin:
assert activation == "silu", ( assert activation == "silu", (
f"{activation} not supported for Marlin MoE.") f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
@ -853,28 +883,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map)
else:
return fused_experts( return self.fused_experts(
x, hidden_states=x,
layer.w13_weight, w1=layer.w13_weight,
layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map, expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale), if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale), if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, )
allow_deep_gemm=self.allow_deep_gemm,
)
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):

View File

@ -79,7 +79,6 @@ class DbrxExperts(FusedMoE):
prefix=prefix, prefix=prefix,
) )
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.d_model = config.d_model self.d_model = config.d_model
self.intermediate_size = (self.config.ffn_config.ffn_hidden_size // self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
self.tp_size) self.tp_size)

View File

@ -31,9 +31,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -143,7 +141,8 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
@ -154,6 +153,7 @@ class DeepseekV2MoE(nn.Module):
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
@ -171,9 +171,11 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = (
final_hidden_states) self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)

View File

@ -25,8 +25,7 @@ from transformers import Llama4TextConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size, from vllm.distributed import get_tensor_model_parallel_world_size
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -89,7 +88,7 @@ class Llama4MoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
prefix=f"{prefix}.shared_expert", prefix=f"{prefix}.shared_expert",
reduce_results=False, # We need to do scatter before reduce reduce_results=self.experts.must_reduce_shared_expert_outputs(),
) )
def forward(self, hidden_states): def forward(self, hidden_states):
@ -102,7 +101,8 @@ class Llama4MoE(nn.Module):
experts_out = routed_out + shared_out experts_out = routed_out + shared_out
if self.tp_size > 1: if self.tp_size > 1:
experts_out = tensor_model_parallel_all_reduce(experts_out) experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
experts_out)
return experts_out return experts_out

View File

@ -33,9 +33,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@ -129,7 +127,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
intermediate_size=config.shared_expert_intermediate_size, intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
) )
else: else:
self.shared_expert = None self.shared_expert = None
@ -156,7 +155,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states) final_hidden_states)
return final_hidden_states.view(orig_shape) return final_hidden_states.view(orig_shape)

View File

@ -30,9 +30,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@ -137,7 +135,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
router_logits=router_logits) router_logits=router_logits)
final_hidden_states = final_hidden_states final_hidden_states = final_hidden_states
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states) final_hidden_states)
return final_hidden_states.view(orig_shape) return final_hidden_states.view(orig_shape)

View File

@ -158,6 +158,7 @@ class CudaPlatformBase(Platform):
"currently not supported with CUDA Graphs.") "currently not supported with CUDA Graphs.")
vllm_config.model_config.enforce_eager = True vllm_config.model_config.enforce_eager = True
compilation_config.use_cudagraph = False compilation_config.use_cudagraph = False
compilation_config.use_inductor = False
@classmethod @classmethod
def get_current_memory_usage(cls, def get_current_memory_usage(cls,

View File

@ -865,8 +865,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # The zero fill is required when used with DP + EP
return output # to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens num_actual_toks = attn_metadata.num_actual_tokens

View File

@ -341,7 +341,8 @@ def init_worker_distributed_environment(
distributed_init_method, local_rank) distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)

View File

@ -265,4 +265,5 @@ def init_tpu_worker_distributed_environment(
backend="gloo", backend="gloo",
) )
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)

View File

@ -390,7 +390,8 @@ class CPUWorker(LocalOrDistributedWorkerBase):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
def get_cache_block_size_bytes(self) -> int: def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block. """Return the size in bytes of a single KV cache block.

View File

@ -416,7 +416,8 @@ def init_worker_distributed_environment(
backend='hccl') backend='hccl')
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size() torch_world_size = torch.distributed.get_world_size()
@ -442,7 +443,8 @@ def init_worker_distributed_environment(
torch.distributed.all_reduce(dummy_tensor_hpu) torch.distributed.all_reduce(dummy_tensor_hpu)
assert dummy_tensor_hpu.item() == parallel_config.world_size assert dummy_tensor_hpu.item() == parallel_config.world_size
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len, def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,

View File

@ -76,7 +76,8 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
) )
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size, self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size) self.parallel_config.pipeline_parallel_size,
self.parallel_config.enable_expert_parallel)
# Device initialization should happen after initializing the distributed # Device initialization should happen after initializing the distributed
# runtime. # runtime.

View File

@ -530,7 +530,8 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank) distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)

View File

@ -176,7 +176,8 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
# global all_reduce needed for overall oneccl warm up # global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu()) torch.distributed.all_reduce(torch.zeros(1).xpu())