Modularize fused experts and integrate PPLX kernels (#15956)

This commit is contained in:
bnellnm 2025-05-14 16:11:54 -04:00 committed by GitHub
parent 418d2f8bfb
commit f9c069c85e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 3830 additions and 660 deletions

View File

@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \

View File

@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))

View File

@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}
if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer.data_ptr<int32_t>());
});
} else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem
auto kernel =
@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel());
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3.");
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors
auto options_int =

View File

@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
}
}
template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
template <int TPB, typename IndType>
__launch_bounds__(TPB) __global__ void moeTopK(
const float* inputs_after_softmax,
const bool* finished,
float* output,
IndType* indices,
int* source_rows,
const int num_experts,
const int k,
const int start_expert,
const int end_expert)
{
using cub_kvp = cub::KeyValuePair<int, float>;
@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k.
*/
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
@ -397,8 +405,8 @@ struct TopkConstants
};
} // namespace detail
template <int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
template <typename IndType>
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
float* topk_weights,
int* topk_indicies,
IndType* topk_indicies,
int* token_expert_indices,
float* softmax_workspace,
const int num_tokens,
@ -493,6 +502,9 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
if(topk_indices.scalar_type() == at::ScalarType::Int)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
@ -504,3 +516,18 @@ void topk_softmax(
topk,
stream);
}
else
{
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
}

View File

@ -65,11 +65,17 @@ def parse_args():
type=int,
default=0,
help="Master node port")
parser.add_argument("--enforce-eager",
action='store_true',
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action='store_true',
help="Trust remote code.")
return parser.parse_args()
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank):
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
max_tokens=[16, 20][global_dp_rank % 2])
# Create an LLM.
llm = LLM(model=model,
llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=True,
enable_expert_parallel=True)
enforce_eager=enforce_eager,
enable_expert_parallel=True,
trust_remote_code=trust_remote_code,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for i, output in enumerate(outputs):
@ -155,7 +164,8 @@ if __name__ == "__main__":
proc = Process(target=main,
args=(args.model, dp_size, local_dp_rank,
global_dp_rank, dp_master_ip, dp_master_port,
tp_size))
tp_size, args.enforce_eager,
args.trust_remote_code))
proc.start()
procs.append(proc)
exit_code = 0

View File

@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import pytest
import torch
import triton.language as tl
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)
@dataclass
class BatchedMMConfig:
dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
K: int
N: int
@dataclass
class BatchedMMTensors:
A: torch.Tensor # [E, max_tokens, K]
B: torch.Tensor # [E, K, N] - column major
C: torch.Tensor # [E, max_tokens, N]
num_expert_tokens: torch.Tensor # [E]
@staticmethod
def make_tensors(config: BatchedMMConfig):
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C
@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config)
test_output = tensors.C
ref_output = test_output.clone()
compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]
invoke_moe_batched_triton_kernel(
tensors.A,
tensors.B,
test_output,
tensors.num_expert_tokens,
compute_tl_dtype,
# Quantization data
None,
None,
None,
# Quantization schemes
False,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16
})
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
tensors.num_expert_tokens)
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)

View File

@ -30,6 +30,11 @@ MNK_FACTORS = [
(224, 3072, 1536),
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclasses.dataclass
class MOETensors:
@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
'topk_weights': topk_weights,
'topk_ids_': topk_ids,
'topk_ids': topk_ids,
'ab_strides1': moe_tensors.ab_strides1,
'c_strides1': moe_tensors.c_strides1,
'ab_strides2': moe_tensors.ab_strides2,
@ -231,15 +236,12 @@ def test_cutlass_moe_8_bit_no_graph(
per_out_ch: bool,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids = fused_topk(mt.a,
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)
@ -276,17 +278,14 @@ def test_cutlass_moe_8_bit_cuda_graph(
per_out_ch: bool,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
with set_current_vllm_config(vllm_config):
dtype = torch.half
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(mt.a,
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)
@ -334,15 +333,12 @@ def test_cutlass_moe_8_bit_EP(
ep_size: int,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel)
score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids = fused_topk(mt.a,
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)

View File

@ -12,6 +12,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
@ -32,6 +33,10 @@ NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
TOP_KS = [2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@ -70,6 +75,7 @@ def test_fused_moe(
else:
e_map = None
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a,
w1,
@ -95,6 +101,7 @@ def test_fused_moe(
global_num_experts=e,
expert_map=e_map,
renormalize=False)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(iterative_output,
torch_output,
@ -115,7 +122,6 @@ def test_fused_moe(
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
ep_size: int, dtype: torch.dtype, group_size: int,
has_zp: bool, weight_bits: int):
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
@ -194,6 +200,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
else:
e_map = None
with set_current_vllm_config(vllm_config):
triton_output = fused_moe(a,
w1_qweight,
w2_qweight,
@ -210,6 +217,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@ -515,6 +523,7 @@ def test_fused_marlin_moe(
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe(

View File

@ -0,0 +1,691 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for the MOE layers.
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import dataclasses
import os
import traceback
from typing import Callable, Optional
import pytest
import torch
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_finalize, nvshmem_get_unique_id,
nvshmem_init)
has_pplx = True
except ImportError:
has_pplx = False
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import override_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
get_default_config)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
(222, 2048, 1024)]
PPLX_MOE_COMBOS = [
(1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
(32, 128, 1024),
(45, 512, 2048),
(64, 1024, 1024),
(222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
TOP_KS = [1, 2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
P = ParamSpec("P")
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)
def torch_prepare(
a: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
max_num_tokens: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert topk_ids.dim() == 2
assert topk_ids.shape[0] == a.shape[0]
num_tokens, hidden_dim = a.shape
topk = topk_ids.shape[1]
tokens_per_expert = torch.bincount(topk_ids.view(-1),
minlength=num_experts)
assert tokens_per_expert.numel() == num_experts
if max_num_tokens is None:
max_num_tokens = int(tokens_per_expert.max().item())
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
dtype=a.dtype,
device=a.device)
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
for token in range(num_tokens):
for j in range(topk):
expert_id = topk_ids[token, j]
idx = token_counts[expert_id]
b_a[expert_id, idx:idx + 1, :] = a[token, :]
token_counts[expert_id] = token_counts[expert_id] + 1
return b_a, tokens_per_expert
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
num_tokens = topk_ids.shape[0]
num_experts = b_out.shape[0]
K = b_out.shape[-1]
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
expert_counts = torch.zeros(num_experts,
dtype=torch.int,
device=b_out.device)
for token in range(num_tokens):
expert_ids = topk_ids[token]
for i in range(expert_ids.numel()):
expert_id = expert_ids[i]
idx = expert_counts[expert_id]
out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
1, :] * topk_weight[token, i]
expert_counts[expert_id] = expert_counts[expert_id] + 1
return out
def torch_batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
num_experts = w1.shape[0]
b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
assert b_a.dim() == 3
num_tokens, topk = topk_ids.shape
_, max_num_tokens, K = b_a.shape
assert num_experts == b_a.shape[0] and w2.shape[1] == K
out = torch.zeros((num_experts, max_num_tokens, K),
dtype=b_a.dtype,
device=b_a.device)
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
dtype=b_a.dtype,
device=b_a.device)
for expert in range(num_experts):
num = tokens_per_expert[expert]
if num > 0:
torch.ops._C.silu_and_mul(
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
return torch_finalize(out, topk_weight, topk_ids)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
num_experts = w1.shape[0]
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0),
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
# Note: same as torch_moe but with fused_topk factored out.
def torch_moe2(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_fused_moe_batched_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
):
current_platform.seed_everything(7)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
torch.testing.assert_close(baseline_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_output,
batched_output,
atol=2e-2,
rtol=0)
def rank_chunk(num: int, r: int, w: int) -> int:
rem = num % w
return (num // w) + (1 if r < rem else 0)
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
chunk = rank_chunk(t.shape[0], r, w)
return t[(r * chunk):(r + 1) * chunk]
def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
topk_weight: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
assert torch.cuda.current_device() == pgi.local_rank
topk = topk_ids.shape[1]
num_tokens, hidden_dim = a.shape
block_size = 128
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
)
topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens,
world_size,
rank,
dp_size,
a.dtype,
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare(
a_chunk,
None,
None,
chunk_topk_weight,
chunk_topk_ids,
num_experts,
None,
False,
)
b_a = b_a * 1.5
out = torch.full(
(max_num_tokens, hidden_dim),
torch.nan,
dtype=a.dtype,
device=device,
)
prepare_finalize.finalize(
out,
b_a,
chunk_topk_weight,
chunk_topk_ids,
False,
)
torch.cuda.synchronize()
ata.destroy()
num_tokens = a_chunk.shape[0]
return out[:num_tokens]
def _pplx_prepare_finalize(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
score: torch.Tensor,
topk: torch.Tensor,
num_experts: int,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
device = pgi.device
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
k = a.shape[1]
a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
torch_output = (a_rep.view(-1, topk, k) * 1.5 *
topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(
a.dtype)
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
num_experts)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
nvshmem_finalize()
# TODO (bnell): this test point does not work for odd M due to how the test is
# written, not due to limitations of the pplx kernels. The pplx_moe
# test below is able to deal with odd M.
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@requires_pplx
def test_pplx_prepare_finalize(
mnk: tuple[int, int, int],
e: int,
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
):
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
device = "cuda"
a = torch.randn((m, k), device=device, dtype=dtype) / 10
score = torch.randn((m, e), device=device, dtype=dtype)
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
topk, e)
def pplx_moe(
rank: int,
world_size: int,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
use_compile: bool = True,
use_cudagraphs: bool = True,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
device = torch.device("cuda", rank)
hidden_dim = a.shape[1]
num_experts = w1.shape[0]
block_size = 128
topk = topk_ids.shape[1]
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
)
topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens,
world_size,
rank,
dp_size,
)
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
world_size=world_size,
dp_size=dp_size)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
# Chunking weights like this only works for batched format
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
if use_compile:
_fused_experts = torch.compile(fused_experts,
backend='inductor',
fullgraph=True)
else:
_fused_experts = fused_experts
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
if use_cudagraphs:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
ata.destroy()
return out
def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
assert torch.cuda.current_device() == pgi.local_rank
num_experts = w1.shape[0]
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
prepare_finalize = BatchedPrepareAndFinalize(
max_num_tokens=max_num_tokens,
world_size=world_size,
dp_size=dp_size,
rank=rank,
)
experts = BatchedExperts(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
out = fused_experts(
a_chunk,
# Chunking weights like this only works for batched format
chunk_by_rank(w1, rank, world_size).to(device),
chunk_by_rank(w2, rank, world_size).to(device),
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
return out
def _pplx_moe(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
m, k = a.shape
e, _, n = w2.shape
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
topk_weight, topk_ids)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
nvshmem_finalize()
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@requires_pplx
def test_pplx_moe(
mnk: tuple[int, int, int],
e: int,
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
):
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)

View File

@ -7,6 +7,7 @@ import pytest
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.platforms import current_platform
@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input
@ -137,6 +142,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
score = torch.randn((M, E), dtype=dtype)
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
out = fused_moe(
a,

View File

@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
deep_gemm_moe_fp8)
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam.
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes
@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
block_size = [block_m, block_m]
dtype = torch.bfloat16
# only aligned sizes
if (N % block_m != 0 or K % block_m != 0 or topk > E):
pytest.skip(
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
if N <= 512:
pytest.skip("Skipping N <= 512 until performance issues solved.")
vllm_config = VllmConfig()
if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn)

View File

@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# For test
def native_per_token_group_quant_int8(x,
@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam.
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,

View File

@ -23,6 +23,7 @@ If you only need to use the distributed environment without model/pipeline
"""
import contextlib
import gc
import importlib.util
import pickle
import weakref
from collections import namedtuple
@ -42,7 +43,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
supports_custom_op)
run_once, supports_custom_op)
@dataclass
@ -936,9 +937,49 @@ def init_distributed_environment(
"world group already initialized with a different world size")
PPLX_DID_INIT: bool = False
@run_once
def pplx_init(rank, world_size):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if has_pplx and world_size > 1:
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id, nvshmem_init)
try:
global PPLX_DID_INIT
logger.debug(
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
"world size=%d", rank, world_size)
uid = nvshmem_get_unique_id(
) if rank == 0 else nvshmem_alloc_empty_unique_id()
uid_gpu = uid.cuda()
get_world_group().broadcast(uid_gpu, src=0)
uid = uid_gpu.to(device='cpu')
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, rank, world_size)
PPLX_DID_INIT = True
except Exception as ex:
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
@run_once
def pplx_finalize():
global PPLX_DID_INIT
if PPLX_DID_INIT:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
from vllm.model_executor.layers.fused_moe.layer import (
_all_to_all_cache)
_all_to_all_cache.destroy()
nvshmem_finalize()
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
enable_expert_parallel: bool = False,
backend: Optional[str] = None,
) -> None:
"""
@ -1041,10 +1082,14 @@ def initialize_model_parallel(
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group)
if enable_expert_parallel:
pplx_init(rank, world_size)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
enable_expert_parallel: bool = False,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
@ -1055,7 +1100,8 @@ def ensure_model_parallel_initialized(
get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, backend)
pipeline_model_parallel_size,
enable_expert_parallel, backend)
return
assert (
@ -1133,6 +1179,9 @@ def get_tensor_model_parallel_rank():
def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
pplx_finalize()
if _TP:
_TP.destroy()
_TP = None

View File

@ -23,7 +23,7 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import get_tcp_uri
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
logger = init_logger(__name__)
@ -362,12 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group().
"""
if is_torch_equal_or_newer("2.7"):
pg.shutdown()
else:
# Lazy import for non-CUDA backends.
try:
# pytorch <= 2.6
from torch.distributed.distributed_c10d import _shutdown_backend
_shutdown_backend(pg)
except ImportError:
# pytorch >= 2.7
pg.shutdown()
_unregister_process_group(pg.group_name)

View File

@ -27,6 +27,7 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@dataclass
class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor
@ -90,8 +91,10 @@ def set_forward_context(attn_metadata: Any,
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
dp_metadata = DPMetadata(cu_tokens_across_dp_cpu)
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
cu_tokens_across_dp_cpu)
global _forward_context
prev_context = _forward_context

View File

@ -38,8 +38,8 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4, cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
TritonExperts, fused_experts, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
__all__ += [
"fused_moe",
@ -49,4 +49,5 @@ if HAS_TRITON:
"grouped_topk",
"cutlass_moe_fp8",
"cutlass_moe_fp4",
"TritonExperts",
]

View File

@ -5,10 +5,176 @@ from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
from vllm.scalar_type import scalar_types
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
out_dtype: torch.dtype,
):
super().__init__()
self.ab_strides1 = ab_strides1
self.c_strides1 = c_strides1
self.ab_strides2 = ab_strides2
self.c_strides2 = c_strides2
self.out_dtype = out_dtype
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
# Note that K, N are transposed
N, K = K, N
workspace1 = M * topk * max(2 * N, K)
workspace2 = M * topk * N
return (workspace1, workspace2, self.out_dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
a1q = hidden_states
assert w1_scale is not None
assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1"
assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2"
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert a1q_scale is None or a1q_scale.dim(
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2.shape[2], "W2 scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
assert w1.shape[0] == w1_scale.shape[
0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
assert self.ab_strides1.shape[0] == w1.shape[
0], "AB Strides 1 expert number mismatch"
assert self.c_strides1.shape[0] == w1.shape[
0], "C Strides 1 expert number mismatch"
assert self.ab_strides2.shape[0] == w2.shape[
0], "AB Strides 2 expert number mismatch"
assert self.c_strides2.shape[0] == w2.shape[
0], "C Strides 2 expert number mismatch"
assert self.out_dtype in [torch.half,
torch.bfloat16], "Invalid output dtype"
M = a1q.shape[0]
_, N, K = w2.shape # because w1 + w2 are transposed
device = a1q.device
assert w1.shape[1] == K
assert global_num_experts != -1
assert a1q_scale is not None
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids] != -1,
expert_map[topk_ids], -1)
else:
local_topk_ids = topk_ids
topk = local_topk_ids.shape[1]
per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((global_num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((global_num_experts, 3),
dtype=torch.int32,
device=device)
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
if expert_map is not None:
a_map = torch.zeros((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
else:
a_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale,
expert_offsets[:-1], problem_sizes1,
self.ab_strides1, self.ab_strides1, self.c_strides1)
self.activation(activation, c2, c1)
a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
if expert_map is not None:
c3.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale,
expert_offsets[:-1], problem_sizes2,
self.ab_strides2, self.ab_strides2, self.c_strides2)
c3 = c3[c_map]
return c3
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def cutlass_moe_fp8(
a: torch.Tensor,
@ -17,7 +183,7 @@ def cutlass_moe_fp8(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids_: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
@ -59,7 +225,7 @@ def cutlass_moe_fp8(
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type.
- out_dtype (torch.dtype): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
@ -71,115 +237,36 @@ def cutlass_moe_fp8(
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert w1_q.dtype == torch.float8_e4m3fn
assert w2_q.dtype == torch.float8_e4m3fn
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert a1_scale is None or a1_scale.dim(
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1_q.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2_q.shape[2], "W2 scale shape mismatch"
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[
0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
assert ab_strides1.shape[0] == w1_q.shape[
0], "AB Strides 1 expert number mismatch"
assert c_strides1.shape[0] == w1_q.shape[
0], "C Strides 1 expert number mismatch"
assert ab_strides2.shape[0] == w2_q.shape[
0], "AB Strides 2 expert number mismatch"
assert c_strides2.shape[0] == w2_q.shape[
0], "C Strides 2 expert number mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(1)
n = w2_q.size(1)
local_topk_ids = topk_ids_
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids_] != -1,
expert_map[topk_ids_], -1)
topk = local_topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
if apply_router_weight_on_input:
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
# TODO: this only works for topK=1, will need to update for topK>1
a = a * topk_weights.to(out_dtype)
a_q, a1_scale = ops.scaled_fp8_quant(
a, a1_scale, use_per_token_if_dynamic=per_act_token)
device = a_q.device
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
per_channel_quant=per_act_token,
quant_dtype=torch.float8_e4m3fn,
),
CutlassExpertsFp8(
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
out_dtype,
),
)
expert_offsets = torch.empty((num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
a_map_initializer = torch.empty
c2_initializer = torch.empty
if expert_map is not None:
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
a_map_initializer = torch.zeros
c2_initializer = torch.zeros
a_map = a_map_initializer((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, num_experts, n,
k)
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype)
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
expert_offsets[:-1], problem_sizes1, ab_strides1,
ab_strides1, c_strides1)
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
torch.ops._C.silu_and_mul(intermediate, c1)
intemediate_q, a2_scale = ops.scaled_fp8_quant(
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
expert_offsets[:-1], problem_sizes2, ab_strides2,
ab_strides2, c_strides2)
# Gather tokens
c2 = c2[c_map].view(m, topk, k)
if not apply_router_weight_on_input:
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
return c2.sum(dim=1)
return fn(
a,
w1_q,
w2_q,
topk_weights,
topk_ids,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()

View File

@ -1,16 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
import functools
import importlib.util
from typing import Optional
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache)
from vllm.utils import round_up
@ -19,6 +20,19 @@ logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@functools.cache
def deep_gemm_block_shape() -> list[int]:
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
block = dg.get_m_alignment_for_contiguous_layout()
return [block, block]
def _valid_deep_gemm_shape(M: int, N: int, K: int):
align = deep_gemm_block_shape()[0]
return align <= M and N % align == 0 and K % align == 0
def _valid_deep_gemm(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -29,89 +43,112 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if not has_deep_gemm:
logger.debug("DeepGemm disabled: deep_gemm not available.")
return False
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
# Expert maps not supported yet.
if expert_map is not None:
logger.debug("DeepGemm disabled: expert map NYI.")
return False
align = dg.get_m_alignment_for_contiguous_layout()
M = hidden_states.shape[0]
_, K, N = w2.shape
# For now, disable DeepGemm for small N until better permute/unpermute
# ops are available.
if N <= 512:
M = hidden_states.size(0)
_, K, N = w2.size()
if not _valid_deep_gemm_shape(M, N, K):
logger.debug("DeepGemm disabled: unalinged problem size.")
return False
if align > M or N % align != 0 or K % align != 0:
if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):
logger.debug("DeepGemm disabled: invalid weight dtype(s).")
return False
return (hidden_states.is_contiguous() and w1.is_contiguous()
and w2.is_contiguous())
if (not hidden_states.is_contiguous() or not w1.is_contiguous()
or not w2.is_contiguous()):
logger.debug(
"DeepGemm disabled: weights or activations not contiguous.")
return False
return True
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
curr_topk_ids: torch.Tensor,
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self):
super().__init__()
self.block_shape = deep_gemm_block_shape()
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
workspace1 = M_sum * max(N * 2, K)
workspace2 = M_sum * N
return (workspace1, workspace2, a.dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
block_m: int,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.shape[1]
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
import deep_gemm as dg
tokens_in_chunk, _ = curr_hidden_states.shape
a1q = hidden_states
_, N, K = w1.size()
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids,
block_m,
assert global_num_experts != -1
assert w2.size(1) == K
a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
a1q,
a1q_scale,
topk_ids,
global_num_experts,
expert_map,
pad_sorted_ids=True))
self.block_shape[0],
)
inv_perm: Optional[torch.Tensor] = None
# Note: M_sum is different than the pre-permuted shape of a1q.
M_sum = a1q.size(0)
workspace1 = _resize_cache(workspace13, (M_sum, N))
workspace2 = _resize_cache(workspace2, (M_sum, N // 2))
workspace3 = _resize_cache(workspace13, (M_sum, K))
num_tokens = top_k_num * tokens_in_chunk
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids)
# Permute according to sorted token ids.
curr_hidden_states = _fp8_perm(curr_hidden_states,
sorted_token_ids // top_k_num)
self.activation(activation, workspace2, workspace1.view(-1, N))
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
a2q_scale: Optional[torch.Tensor] = None
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm)
a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False,
self.block_shape)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: Optional[torch.Tensor],
topk_weight: torch.Tensor,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.shape
K = curr_hidden.shape[1]
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
workspace3 = workspace3[inv_perm, ...]
return workspace3
def deep_gemm_moe_fp8(
@ -128,6 +165,7 @@ def deep_gemm_moe_fp8(
expert_map: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input=False,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
@ -166,129 +204,24 @@ def deep_gemm_moe_fp8(
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
assert expert_map is None, "Expert maps not supported yet"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a1_scale is None or a1_scale.dim(
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
0] == hidden_states.shape[0], "Input scale shape mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
K = w2.shape[1]
if global_num_experts == -1:
global_num_experts = E
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
block_m = dg.get_m_alignment_for_contiguous_layout()
block_shape = [block_m, block_m]
assert w1_scale is not None
assert w2_scale is not None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
num_chunks = (num_tokens // CHUNK_SIZE) + 1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.empty(M_sum * max(N, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
workspace1 = workspace13[:M_sum * N].view(M_sum, N)
workspace2 = torch.empty((M_sum, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
workspace3 = workspace13[:M_sum * K].view(M_sum, K)
for chunk in range(num_chunks):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
a1_scale, block_shape)
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
curr_topk_ids, global_num_experts,
expert_map, block_m)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
curr_M = sorted_token_ids.numel()
workspace1 = _resize_cache(workspace1, (curr_M, N))
workspace2 = _resize_cache(workspace2, (curr_M, N // 2))
workspace3 = _resize_cache(workspace3, (curr_M, K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1,
expert_ids)
if activation == "silu":
torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N))
elif activation == "gelu":
torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
a2q_scale: Optional[torch.Tensor] = None
qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale,
block_shape)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
_moe_unpermute_and_reduce(
out_hidden_states[begin_chunk_idx:end_chunk_idx],
workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights)
return out_hidden_states
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn,
block_shape=deep_gemm_block_shape()),
DeepGemmExperts(),
)
return fn(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace,
activation,
global_num_experts,
expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@ -0,0 +1,755 @@
# SPDX-License-Identifier: Apache-2.0
"""Fused batched MoE kernel."""
from typing import Optional
import torch
import triton
import triton.language as tl
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
@triton.jit
def moe_mmk(
a_ptrs,
b_ptrs,
K,
expert_id,
a_scale_ptr,
b_scale_ptr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ak,
stride_bk,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Offsets and masks
offs_m,
offs_n,
mask_m,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
compute_type: tl.constexpr,
use_w8a8: tl.constexpr,
use_w8a16: tl.constexpr):
offs_k = tl.arange(0, BLOCK_K)
if use_w8a16:
b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[
None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
if use_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
offs_bsn = offs_n // group_n
b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse +
offs_bsn * stride_bsn)
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + expert_id)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs,
mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K),
other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
# We accumulate along the K dimension.
if use_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_K
offs_ks = k_start // group_k
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
mask=mask_m,
other=0.0)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:,
None] * b_scale[None, :]
else:
if use_w8a8:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
if use_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
return accumulator
@triton.jit
def expert_triton_kernel(
a_ptr, #[max_tokens, K]
b_ptr, #[K, N]
c_ptr, #[max_tokens, N]
expert_id,
compute_type: tl.constexpr,
# Dimensions
M,
N,
K,
# Quantization data
a_scale_ptr,
b_scale_ptr,
b_zp_ptr,
# strides
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Blockwise quantization data
group_n,
group_k,
# Quantization schemes
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
# Kernel config
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N) % N
offs_k = tl.arange(0, BLOCK_K)
mask_m = offs_m < M
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
accumulator = moe_mmk(
a_ptrs,
b_ptrs,
K,
expert_id,
a_scale_ptr,
b_scale_ptr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ak,
stride_bk,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Offsets and masks
offs_m,
offs_n,
mask_m,
# Block size for block-wise quantization
group_n,
group_k,
# Meta-parameters
BLOCK_M,
BLOCK_N,
BLOCK_K,
compute_type,
use_fp8_w8a8,
use_int8_w8a16)
# store in C
offs_cn = tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@triton.jit
def batched_triton_kernel(
a_ptr, # [E, max_num_tokens, K]
b_ptr, # [E, K, N]
c_ptr, # [E, max_num_tokens, N]
expert_num_tokens, # [E]
compute_type: tl.constexpr,
# Dimensions
max_num_tokens,
K,
N,
# Quantization data
a_scale_ptr,
b_scale_ptr,
b_zp_ptr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ae,
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_ce,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Blockwise quantization data
group_n: tl.constexpr,
group_k: tl.constexpr,
# Quantization schemes
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
# Kernel config
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
expert_id = tl.program_id(axis=0)
e_num_tokens = tl.load(expert_num_tokens + expert_id)
if e_num_tokens == 0:
# Early exit
return
pid_mn = tl.program_id(axis=1)
#num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid_mn // num_pid_n
pid_n = pid_mn % num_pid_n
cta_m_start = pid_m * BLOCK_M
cta_n_start = pid_n * BLOCK_N
if cta_m_start >= e_num_tokens:
# Early exit
return
cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start)
cta_n_size = min(BLOCK_N, N - cta_n_start)
a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
cta_n_start * stride_cn)
expert_triton_kernel(
a_ptr,
b_ptr,
c_ptr,
expert_id,
compute_type,
cta_m_size, # M
cta_n_size, # N
K, # K
a_scale_ptr,
b_scale_ptr,
b_zp_ptr,
# Strides
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Blockwise quantization data
group_n,
group_k,
# Quantization schemes
use_fp8_w8a8,
use_int8_w8a16,
# Kernel config
BLOCK_M,
BLOCK_N,
BLOCK_K)
def invoke_moe_batched_triton_kernel(
A: torch.Tensor, # [E, max_tokens, K]
B: torch.Tensor, # [E, K, N]
C: torch.Tensor, # [E, max_tokens, N]
expert_num_tokens: torch.Tensor, # [E]
compute_type: tl.dtype,
# Quantization data
A_scale: torch.Tensor,
B_scale: torch.Tensor,
B_zp: torch.Tensor,
# Quantization schemes
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
config: dict[str, int],
block_shape: Optional[list[int]] = None):
assert not use_int4_w4a16
max_num_tokens = A.size(1)
K = A.size(2)
N = C.size(2)
BLOCK_M = config['BLOCK_SIZE_M']
BLOCK_N = config['BLOCK_SIZE_N']
BLOCK_K = config['BLOCK_SIZE_K']
assert (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()
or max_num_tokens % BLOCK_M == 0)
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
triton.cdiv(B.size(1), BLOCK_N))
batched_triton_kernel[grid](
A,
B,
C,
expert_num_tokens,
compute_type,
# Dimensions
max_num_tokens,
K,
N,
# Quantization data
A_scale,
B_scale,
B_zp,
# Strides
A.stride(0),
A.stride(1),
A.stride(2),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(0),
C.stride(1),
C.stride(2),
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
# Blockwise quantization data
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
# Quantization schemes
use_fp8_w8a8,
use_int8_w8a16,
# Kernel config
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K)
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
A reference prepare/finalize class that reorganizes the tokens into
expert batched format, i.e. E x max_num_tokens x K. This is the format
that the PPLX dispatch/combine kernels use.
"""
def __init__(self, max_num_tokens: Optional[int], world_size: int,
dp_size: int, rank: int):
super().__init__()
self.world_size = world_size
self.dp_size = dp_size
self.rank = rank
self.max_num_tokens = max_num_tokens
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
assert a1.dim() == 2
assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0)
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
num_tokens, hidden_dim = a1.size()
topk = topk_ids.size(1)
if self.max_num_tokens is None:
tokens_per_expert = torch.bincount(topk_ids.view(-1),
minlength=num_experts)
self.max_num_tokens = int(tokens_per_expert.max().item())
else:
tokens_per_expert = torch.zeros(num_experts,
dtype=torch.int,
device=a1.device)
assert num_experts % self.world_size == 0
num_local_experts = num_experts // self.world_size
b_a1 = torch.zeros(
(num_local_experts, self.max_num_tokens, hidden_dim),
dtype=a1.dtype,
device=a1.device)
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
for expert_id in range(first_expert, last_expert):
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
rows = torch.count_nonzero(topks.flatten())
b_a1[expert_id -
first_expert, :rows, :] = a1[:topks.numel()][topks]
tokens_per_expert[expert_id - first_expert] = rows
return b_a1, a1_scale, tokens_per_expert
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
num_tokens = topk_ids.size(0)
num_local_experts = fused_expert_output.size(0)
K = fused_expert_output.size(-1)
assert output.size(0) == num_tokens and output.size(1) == K
output.fill_(0)
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
for expert_id in range(first_expert, last_expert):
matching_tokens = topk_ids == expert_id
topks = torch.any(matching_tokens, dim=1).flatten()
rows = torch.count_nonzero(topks)
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
if not apply_router_weight_on_input:
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
output[topks] = output[topks] + rhs
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
A reference MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx
dispatch/combine kernels use.
"""
def __init__(
self,
world_size: int,
dp_size: int,
max_num_tokens: Optional[int] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
):
super().__init__()
assert block_shape is None
assert block_m is None
assert not use_fp8_w8a8, "NYI"
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
self.max_num_tokens = max_num_tokens
self.world_size = world_size
self.dp_size = dp_size
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
workspace13 = num_experts * max_num_tokens * num_dp * K
workspace2 = max_num_tokens * num_dp * N
return (workspace13, workspace2, a.dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
assert hidden_states.dim() == 3
assert expert_num_tokens is not None
hidden_dim = hidden_states.size(-1)
if self.max_num_tokens is None:
max_num_tokens = hidden_states.size(1)
else:
max_num_tokens = self.max_num_tokens
num_dp = self.world_size // self.dp_size
num_experts = global_num_experts
out = _resize_cache(workspace13,
(num_experts, max_num_tokens * num_dp, hidden_dim))
num_local_experts = w1.size(0)
assert num_local_experts == w1.size(0), (
f"{num_local_experts} == {w1.size(0)}")
N = w1.size(1) // 2
# Not cudagraph friendly
assert (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()
or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), (
f"{expert_num_tokens} <= {max_num_tokens * num_dp}")
for expert in range(num_local_experts):
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
if (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()):
num = max_num_tokens * num_dp
else:
num = int(expert_num_tokens[expert].item())
tmp = _resize_cache(workspace2, (num, N))
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
self.activation(activation, tmp, input)
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
return out
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
A Triton based MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx
dispatch/combine kernels use.
"""
def __init__(
self,
max_num_tokens: Optional[int] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
block_shape: Optional[list[int]] = None,
world_size: int = 1,
dp_size: int = 1,
):
super().__init__()
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int8_w8a8 = use_int8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape
self.max_num_tokens = max_num_tokens
assert not use_int8_w8a8, "NYI"
assert not use_int4_w4a16, "NYI"
self.world_size = world_size
self.dp_size = dp_size
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
return (workspace13, workspace2, a.dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
assert hidden_states.size(-1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(-1)} "
f"!= {w1.size(2)}")
assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
]
# TODO: num_tokens -> max_num_tokens?
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids)
assert w1.size(0) == E
assert w2.size(0) == E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.size(),
w2.size(),
top_k_num,
config_dtype,
num_tokens,
block_shape=self.block_shape,
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn:
compute_type = tl.bfloat16
else:
raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}")
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
intermediate_cache2 = _resize_cache(workspace2,
(E, num_tokens, N // 2))
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
# MM1
invoke_moe_batched_triton_kernel(A=hidden_states,
B=w1,
C=intermediate_cache1,
expert_num_tokens=expert_num_tokens,
compute_type=compute_type,
A_scale=a1q_scale,
B_scale=w1_scale,
B_zp=w1_zp,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
config=config,
block_shape=self.block_shape)
# TODO: would be nice to use expert_num_tokens here to reduce
# garbage compute
self.activation(activation, intermediate_cache2.view(-1, N // 2),
intermediate_cache1.view(-1, N))
#qintermediate_cache2 = intermediate_cache2
a2q_scale = a2_scale
# TODO (varun) : support w8a8
assert not self.use_fp8_w8a8
#if self.use_fp8_w8a8:
# qintermediate_cache2, a2q_scale = _fp8_quantize(
# intermediate_cache2, a2_scale, self.block_shape)
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
B=w2,
C=intermediate_cache3,
expert_num_tokens=expert_num_tokens,
compute_type=compute_type,
A_scale=a2q_scale,
B_scale=w2_scale,
B_zp=w2_zp,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
config=config,
block_shape=self.block_shape)
return intermediate_cache3

View File

@ -8,16 +8,17 @@ from typing import Any, Callable, Optional
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.shape[0]
num_tokens = M * top_k
@ -855,6 +870,7 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
indices_type: Optional[torch.dtype] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
@ -865,9 +881,10 @@ def fused_topk(
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device)
token_expert_indices = torch.empty(M,
topk,
@ -962,6 +979,20 @@ def get_config_dtype_str(
return None
# TODO (bnell): use scalar_type instead of bools?
def get_config_qtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
) -> Optional[torch.dtype]:
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
return None
def inplace_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -1128,7 +1159,10 @@ def fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.shape[1]
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8(
@ -1145,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
return dispatch_fused_experts_func(inplace)(
@ -1171,60 +1206,8 @@ def fused_experts(hidden_states: torch.Tensor,
block_shape=block_shape)
def moe_kernel_prepare_input(
A: torch.Tensor,
B: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
# activation channel-wise int8 quantization
assert (per_channel_quant
), "int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
return A, A_scale
def fused_experts_impl(hidden_states: torch.Tensor,
def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
@ -1245,13 +1228,15 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None):
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch"
else:
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert hidden_states.shape[1] == w1.shape[2], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
@ -1261,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
torch.float32, torch.float16, torch.bfloat16
]
num_tokens, _ = hidden_states.shape
num_tokens = hidden_states.shape[0]
E, N, _ = w1.shape
K = w2.shape[1]
if global_num_experts == -1:
@ -1276,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16,
dtype=hidden_states.dtype)
qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
@ -1338,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states,
B=w1,
A_scale=a1_scale,
B_scale=w1_scale,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
qtype=qtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
@ -1357,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel(qcurr_hidden_states,
w1,
intermediate_cache1,
qa1_scale,
a1q_scale,
w1_scale,
w1_zp,
curr_topk_weights,
@ -1384,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
qintermediate_cache2, qa2_scale = moe_kernel_prepare_input(
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
B=w2,
A_scale=a2_scale,
B_scale=w2_scale,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
qtype=qtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
qa2_scale,
a2q_scale,
w2_scale,
w2_zp,
curr_topk_weights,
@ -1534,3 +1514,209 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
):
super().__init__()
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape
self.block_m = block_m
self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
self.per_channel_quant = per_channel_quant
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
factor = num_experts if a.dim() == 3 else 1
workspace1 = M * topk * max(N * 2, K) * factor
workspace2 = M * topk * N * factor
return (workspace1, workspace2, a.dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
assert hidden_states.size(-1) == w1.size(2), \
(f"Hidden size mismatch {hidden_states.size(-1)} "
f"!= {w1.size(2)}")
assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous"
assert hidden_states.dim() == 2
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
]
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids)
if global_num_experts == -1:
global_num_experts = E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.shape,
w2.shape,
top_k_num,
config_dtype,
num_tokens,
block_shape=self.block_shape,
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn:
compute_type = tl.bfloat16
else:
raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13,
(num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(workspace2,
(num_tokens * top_k_num, N // 2))
intermediate_cache3 = _resize_cache(workspace13,
(num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map))
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
self.activation(activation, intermediate_cache2,
intermediate_cache1.view(-1, N))
a2q_scale: Optional[torch.Tensor] = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant,
self.block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
return intermediate_cache3
def modular_triton_fused_moe(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> mk.FusedMoEModularKernel:
qtype = get_config_qtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
)
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
quant_dtype=qtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
),
TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
),
)

View File

@ -1,15 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
import importlib
import threading
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
from weakref import WeakValueDictionary
import torch
import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@ -26,8 +30,17 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
from .fused_batched_moe import (BatchedPrepareAndFinalize,
BatchedTritonExperts)
from .fused_moe import TritonExperts, fused_experts
from .modular_kernel import (FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
if has_pplx:
from .pplx_prepare_finalize import PplxPrepareAndFinalize
else:
fused_experts = None # type: ignore
if is_rocm_aiter_moe_enabled():
@ -42,6 +55,179 @@ else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
# Note: this limit is somewhat arbitrary and might be changed later.
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
MOE_DP_CHUNK_SIZE = 256
@dataclass
class FusedMoEParallelConfig:
tp_size: int
dp_size: int
ep_size: int
tp_rank: int
dp_rank: int
ep_rank: int
use_ep: bool # whether to use EP or not
@property
def use_pplx_kernels(self):
return self.dp_size > 1 and self.use_ep and has_pplx
@staticmethod
def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
"""
Determine MoE parallel configuration. Based on the input tp_size_,
dp_size_, ep_size_ and vllm's parallel config, determine what
level's of parallelism to use in the fused moe layer.
Args:
tp_size_ (int): tp_size passed into the FusedMoE constructor.
dp_size_ (int): dp_size passed into the FusedMoE constructor.
ep_size_ (int): ep_size passed into the FusedMoE constructor.
vllm_parallel_config (ParallelConfig): vllm's parallel config
object.
Examples:
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
we simply return the sizes unaltered and the ranks set to 0.
Expert Parallelism is considered only when either dp_size_ or tp_size_
is non trivial.
When TP = 2, DP = 1 and EP = False, the configuration on different
devices,
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
legend : {size, rank}
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
- Comment : Tensors are sharded across 2 devices.
When TP = 1, DP = 2 and EP = False, the configuration on different
devices,
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
- Comment: There are 2 engine instances and the tensors are sharded
across 2 decvices.
When TP = 2, DP = 2 and EP = False, the configuration on different
devices,
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
- Comment: There are 2 engine instances and the tensors are sharded
across 4 devices.
When, TP = 2, DP = 1 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
- Comment: The experts are split between the 2 devices.
When, TP = 1, DP = 2 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
- Comment: There are 2 engine instances and the experts are split
between the 2 devices.
When TP = 2, DP = 2 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
- Comment: There are 2 engine instances and the experts are split
between the 4 devices.
"""
def flatten_tp_across_dp(dp_rank: int):
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
# There are actually dp_size_ * tp_size_ devices. Update tp_size
# and tp_rank so we shard across all devices.
tp_size = dp_size_ * tp_size_
tp_rank = dp_rank * tp_size_ + tp_rank
return tp_size, tp_rank
use_ep = (dp_size_ * tp_size_ > 1
and vllm_parallel_config.enable_expert_parallel)
dp_size = dp_size_
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
if not use_ep:
return FusedMoEParallelConfig(tp_size=tp_size,
tp_rank=tp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=1,
ep_rank=0,
use_ep=False)
# DP + EP / TP + EP / DP + TP + EP
assert use_ep
# In EP, each device owns a set of experts fully. There is no tensor
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
ep_size = tp_size
ep_rank = tp_rank
return FusedMoEParallelConfig(tp_size=1,
tp_rank=0,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
use_ep=True)
# Adapted from pplx-kernels tests/all_to_all_utils.py
@dataclass
class MoEConfig:
num_experts: int
experts_per_token: int
hidden_dim: int
num_local_experts: int
moe_parallel_config: FusedMoEParallelConfig
in_dtype: torch.dtype # The activation type.
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@property
def dp_size(self):
return self.moe_parallel_config.dp_size
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@property
def tp_rank(self):
return self.moe_parallel_config.tp_rank
@property
def dp_rank(self):
return self.moe_parallel_config.dp_rank
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@property
def use_ep(self):
return self.moe_parallel_config.use_ep
@property
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
@ -58,6 +244,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool:
return False
@abstractmethod
def apply(
self,
@ -80,12 +274,54 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError
class AllToAllCache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
def destroy(self):
with self._lock:
# TODO: can we do del self._cache?
for _, a2a in self._cache.items():
a2a.destroy()
def get_or_create(self, **kwargs):
assert has_pplx
import pplx_kernels as pplx
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))
with self._lock:
instance = self._cache.get(key)
if instance is None:
# TODO (varun): Add support to switch to intranode
# when all communications are within the same
# node.
logger.debug("Create AllToAll %s", kwargs)
instance = pplx.AllToAll.internode(**kwargs)
self._cache[key] = instance
return instance
# Global singleton
_all_to_all_cache = AllToAllCache()
# Factory function as a cleaner interface
def get_all_to_all(**kwargs):
return _all_to_all_cache.get_or_create(**kwargs)
@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self):
def __init__(self, moe: MoEConfig):
super().__init__()
self.fused_experts = fused_experts
self.moe = moe
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled:
@ -193,6 +429,47 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool:
assert self.fused_experts == fused_experts
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
def forward_cuda(
self,
layer: torch.nn.Module,
@ -221,9 +498,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.uint32 if self.moe.use_pplx_kernels else None)
if self.rocm_aiter_moe_enabled:
assert not apply_router_weight_on_input
assert expert_map is None
return self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
@ -232,8 +512,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_experts(
else:
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@ -243,7 +523,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
expert_map=expert_map,
)
def forward_cpu(
self,
@ -399,6 +680,45 @@ def determine_expert_map(
return (local_num_experts, expert_map)
def _construct_prepare_finalize(
moe: MoEConfig, quant_config: Optional[QuantizationConfig]
) -> Optional[FusedMoEPrepareAndFinalize]:
max_num_tokens = MOE_DP_CHUNK_SIZE
world_size = moe.ep_size
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
rank = moe.ep_rank
if moe.use_pplx_kernels:
logger.debug("using PplxPrepareAndFinalize")
all_to_all = get_all_to_all(
max_num_tokens=max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) //
moe.block_size * torch.float32.itemsize)))
return PplxPrepareAndFinalize(
all_to_all,
max_num_tokens=max_num_tokens,
world_size=world_size,
rank=rank,
dp_size=dp_size,
quant_dtype=moe.in_dtype,
)
return None
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
@ -449,21 +769,16 @@ class FusedMoE(torch.nn.Module):
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Note: here we guard against accessing the TP and DP groups when
# uninitialized (this happens when testing)
self.tp_size = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank()
self.dp_size = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.dp_rank = (0
if self.dp_size == 1 else get_dp_group().rank_in_group)
self.global_num_experts = num_experts
# Use expert parallelism instead of tensor parallelism?
vllm_config = get_current_vllm_config()
use_ep = (vllm_config.parallel_config.enable_expert_parallel
and self.tp_size * self.dp_size > 1)
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
tp_size_=(tp_size if tp_size is not None else
get_tensor_model_parallel_world_size()),
dp_size_=(dp_size if dp_size is not None else
get_dp_group().world_size),
vllm_parallel_config=vllm_config.parallel_config))
self.global_num_experts = num_experts
# For smuggling this layer into the fused moe custom op
self.use_direct_call = self.dp_size == 1
@ -474,28 +789,17 @@ class FusedMoE(torch.nn.Module):
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
if use_ep:
# Set TP size to 1 to adjust for EP and adjust EP size and rank
# for DP attention.
self.ep_rank = tp_rank + self.tp_size * self.dp_rank
self.tp_rank = 0
self.ep_size = self.tp_size * self.dp_size
self.tp_size = 1
# Determine expert maps
if self.use_ep:
self.local_num_experts, self.expert_map = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts)
else:
# Adjust TP size for DP attention
self.tp_rank = tp_rank + self.tp_size * self.dp_rank
self.ep_rank = 0
self.tp_size = self.tp_size * self.dp_size
self.ep_size = 1
self.local_num_experts = self.global_num_experts
self.expert_map = None
self.local_num_experts, self.expert_map = (self.global_num_experts,
None)
self.top_k = top_k
self.global_num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
@ -520,14 +824,40 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
moe = MoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
)
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod())
quant_method = UnquantizedFusedMoEMethod(moe)
prepare_finalize = _construct_prepare_finalize(moe, quant_config)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
quant_method = quant_config.get_quant_method(self, prefix)
# No pplx for quantized types yet.
prepare_finalize = None
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if prepare_finalize is not None:
world_size = moe.ep_size
dp_size = int(moe.ep_size // moe.dp_size)
success = self.quant_method.set_prepare_finalize(
dp_size, world_size, prepare_finalize)
if not success:
logger.warning("DP+EP not supported for %s.",
type(self.quant_method))
moe_quant_params = {
"num_experts": self.local_num_experts,
@ -546,6 +876,38 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params)
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@property
def dp_size(self):
return self.moe_parallel_config.dp_size
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@property
def tp_rank(self):
return self.moe_parallel_config.tp_rank
@property
def dp_rank(self):
return self.moe_parallel_config.dp_rank
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@property
def use_ep(self):
return self.moe_parallel_config.use_ep
@property
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
@ -830,7 +1192,8 @@ class FusedMoE(torch.nn.Module):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
e_score_correction_bias: Optional[torch.Tensor] = None,
indices_type: Optional[torch.dtype] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
# DeekSeekv2 uses grouped_top_k
@ -846,21 +1209,52 @@ class FusedMoE(torch.nn.Module):
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
renormalize=renormalize,
indices_type=indices_type,
)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
return topk_weights, topk_ids
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
layer. The result of this function is typically used as
the reduce_results argument to the module.
When just tensor-parallel is used, it is not required to reduce
the shared_experts results immediately. Instead we reduce at the
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
With EP and the pplx kernels - this is no longer viable as all
GPU ranks in DP, produce the complete set of hidden_states.
Therefore it is required that we reduce the shared_experts output
early.
"""
return self.use_pplx_kernels
def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor):
"""
The pplx combine kernel reduces across GPU ranks by default.
"""
if self.use_pplx_kernels:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
if self.use_direct_call:
@ -869,9 +1263,62 @@ class FusedMoE(torch.nn.Module):
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
full_final_hidden_states = torch.empty_like(full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
)
if not skip_result_store:
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states)
ctx = get_forward_context()
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
num_tokens = full_hidden_states.size(0)
for chunk_start_ in range(0, max_tokens_across_dp,
moe_dp_chunk_size_per_rank):
chunk_start = chunk_start_
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
max_tokens_across_dp)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
process_chunk(chunk_start,
chunk_end,
skip_result_store=chunk_start_ >= num_tokens)
return full_final_hidden_states
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
if self.moe_parallel_config.use_pplx_kernels:
return self.forward_impl_chunked(hidden_states, router_logits)
if self.dp_size > 1:
hidden_states, router_logits = get_ep_group().dispatch(

View File

@ -0,0 +1,364 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Optional
import torch
#
# This file defines a set of base classes used to make MoE kernels more modular.
# The goal is to be able to utilize different communication mechanisms with
# any fused MoE kernel without needing to have combinatoric implementations.
#
# The fused moe kernels are broken down into the following components:
#
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
#
# Each component will be independent of the others except for
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
# mixed and matched with so that DP+EP can be supported easily for multiple
# MoE kernel implementations.
#
# The following main classes are defined:
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
# The prepare method must take care of any needed quantization and the
# finalize method must apply weights and do the final reduction of the output.
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
# MoE operation. One important feature to note is that this class does not
# apply topk weights or reduce the final output.
# * FusedMoEModularKernel - an interface class that combines a
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
# provide the standard fused MoE kernel interface.
#
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
# class `FusedMoEPrepareAndFinalize` since they could use collective
# communication mechanisms that need to be consistent.
#
def _moe_problem_size(
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation tensors is
not obvious. It needs to be done this way specifically due to subtle issues
with particular kernels, e.g. the int4 kernels divide the trailing dimension
by two, so it's not "correct" to extract N or K from the trailing dimension
of w1 or w2. Similarly, some kernels transpose the weights, so this needs
to be kept in mind.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = w2.size(1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), \
f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
class FusedMoEPrepareAndFinalize(ABC):
"""
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
described above.
"""
@abstractmethod
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
- a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
sure the quantization is consistent for both gemms.
- topk_ids: The topk ids.
- topk_weights: The topk weights.
- num_experts: The total number of experts in the global expert space.
- expert_map: A tensor mapping expert indices from the global expert
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
Returns a tuple of:
- quantized + dispatched a.
- quantized + dispatched a1_scales.
"""
raise NotImplementedError
@abstractmethod
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output.
- output: The output tensor, written in place. Must be (M, K) shape.
- fused_expert_output: The unweighted, unreduced output of the fused
experts, it will have (M, topk, K) shape.
- topk_weights: The weights to be applied to the fused_experts_output.
- topk_ids: The topk_ids.
- apply_router_weight_on_input: When False, apply the weights to
fused_expert_output.
"""
raise NotImplementedError
class FusedMoEPermuteExpertsUnpermute(ABC):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above.
"""
@abstractmethod
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
"""
Compute the number of elements for the temporary outputs of the two
gemms and activation in the fused expert function. Since the
gemms are independent, the workspace for the first gemm can be shared
with the workspace for the last gemm.
Returns a tuple of:
- Number of workspace13 elements: must be large enough to hold the
result of either expert gemm.
- Number of workspace2 elements: must be large enough to hold the
result of the activation function.
- Workspace type: The dtype to use for the workspace tensors.
"""
raise NotImplementedError
def activation(self, activation: str, output: torch.Tensor,
input: torch.Tensor) -> None:
assert output.size(-1) * 2 == input.size(-1)
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
@abstractmethod
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
"""
This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2.
Parameters:
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
Returns:
- torch.Tensor: The unweighted, unreduced output tensor
"""
raise NotImplementedError
class FusedMoEModularKernel(torch.nn.Module):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
a FusedMoEPermuteExpertsUnpermute to provide an interface that
is compatible with the `fused_experts` function in fused_moe.py.
It takes care of managing any required scratch space.
Note: Instances of this class should only be used for a single model
layer due to any layer specific state that may be used by the component
objects.
"""
def __init__(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
def forward(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
a1 = hidden_states
E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids)
if global_num_experts == -1:
global_num_experts = E
output = a1 if inplace else torch.zeros_like(a1)
workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.zeros(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
expert_map, apply_router_weight_on_input)
fused_out = self.fused_experts.apply(
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input)
return output

View File

@ -3,6 +3,74 @@ from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
block_m: int,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.size(1)
tokens_in_chunk = curr_hidden_states.sizze(0)
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids,
block_m,
global_num_experts,
expert_map,
pad_sorted_ids=True))
inv_perm: Optional[torch.Tensor] = None
num_tokens = top_k_num * tokens_in_chunk
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
curr_hidden_states = _fp8_perm(curr_hidden_states,
sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: Optional[torch.Tensor],
topk_weight: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.size()
K = curr_hidden.size(-1)
if inv_perm is not None:
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
if not apply_router_weight_on_input:
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def moe_permute(
hidden_states: torch.Tensor,
@ -42,7 +110,7 @@ def moe_permute(
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.`
"""
n_token, n_hidden = hidden_states.shape
n_token, n_hidden = hidden_states.size()
assert (n_hidden * hidden_states.element_size()
) % 16 == 0, "permue kernel need hidden dim align to 16B"
permuted_row_size = n_token * topk
@ -102,7 +170,7 @@ def moe_unpermute(
- hidden_states (torch.Tensor): The reduced and unpermuted activation
tensor.
"""
n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1]
n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1)
assert (n_hidden * permuted_hidden_states.element_size()
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
hidden_states = torch.empty((n_token, n_hidden),

View File

@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pplx_kernels as pplx
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
# Note use: layer.get_all_to_all() to get an AllToAll instance
# The max_num_tokens, world_size and dp_size must be the same
# as the ones used to create the AllToAll.
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(self,
a2a: pplx.AllToAll,
max_num_tokens: int,
world_size: int,
rank: int,
dp_size: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
super().__init__()
assert max_num_tokens > 0
self.a2a = a2a
self.block_shape = block_shape
self.max_num_tokens = max_num_tokens
self.world_size = world_size
self.rank = rank
self.dp_size = dp_size
self.quant_dtype = quant_dtype
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
rank_topk_weights: torch.Tensor,
rank_topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
assert rank_topk_ids.size(0) == num_tokens
# assert expert_map is None, "NYI"
# Is this always going to be a1.device?
device = a1.device
if apply_router_weight_on_input:
topk = rank_topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
self.quant_dtype,
per_act_token,
self.block_shape)
# rem_experts need to be 0 for pplx to work properly.
rem_experts = num_experts % self.world_size
assert rem_experts == 0
num_local_experts = ((num_experts // self.world_size) +
(1 if self.rank < rem_experts else 0))
expert_num_tokens = torch.empty(
num_local_experts,
dtype=torch.int32,
device=device,
)
num_dp = self.world_size // self.dp_size
expert_x = torch.empty(
(num_local_experts, self.max_num_tokens * num_dp, hidden_dim),
dtype=a1q.dtype,
device=device,
)
expert_x_scale: Optional[torch.Tensor] = None
if a1q.dtype.itemsize == 1:
float32_size = torch.float32.itemsize
block_size = (self.block_shape[0] if self.block_shape is not None
else 1) * float32_size
expert_x_scale = torch.empty(
(
num_experts,
expert_x.size(1),
(expert_x.size(2) + block_size - 1) // block_size,
),
dtype=torch.float32,
device=device,
)
# This argument is optional, defaults to indices.size(0)
# There's not much point setting this unless it is != indices.size(0)
bound_m: Optional[torch.Tensor] = None
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=rank_topk_ids,
bound_m=bound_m,
)
return expert_x, expert_x_scale, expert_num_tokens
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
num_tokens = output.size(0) # M
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
bound_m: Optional[torch.Tensor] = None
assert topk_ids.size(0) == num_tokens, (
f"{topk_ids.size(0)} == {num_tokens}")
assert output.size(0) <= self.max_num_tokens, (
f"{output.size(0)} <= {self.max_num_tokens}")
assert output.size(1) == fused_expert_output.size(-1)
# Set weights to 1 if we did them in dispatch. This is hacky.
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
self.a2a.combine(out_tokens=output,
indices=topk_ids,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m)

View File

@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
quant_dtype: Optional[torch.dtype] = None,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
):
super().__init__()
self.per_channel_quant = per_channel_quant
self.block_shape = block_shape
self.quant_dtype = quant_dtype
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
self.quant_dtype,
self.per_channel_quant,
self.block_shape)
return a1q, a1q_scale, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
_moe_unpermute_and_reduce(output, fused_expert_output, None,
topk_weights, apply_router_weight_on_input)

View File

@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
allow_deep_gemm: bool = False):
super().__init__()
self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
block_m=block_m)
self.deep_gemm_expert = DeepGemmExperts()
self.allow_deep_gemm = allow_deep_gemm
self.use_fp8_w8a8 = use_fp8_w8a8
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
return self.deep_gemm_expert.workspace_shapes(
a, M, N, K, topk, num_experts)
else:
return self.triton_expert.workspace_shapes(a, M, N, K, topk,
num_experts)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
N = w1.size(1)
if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
return self.deep_gemm_expert.apply(
hidden_states,
w1,
w2,
topk_ids,
activation,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_num_tokens,
)
else:
return self.triton_expert.apply(
hidden_states,
w1,
w2,
topk_ids,
activation,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_num_tokens,
)

View File

@ -7,6 +7,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.utils import cdiv
@ -15,26 +17,73 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
"""
assert prod(v) <= x.numel()
assert prod(
v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly?
return x.flatten()[:prod(v)].view(*v)
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
block_shape: Optional[list[int]],
per_act_token: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
if block_shape is None:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_act_token)
else:
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
return A, A_scale
def _int8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform int8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static
if block_shape is None:
assert per_act_token, \
"int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
return A, A_scale
def moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
qtype: Optional[torch.dtype],
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if qtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
elif qtype == torch.int8:
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
else:
assert A_scale is None
return A, A_scale
@ -42,7 +91,7 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
A permutation routine that works on fp8 types.
"""
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
if torch.is_floating_point(m) and m.dtype.itemsize == 1:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import functools
import importlib.util
from typing import Any, Callable, Optional
@ -9,6 +10,7 @@ from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
@ -434,6 +436,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"""
def __init__(self, quant_config: Fp8Config):
from vllm.model_executor.layers.fused_moe import fused_experts
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
@ -458,6 +461,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once(
"DeepGemm not supported on the current platform.")
self.fused_experts = functools.partial(
fused_experts,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm)
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@ -783,6 +791,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> bool:
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
if self.use_marlin or self.rocm_aiter_moe_enabled:
return False
experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
self.fused_experts = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
def apply(
self,
layer: torch.nn.Module,
@ -801,10 +834,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -819,6 +848,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
@ -835,8 +866,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size)
if self.use_marlin:
elif self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
@ -853,11 +883,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_type_id=scalar_types.float8_e4m3fn.id,
global_num_experts=global_num_experts,
expert_map=expert_map)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
else:
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
@ -872,8 +902,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)

View File

@ -79,7 +79,6 @@ class DbrxExperts(FusedMoE):
prefix=prefix,
)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.d_model = config.d_model
self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
self.tp_size)

View File

@ -31,9 +31,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@ -143,7 +141,8 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
prefix=f"{prefix}.shared_experts",
)
@ -154,6 +153,7 @@ class DeepseekV2MoE(nn.Module):
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
@ -171,9 +171,11 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)

View File

@ -25,8 +25,7 @@ from transformers import Llama4TextConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -89,7 +88,7 @@ class Llama4MoE(nn.Module):
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=False, # We need to do scatter before reduce
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
)
def forward(self, hidden_states):
@ -102,7 +101,8 @@ class Llama4MoE(nn.Module):
experts_out = routed_out + shared_out
if self.tp_size > 1:
experts_out = tensor_model_parallel_all_reduce(experts_out)
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
experts_out)
return experts_out

View File

@ -33,9 +33,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -129,7 +127,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
)
else:
self.shared_expert = None
@ -156,7 +155,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states)
return final_hidden_states.view(orig_shape)

View File

@ -30,9 +30,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -137,7 +135,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
router_logits=router_logits)
final_hidden_states = final_hidden_states
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states)
return final_hidden_states.view(orig_shape)

View File

@ -158,6 +158,7 @@ class CudaPlatformBase(Platform):
"currently not supported with CUDA Graphs.")
vllm_config.model_config.enforce_eager = True
compilation_config.use_cudagraph = False
compilation_config.use_inductor = False
@classmethod
def get_current_memory_usage(cls,

View File

@ -865,8 +865,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens

View File

@ -341,7 +341,8 @@ def init_worker_distributed_environment(
distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
ensure_kv_transfer_initialized(vllm_config)

View File

@ -265,4 +265,5 @@ def init_tpu_worker_distributed_environment(
backend="gloo",
)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)

View File

@ -390,7 +390,8 @@ class CPUWorker(LocalOrDistributedWorkerBase):
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block.

View File

@ -416,7 +416,8 @@ def init_worker_distributed_environment(
backend='hccl')
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
@ -442,7 +443,8 @@ def init_worker_distributed_environment(
torch.distributed.all_reduce(dummy_tensor_hpu)
assert dummy_tensor_hpu.item() == parallel_config.world_size
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,

View File

@ -76,7 +76,8 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
)
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
self.parallel_config.pipeline_parallel_size,
self.parallel_config.enable_expert_parallel)
# Device initialization should happen after initializing the distributed
# runtime.

View File

@ -530,7 +530,8 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
ensure_kv_transfer_initialized(vllm_config)

View File

@ -176,7 +176,8 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu())