[Kernel] DeepEP dispatch-combine kernel integration (#18434)

Signed-off-by: Varun <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-06-03 15:30:02 -04:00 committed by GitHub
parent 01eee40536
commit fa98d77773
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1950 additions and 122 deletions

View File

@ -516,9 +516,8 @@ void topk_softmax(
topk,
stream);
}
else
else if (topk_indices.scalar_type() == at::ScalarType::UInt32)
{
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
@ -530,4 +529,17 @@ void topk_softmax(
topk,
stream);
}
else {
assert(topk_indices.scalar_type() == at::ScalarType::Int64);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
}

View File

View File

@ -0,0 +1,188 @@
# SPDX-License-Identifier: Apache-2.0
"""
DeepEP test utilities
"""
import dataclasses
import importlib
import traceback
from typing import Callable, Optional
import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
## Parallel Processes Utils
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
## DeepEP specific utils
@dataclasses.dataclass
class DeepEPHTArgs:
num_local_experts: int
@dataclasses.dataclass
class DeepEPLLArgs:
max_tokens_per_rank: int
hidden_size: int
num_experts: int
use_fp8_dispatch: bool
def make_deepep_ht_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# high throughput a2a
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
buffer = deep_ep.Buffer(group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer,
world_size=pgi.world_size,
rank=pgi.rank,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
ht_args.num_local_experts,
quant_dtype=q_dtype,
block_shape=block_shape)
def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# low-latency a2a
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
pgi.world_size, deepep_ll_args.num_experts)
buffer = deep_ep.Buffer(group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)
return DeepEPLLPrepareAndFinalize(
buffer=buffer,
world_size=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
quant_dtype=q_dtype,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
)
def make_deepep_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
if deepep_ht_args is not None:
assert deepep_ll_args is None
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
block_shape)
assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype)

View File

@ -0,0 +1,371 @@
# SPDX-License-Identifier: Apache-2.0
"""
Test DeepEP + DeepGEMM integration
"""
import dataclasses
import importlib
from typing import Optional
import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from .deepep_utils import ProcessGroupInfo, parallel_launch
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
try:
import deep_gemm
has_deep_gemm = True
except ImportError:
has_deep_gemm = False
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from .deepep_utils import DeepEPHTArgs, make_deepep_a2a
if has_deep_gemm:
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep,
reason="Requires deep_ep kernels",
)
requires_deep_gemm = pytest.mark.skipif(
not has_deep_gemm,
reason="Requires deep_gemm kernels",
)
P = ParamSpec("P")
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_block_quant_fp8_weights(
e: int,
n: int,
k: int,
block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale
"""
dtype = torch.bfloat16
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
device="cuda",
dtype=torch.float32)
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
device="cuda",
dtype=torch.float32)
assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
return w1, w2, w1_s, w2_s
@dataclasses.dataclass
class TestConfig:
topk: int
m: int
k: int
n: int
num_experts: int
block_size: list[int]
@dataclasses.dataclass
class TestTensors:
rank_tokens: torch.Tensor # all ranks make this many tokens
rank_token_scales: Optional[torch.Tensor]
topk: torch.Tensor
topk_weights: torch.Tensor
config: TestConfig
@staticmethod
def make(config: TestConfig, rank) -> "TestTensors":
dtype = torch.bfloat16
topk, m, k, block_size = (config.topk, config.m, config.k,
config.block_size)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
rank_tokens = torch.randn(
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
block_k = block_size[1]
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
topk_ids = torch.randint(
low=0,
high=config.num_experts,
size=(m, topk),
device=torch.cuda.current_device()).to(dtype=torch.int64)
topk_weights = torch.randn(topk_ids.shape,
dtype=torch.float32,
device=torch.cuda.current_device())
return TestTensors(rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk_ids,
topk_weights=topk_weights,
config=config)
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, q_dtype: Optional[torch.dtype],
block_shape: list[int]) -> FusedMoEModularKernel:
a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
pg=pg,
pgi=pgi,
dp_size=dp_size,
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
deepep_ll_args=None,
q_dtype=q_dtype,
block_shape=block_shape)
fused_experts = DeepGemmExperts()
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
num_experts: int) -> torch.Tensor:
num_local_experts = w1.size(0)
def build_expert_map():
num_local_experts = w1.size(0)
expert_map = torch.full((num_experts, ),
fill_value=-1,
dtype=torch.int32)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, dp_size, num_local_experts, q_dtype,
test_tensors.config.block_size)
a1_scale = test_tensors.rank_token_scales
out = mk.forward(hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=a1_scale,
a2_scale=None,
apply_router_weight_on_input=False)
return out
def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor, block_shape: list[int]):
return fused_experts(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
block_shape=block_shape,
# Make sure this is set to False so we
# dont end up comparing the same implementation.
allow_deep_gemm=False)
def _deep_ep_moe(
pgi: ProcessGroupInfo,
dp_size: int,
config: TestConfig,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
):
current_platform.seed_everything(pgi.rank)
w1 = w1.to(device=torch.cuda.current_device())
w2 = w2.to(device=torch.cuda.current_device())
w1_scale = w1_scale.to(device=torch.cuda.current_device())
w2_scale = w2_scale.to(device=torch.cuda.current_device())
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, pgi.rank)
block_shape = [
w1.size(1) // w1_scale.size(1),
w1.size(2) // w1_scale.size(2)
]
with set_current_vllm_config(VllmConfig()):
# Reference
triton_moe = triton_impl(a=test_tensors.rank_tokens,
topk_ids=test_tensors.topk,
topk_weights=test_tensors.topk_weights,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=test_tensors.rank_token_scales,
block_shape=block_shape)
# Slice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
e_start = num_local_experts * pgi.rank
e_end = e_start + num_local_experts
w1_ep = w1[e_start:e_end]
w2_ep = w2[e_start:e_end]
w1_scale_ep = w1_scale[e_start:e_end]
w2_scale_ep = w2_scale[e_start:e_end]
deepep_moe = deep_ep_moe_impl(
pg,
pgi,
dp_size,
test_tensors,
w1_ep,
w2_ep,
w1_scale_ep,
w2_scale_ep,
config.num_experts,
)
torch.testing.assert_close(
triton_moe,
deepep_moe,
atol=6e-2,
rtol=6e-2,
)
MNKs = [
(8, 128, 128),
(8, 128, 512),
(8, 512, 512),
(3, 1024, 2048),
(32, 128, 1024),
(45, 512, 2048),
(64, 1024, 1024),
(129, 128, 256),
(129, 1024, 2048),
(222, 1024, 2048),
]
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
world_dp_size: tuple[int, int]):
m, n, k = mnk
current_platform.seed_everything(7)
if topk > num_experts:
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
world_size, dp_size = world_dp_size
config = TestConfig(
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts,
block_size=block_size,
)
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size)
parallel_launch(world_size, _deep_ep_moe, dp_size, config, w1, w2,
w1_scale, w2_scale)

View File

@ -0,0 +1,459 @@
# SPDX-License-Identifier: Apache-2.0
"""
Test deepep dispatch-combine logic
"""
import dataclasses
import importlib
from typing import Optional, Union
import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from .deepep_utils import ProcessGroupInfo, parallel_launch
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep,
reason="Requires deep_ep kernels",
)
MAX_TOKENS_PER_RANK = 64
def make_weights(
e, n, k, dtype
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1_scale, w2_scale
"""
if dtype in [torch.float16, torch.bfloat16]:
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
return w1, w2, None, None
# per-out-channel weight quantization
assert dtype == torch.float8_e4m3fn
w1 = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float16)
w2 = torch.empty((e, k, n), device="cuda", dtype=torch.float16)
n_b_scales = 2 * n
k_b_scales = k
w1_q = torch.empty_like(w1, dtype=dtype)
w2_q = torch.empty_like(w2, dtype=dtype)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=True)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=True)
return w1_q, w2_q, w1_scale, w2_scale
@dataclasses.dataclass
class TestConfig:
dtype: torch.dtype
topk: int
m: int
k: int
n: int
num_experts: int
@dataclasses.dataclass
class TestTensors:
rank_tokens: torch.Tensor # all ranks make this many tokens
rank_token_scales: Optional[torch.Tensor]
topk: torch.Tensor
topk_weights: torch.Tensor
config: TestConfig
@staticmethod
def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors":
# TODO (varun) - check that float16 works ?
assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn]
token_dtype = (torch.bfloat16 if config.dtype == torch.float8_e4m3fn
else config.dtype)
rank_tokens = torch.randn(
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
rank_token_scales = None
if config.dtype == torch.float8_e4m3fn:
# low_latency_mode kernels dont support per-token quant.
_, rank_token_scales = ops.scaled_fp8_quant(
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
topk = torch.randint(low=0,
high=config.num_experts,
size=(config.m, config.topk),
device="cuda").to(dtype=torch.int64)
topk_weights = torch.randn(topk.shape,
dtype=torch.float32,
device="cuda")
return TestTensors(rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk,
topk_weights=topk_weights,
config=config)
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
low_latency_mode: bool, hidden_size: int, dp_size: int,
num_experts: int, num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool) -> FusedMoEModularKernel:
is_quantized = q_dtype is not None
ht_args: Optional[DeepEPHTArgs] = None
ll_args: Optional[DeepEPLLArgs] = None
if low_latency_mode:
ll_args = DeepEPLLArgs(max_tokens_per_rank=MAX_TOKENS_PER_RANK,
hidden_size=hidden_size,
num_experts=num_experts,
use_fp8_dispatch=use_fp8_dispatch)
else:
assert not use_fp8_dispatch, (
"FP8 Dispatch is valid only for low-latency kernels")
ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)
a2a : Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = \
make_deepep_a2a(pg = pg,
pgi = pgi,
dp_size = dp_size,
q_dtype = q_dtype,
block_shape = None,
deepep_ht_args = ht_args,
deepep_ll_args = ll_args)
if low_latency_mode:
fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK,
world_size=pgi.world_size,
dp_size=dp_size,
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False)
else:
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
low_latency_mode: bool, dp_size: int,
test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], num_experts: int,
use_fp8_dispatch: bool) -> torch.Tensor:
num_local_experts = w1.size(0)
def build_expert_map():
num_local_experts = w1.size(0)
expert_map = torch.full((num_experts, ),
fill_value=-1,
dtype=torch.int32)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
hidden_size = test_tensors.rank_tokens.size(1)
is_quantized = w1.dtype == torch.float8_e4m3fn
q_dtype = None
if is_quantized:
q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode,
hidden_size, dp_size,
num_experts,
num_local_experts, q_dtype,
use_fp8_dispatch)
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
total_num_tokens = test_tensors.rank_tokens.size(0)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
rank_tokens_chunk = test_tensors.rank_tokens[chunk_start:chunk_end]
topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end]
topk_chunk = test_tensors.topk[chunk_start:chunk_end]
rank_token_scales_chunk = test_tensors.rank_token_scales
if rank_token_scales_chunk is not None and rank_token_scales_chunk.size(
0) == total_num_tokens:
# per act token
rank_token_scales_chunk = rank_token_scales_chunk[
chunk_start:chunk_end]
out = mk.forward(hidden_states=rank_tokens_chunk,
w1=w1,
w2=w2,
topk_weights=topk_weights_chunk,
topk_ids=topk_chunk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=rank_token_scales_chunk,
a2_scale=None,
apply_router_weight_on_input=False)
if not skip_result_store:
out_hidden_states[chunk_start:chunk_end, :].copy_(
out, non_blocking=True)
max_num_tokens_per_dp = (MAX_TOKENS_PER_RANK
if low_latency_mode else total_num_tokens)
for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp):
chunk_start = chunk_start_
chunk_end = min(chunk_start + max_num_tokens_per_dp, total_num_tokens)
# clamp start and end
chunk_start = min(chunk_start, total_num_tokens - 1)
chunk_end = min(chunk_end, total_num_tokens)
process_chunk(chunk_start,
chunk_end,
skip_result_store=chunk_start_ >= total_num_tokens)
return out_hidden_states
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool):
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
test_tensors.topk_weights)
if using_fp8_dispatch:
# The DeepEP implementation is requested to dispatch using FP8.
# For numerical stability for testing, emulate the fp8 dispatch by
# blockwise quant and de-quant.
a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128)
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
a.shape).to(a.dtype)
is_quantized = w1.dtype == torch.float8_e4m3fn
a_dtype = a.dtype
if is_quantized:
w1 = w1.to(dtype=torch.float32) * w1_scale
w2 = w2.to(dtype=torch.float32) * w2_scale
a = a.to(dtype=torch.float32)
m, _ = a.shape
topk = topk_ids.size(1)
out = torch.zeros_like(a)
for i in range(m):
a_i = a[i]
o_i = out[i]
for j in range(topk):
e = topk_ids[i][j]
e_w = topk_weights[i][j]
w1_e = w1[e]
w2_e = w2[e]
o_i += (SiluAndMul()
(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)) * e_w
if is_quantized:
out = out.to(dtype=a_dtype)
return out
def _deep_ep_moe(
pgi: ProcessGroupInfo,
low_latency_mode: bool,
dp_size: int,
config: TestConfig,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
use_fp8_dispatch: bool,
):
if not low_latency_mode:
assert not use_fp8_dispatch, (
"FP8 dispatch interface is available only in low-latency mode")
is_quantized = w1.dtype == torch.float8_e4m3fn
w1 = w1.to(device=torch.cuda.current_device())
w2 = w2.to(device=torch.cuda.current_device())
if is_quantized:
w1_scale = w1_scale.to( # type: ignore
device=torch.cuda.current_device())
w2_scale = w2_scale.to( # type: ignore
device=torch.cuda.current_device())
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, low_latency_mode)
with set_current_vllm_config(VllmConfig()):
# Reference
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
w2_scale, use_fp8_dispatch)
# Splice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
e_start = num_local_experts * pgi.rank
e_end = e_start + num_local_experts
w1_ep = w1[e_start:e_end]
w2_ep = w2[e_start:e_end]
w1_scale_ep, w2_scale_ep = None, None
if is_quantized:
w1_scale_ep = w1_scale[e_start:e_end] # type: ignore
w2_scale_ep = w2_scale[e_start:e_end] # type: ignore
deepep_combined = deep_ep_moe_impl(
pg,
pgi,
low_latency_mode,
dp_size,
test_tensors,
w1_ep,
w2_ep,
w1_scale_ep,
w2_scale_ep,
config.num_experts,
use_fp8_dispatch,
)
torch.testing.assert_close(
torch_combined,
deepep_combined,
atol=6e-2,
rtol=6e-2,
)
MNKs = [
(1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
(32, 128, 1024),
(45, 512, 2048),
(64, 1024, 1024),
(222, 1024, 2048),
]
DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int, world_dp_size: tuple[int,
int]):
low_latency_mode = False
use_fp8_dispatch = False
m, n, k = mnk
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
config = TestConfig(dtype=dtype,
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts)
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
MNKs = [
(1, 128, 2560),
(2, 128, 2560),
(3, 1024, 2560),
(32, 128, 2560),
(45, 512, 2560),
(64, 1024, 2560),
(222, 1024, 2560),
]
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
USE_FP8_DISPATCH = [True, False]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@requires_deep_ep
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int,
world_dp_size: tuple[int, int],
use_fp8_dispatch: bool):
low_latency_mode = True
m, n, k = mnk
if (low_latency_mode
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):
pytest.skip(
f"Skipping test as hidden size {k} is not in list of supported "
f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}"
)
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
config = TestConfig(dtype=dtype,
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts)
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)

View File

@ -1856,6 +1856,8 @@ class ParallelConfig:
factors.append(self.pipeline_parallel_size)
factors.append(self.tensor_parallel_size)
factors.append(self.enable_expert_parallel)
factors.append(self.data_parallel_size)
factors.append(envs.VLLM_ALL2ALL_BACKEND)
return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import torch
import torch.distributed as dist
@ -129,3 +129,147 @@ class PPLXAll2AllManager(All2AllManagerBase):
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
has_deepep = importlib.util.find_spec("deep_ep") is not None
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
super().__init__(cpu_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self.num_sms = 20
def get_handle(self, kwargs):
raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
pass
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024
num_rdma_bytes = None
num_qps_per_rank = None
if self.internode:
num_rdma_bytes = 1024 * 1024 * 1024
num_qps_per_rank = self.num_sms // 2
else:
assert self.intranode
num_rdma_bytes = 0
num_qps_per_rank = 1
assert num_rdma_bytes is not None
assert num_qps_per_rank is not None
return dict(group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank)
def get_handle(self, kwargs):
assert len(kwargs) == 0, (
"DeepEPHTAll2AllManager expects no arguments. All the required "
"args are computed in the Manager itself.")
import deep_ep
buffer_kwargs = self._make_all2all_kwargs()
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
# It is dangerous to set num sms outside this function. num_sms is not
# a part of the hash-key that identifies this object. If we are in a
# situation where we make objects with different num_sms, the hash key
# in get_or_create must be updated.
handle.set_num_sms(self.num_sms)
return handle
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
"""
All2All communication based on DeepEP Low-Latency kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def _make_all2all_kwargs(
self,
max_num_tokens_per_dp_rank: int,
token_hidden_size: int,
num_ep_ranks: int,
num_global_experts: int,
num_local_experts: int,
) -> dict[Any, Any]:
"""
max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
can dispatch all the ranks must hold the same value.
token_hidden_size: the hidden dimension of each token.
num_ep_ranks: the number of EP group ranks.
num_global_experts: Number of experts in the model.
num_local_experts: Number of experts in an EP rank.
"""
import deep_ep
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024
num_qps_per_rank = num_local_experts
num_rdma_bytes = None
if self.internode:
num_rdma_bytes = 1024 * 1024 * 1024
else:
assert self.intranode
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
hidden=token_hidden_size,
num_ranks=num_ep_ranks,
num_experts=num_global_experts)
assert num_rdma_bytes is not None
return dict(group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank)
def get_handle(self, kwargs):
"""
The kwargs for DeepEPLLAll2AllManager is dictated by
_make_all2all_kwargs.
"""
import deep_ep
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
# It is dangerous to set num sms outside this function. num_sms is not
# a part of the hash-key that identifies this object. If we are in a
# situation where we make objects with different num_sms, the hash key
# in get_or_create must be updated.
handle.set_num_sms(self.num_sms)
return handle

View File

@ -67,6 +67,14 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
logger.info("Using PPLX all2all manager.")
elif all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
logger.info("Using DeepEP High-Throughput all2all manager.")
elif all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")

View File

@ -826,6 +826,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Available options:
# - "naive": naive all2all implementation using all-reduce
# - "pplx": use pplx kernels
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),

View File

@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, per_token_group_quant_fp8)
from vllm.utils import round_up
logger = init_logger(__name__)
@ -34,10 +34,8 @@ def _valid_deep_gemm_shape(M: int, N: int, K: int):
return align <= M and N % align == 0 and K % align == 0
def _valid_deep_gemm(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
expert_map: Optional[torch.Tensor] = None) -> bool:
def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor) -> bool:
"""
Check if the given problem size is supported by the DeepGemm grouped
gemm kernel. All of M, N, K and the quantization block_shape must be
@ -47,10 +45,6 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
logger.debug("DeepGemm disabled: deep_gemm not available.")
return False
if expert_map is not None:
logger.debug("DeepGemm disabled: expert map NYI.")
return False
M = hidden_states.size(0)
_, K, N = w2.size()
if not _valid_deep_gemm_shape(M, N, K):
@ -116,7 +110,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a1q = hidden_states
_, N, K = w1.size()
assert global_num_experts != -1
if global_num_experts == -1:
global_num_experts = w1.size(0)
assert w2.size(1) == K
a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
@ -128,6 +124,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.block_shape[0],
)
if expert_map is not None:
# DeepGemm (Grouped Contiguous) kernel needs a valid B index
# for all rows of A. To that effect, simply compute with
# the 0th weight matrix.
# Note that this relies on the fact that corresponding topk
# weights would be 0 during weight multiplication.
expert_ids = torch.where(expert_ids == -1, 0, expert_ids)
# Note: M_sum is different than the pre-permuted shape of a1q.
M_sum = a1q.size(0)
workspace1 = _resize_cache(workspace13, (M_sum, N))
@ -140,9 +144,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.activation(activation, workspace2, workspace1.view(-1, N))
a2q_scale: Optional[torch.Tensor] = None
a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False,
self.block_shape)
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
self.block_shape[1],
column_major_scales=True)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)

View File

@ -0,0 +1,236 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import deep_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Prepare/Finalize using DeepEP High-Throughput kernels.
"""
def __init__(self,
buffer: deep_ep.Buffer,
world_size: int,
rank: int,
dp_size: int,
rank_expert_offset: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
super().__init__()
self.buffer = buffer
self.world_size = world_size
self.rank = rank
self.dp_size = dp_size
self.rank_expert_offset = rank_expert_offset
self.quant_dtype = quant_dtype
self.block_shape = block_shape
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
self.handle = None
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
def max_num_tokens_per_rank(self) -> Optional[int]:
return None
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.int64
def _get_dispatch_config(self) -> Optional[deep_ep.Config]:
if self.dp_size not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_dispatch_config(self.dp_size)
def _get_combine_config(self) -> Optional[deep_ep.Config]:
if self.dp_size not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_combine_config(self.dp_size)
def _do_quant(self, tokens: torch.Tensor,
token_scales: Optional[torch.Tensor], per_act_token: bool):
tokens, token_scales = moe_kernel_quantize_input(
tokens, token_scales, self.quant_dtype, per_act_token,
self.block_shape)
return tokens, token_scales
def _do_dispatch(self, tokens: torch.Tensor,
token_scales: Optional[torch.Tensor],
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor, num_experts: int):
has_scales = token_scales is not None
(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens,
is_token_in_rank, event) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids,
num_experts=num_experts,
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)
token_data = tokens
if has_scales:
token_data = (tokens, token_scales)
(
token_data, expert_topk_ids, expert_topk_weights,
expert_num_tokens_per_expert_list, self.handle, event
) = self.buffer.dispatch(
x=token_data,
handle=None,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=expert_num_tokens,
topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert
# to this value.
expert_alignment=1,
config=self._get_dispatch_config(),
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)
if has_scales:
expert_x, expert_x_scale = token_data
else:
expert_x, expert_x_scale = token_data, None
# The existing MOE kernels assume that all entries of topk_ids are
# valid. To that effect, set the -1s in expert_topk_ids to some expert
# outside this rank so the expert_map can remap it to -1 when safe.
# With Expert Parallel, the experts are divided amongst the rank
# sequentially. For rank 0, set it to num_experts - 1 and for all other
# ranks set it to 0 as we know that expert_map will have a -1 in those
# regions for those ranks.
#
# DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces.
expert_topk_ids = torch.where(
expert_topk_ids == -1,
num_experts - 1 if self.rank_expert_offset == 0 else 0,
expert_topk_ids + self.rank_expert_offset)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights)
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
rank_topk_weights: torch.Tensor,
rank_topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = rank_topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
# Check if there is a block_shape / or if we can infer the quantization
# schemes from the scales.
per_token_quant = None
if all([x is None for x in [self.block_shape, a1_scale, a2_scale]
]) and self.quant_dtype is not None:
# Quantization required despite none of the inputs suggesting
# quantization. Fallback to per_token_dynamic quant.
per_token_quant = True
else:
per_token_quant = ((self.block_shape is not None) or
(a1_scale is not None and a1_scale.numel() != 1)
or (a2_scale is not None
and a2_scale.numel() != 1))
if per_token_quant:
a1q, a1q_scale = self._do_quant(a1, a1_scale, per_act_token=True)
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=rank_topk_ids,
rank_topk_weights=rank_topk_weights,
num_experts=num_experts)
else:
# DeepEP kernels only support dispatching per-token-quant
# quantization. dispatch in bfloat16.
(expert_x, _, expert_num_tokens, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1,
token_scales=None,
rank_topk_ids=rank_topk_ids,
rank_topk_weights=rank_topk_weights,
num_experts=num_experts)
# quantize now
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = self._do_quant(expert_x,
a1_scale,
per_act_token=False)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights)
def _apply_weights_and_reduce(self, num_tokens: int,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
apply_router_weight_on_input: bool,
output_dtype: torch.dtype):
if fused_expert_output.ndim == 2:
hidden_dim = fused_expert_output.size(-1)
fused_expert_output = fused_expert_output.view(
num_tokens, -1, hidden_dim)
if not apply_router_weight_on_input:
# The DeepEP combine kernels don't do the topk weight
# multiplication. We multiply the weights locally.
fused_expert_output = fused_expert_output.to(torch.float32)
fused_expert_output = fused_expert_output * topk_weights.view(
fused_expert_output.size(0), -1, 1)
fused_expert_output = fused_expert_output.to(output_dtype)
return fused_expert_output.sum(dim=1).to(output_dtype)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> None:
assert self.handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if fused_expert_output.numel() != 0:
fused_expert_output = self._apply_weights_and_reduce(
num_tokens=topk_ids.size(0),
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
apply_router_weight_on_input=apply_router_weight_on_input,
output_dtype=output.dtype)
combined_x, _, event = self.buffer.combine(
x=fused_expert_output,
handle=self.handle,
topk_weights=None,
config=self._get_combine_config(),
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)

View File

@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import deep_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128
def dequant_fp8(expert_x_fp8: torch.Tensor,
expert_x_scales: torch.Tensor) -> torch.Tensor:
"""
Return dequantized tensor in fp32
"""
# TODO (varun) : Optimize leverage num_tokens_per_expert counts
assert expert_x_fp8.is_contiguous()
expert_x_scales = expert_x_scales.contiguous()
num_experts = expert_x_fp8.size(0)
expert_x_fp32 = expert_x_fp8.to(torch.float32).view(
num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE)
expert_x_scales = expert_x_scales.view(num_experts, -1, 1)
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape)
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Prepare/Finalize using DeepEP low-latency kernels.
"""
# DeepEP low-latency kernels are compiled only for certain
# specific hidden sizes.
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
def __init__(self,
buffer: deep_ep.Buffer,
world_size: int,
dp_size: int,
max_tokens_per_rank: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
use_fp8_dispatch: bool = False):
super().__init__()
self.buffer = buffer
self.world_size = world_size
self.dp_size = dp_size
self.quant_dtype = quant_dtype
self.block_shape = block_shape
self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
self.handle = None
def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_tokens_per_rank
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.int64
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
rank_topk_weights: torch.Tensor,
rank_topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
(f"Hidden Size {hidden_size} not in supported list of hidden sizes"
f"{self.SUPPORTED_HIDDEN_SIZES}")
if self.use_fp8_dispatch:
assert hidden_size % 128 == 0, \
"DeepEP kernels quantize the inputs in blocks of shape 128"
# Quantize
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
assert not per_act_token, (
"low_latency kernels don't support per-act-token quant")
if apply_router_weight_on_input:
topk = rank_topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
# Dispatch
expert_x, expert_num_tokens, self.handle, event, hook = \
self.buffer.low_latency_dispatch(a1,
rank_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
async_finish=False,
return_recv_hook=False)
if self.use_fp8_dispatch:
# TODO (varun) : In the case of dynamic quantization, we could
# probably skip the quant below and use the results directly.
# Although note that the deepep quant is per token 128 elements.
expert_x_fp8, expert_x_scales = expert_x
expert_x = dequant_fp8(expert_x_fp8,
expert_x_scales).to(dtype=a1.dtype)
num_experts = expert_x.size(0)
hidden_dim = expert_x.size(-1)
expert_x = expert_x.view((-1, expert_x.size(-1)))
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x, a1_scale, self.quant_dtype, per_act_token,
self.block_shape)
expert_x = expert_x.view((num_experts, -1, hidden_dim))
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> None:
assert self.handle is not None
combine_topk_weights = topk_weights
if apply_router_weight_on_input:
# weights have already been applied.
combine_topk_weights = torch.ones_like(topk_weights)
# TODO (varun) : Enable zero copy mode
_, event, hook = self.buffer.low_latency_combine(
fused_expert_output,
topk_ids,
combine_topk_weights,
self.handle,
async_finish=False,
zero_copy=False,
return_recv_hook=False,
out=output)

View File

@ -10,7 +10,8 @@ import triton.language as tl
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
@triton.jit
@ -397,6 +398,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.rank = rank
self.max_num_tokens = max_num_tokens
def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_num_tokens
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return None
def prepare(
self,
a1: torch.Tensor,
@ -407,7 +414,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
assert a1.dim() == 2
assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0)
@ -450,7 +458,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
first_expert, :rows, :] = a1[:topks.numel()][topks]
tokens_per_expert[expert_id - first_expert] = rows
return b_a1, a1_scale, tokens_per_expert
return b_a1, a1_scale, tokens_per_expert, None, None
def finalize(
self,
@ -601,6 +609,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
world_size: int = 1,
dp_size: int = 1,
@ -611,12 +620,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape
self.per_channel_quant = per_channel_quant
self.max_num_tokens = max_num_tokens
assert not use_int8_w8a8, "NYI"
assert not use_int4_w4a16, "NYI"
self.world_size = world_size
self.dp_size = dp_size
assert not use_int8_w8a8, "NYI"
assert not use_int4_w4a16, "NYI"
assert self.block_shape is None, "NYI"
def workspace_shapes(
self,
a: torch.Tensor,
@ -670,8 +682,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
]
# TODO: num_tokens -> max_num_tokens?
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids)
assert w1.size(0) == E
@ -687,7 +698,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2.size(),
top_k_num,
config_dtype,
num_tokens,
max_num_tokens,
block_shape=self.block_shape,
)
@ -706,10 +717,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
intermediate_cache1 = _resize_cache(workspace13,
(E, max_num_tokens, N))
intermediate_cache2 = _resize_cache(workspace2,
(E, num_tokens, N // 2))
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
(E, max_num_tokens, N // 2))
intermediate_cache3 = _resize_cache(workspace13,
(E, max_num_tokens, K))
# MM1
invoke_moe_batched_triton_kernel(A=hidden_states,
@ -731,15 +744,20 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.activation(activation, intermediate_cache2.view(-1, N // 2),
intermediate_cache1.view(-1, N))
#qintermediate_cache2 = intermediate_cache2
a2q_scale = a2_scale
# TODO (varun) : support w8a8
assert not self.use_fp8_w8a8
#if self.use_fp8_w8a8:
# qintermediate_cache2, a2q_scale = _fp8_quantize(
# intermediate_cache2, a2_scale, self.block_shape)
ic2_hidden_size = intermediate_cache2.size(-1)
intermediate_cache2 = intermediate_cache2.view(-1, ic2_hidden_size)
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
A_scale=a2_scale,
qtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
qintermediate_cache2 = qintermediate_cache2.view(
(E, -1, ic2_hidden_size))
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
B=w2,
C=intermediate_cache3,
expert_num_tokens=expert_num_tokens,
@ -752,5 +770,4 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int4_w4a16=self.use_int4_w4a16,
config=config,
block_shape=self.block_shape)
return intermediate_cache3

View File

@ -1164,7 +1164,7 @@ def fused_experts(hidden_states: torch.Tensor,
# permute/unpermute ops are available.
N = w1.shape[1]
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
and _valid_deep_gemm(hidden_states, w1, w2)):
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8(
hidden_states=hidden_states,

View File

@ -5,7 +5,7 @@ import importlib
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
@ -30,16 +30,19 @@ from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
has_deepep = importlib.util.find_spec("deep_ep") is not None
if current_platform.is_cuda_alike():
from .fused_batched_moe import (BatchedPrepareAndFinalize,
BatchedTritonExperts)
from .fused_batched_moe import BatchedTritonExperts
from .fused_moe import TritonExperts, fused_experts
from .modular_kernel import (FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
if has_pplx:
from .pplx_prepare_finalize import PplxPrepareAndFinalize
if has_deepep:
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
@ -71,10 +74,24 @@ class FusedMoEParallelConfig:
use_ep: bool # whether to use EP or not
@property
def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep
@property
def use_pplx_kernels(self):
return self.dp_size > 1 and self.use_ep and \
envs.VLLM_ALL2ALL_BACKEND == "pplx"
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
@property
def use_deepep_ht_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
@property
def use_deepep_ll_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@staticmethod
def make(tp_size_: int, dp_size_: int,
@ -231,6 +248,14 @@ class MoEConfig:
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
@property
def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels
@property
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
@ -252,7 +277,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize = None
quant_dtype = None
act_quant_block_size = None
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
if isinstance(quant_config, Fp8Config):
act_quant_block_size = quant_config.weight_block_size
quant_dtype = torch.float8_e4m3fn
prepare_finalize: Optional[Union[PplxPrepareAndFinalize,
DeepEPHTPrepareAndFinalize,
DeepEPLLPrepareAndFinalize]] = None
if moe.use_pplx_kernels:
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
@ -288,8 +322,49 @@ class FusedMoEMethodBase(QuantizeMethodBase):
dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype,
)
elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
all_to_all_args = dict()
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = DeepEPHTPrepareAndFinalize(
handle,
world_size=all2all_manager.world_size,
rank=all2all_manager.rank,
dp_size=all2all_manager.dp_world_size,
rank_expert_offset=all2all_manager.rank *
moe.num_local_experts,
quant_dtype=quant_dtype,
block_shape=act_quant_block_size,
)
elif moe.use_deepep_ll_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
num_ep_ranks=all2all_manager.world_size,
num_global_experts=moe.num_experts,
num_local_experts=moe.num_experts //
all2all_manager.world_size)
handle = all2all_manager.get_handle(all_to_all_args)
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize(
handle,
world_size=all2all_manager.world_size,
dp_size=all2all_manager.dp_world_size,
max_tokens_per_rank=moe.max_num_tokens,
quant_dtype=quant_dtype,
block_shape=act_quant_block_size,
use_fp8_dispatch=False,
)
self.topk_indices_dtype = None
if prepare_finalize is not None:
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
@ -297,7 +372,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
)
def select_gemm_impl(
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]
self, prepare_finalize: FusedMoEPrepareAndFinalize
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
@ -334,6 +409,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: MoEConfig):
super().__init__()
self.fused_experts = fused_experts # type: ignore
self.topk_indices_dtype = None
self.moe = moe
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
@ -343,8 +419,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
self.rocm_aiter_fused_experts = None # type: ignore
def select_gemm_impl(
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]):
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize):
assert self.fused_experts == fused_experts
@ -353,11 +428,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
use_batched_experts = prepare_finalize.max_num_tokens_per_rank(
) is not None
if use_batched_experts:
logger.debug("BatchedTritonExperts %s", self.moe)
assert self.moe.dp_size == all2all_manager.dp_world_size
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
max_num_tokens=self.moe.max_num_tokens,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
@ -366,6 +443,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
else:
logger.debug("TritonExperts %s", self.moe)
@ -494,6 +572,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -505,7 +584,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.uint32 if self.moe.use_pplx_kernels else None)
indices_type=self.topk_indices_dtype)
if self.rocm_aiter_moe_enabled:
assert expert_map is None
@ -806,11 +885,8 @@ class FusedMoE(torch.nn.Module):
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
if quant_config is None:
quant_method = UnquantizedFusedMoEMethod(moe)
else:
quant_method = quant_config.get_quant_method(self, prefix)
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
@ -836,7 +912,8 @@ class FusedMoE(torch.nn.Module):
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
if self.moe_parallel_config.use_pplx_kernels:
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size),
@ -880,6 +957,14 @@ class FusedMoE(torch.nn.Module):
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
@property
def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels
@property
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
@ -1210,19 +1295,21 @@ class FusedMoE(torch.nn.Module):
When just tensor-parallel is used, it is not required to reduce
the shared_experts results immediately. Instead we reduce at the
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
With EP and the pplx kernels - this is no longer viable as all
With EP and all2all kernels - this is no longer viable as all
GPU ranks in DP, produce the complete set of hidden_states.
Therefore it is required that we reduce the shared_experts output
early.
"""
return self.use_pplx_kernels
return (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels)
def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor):
"""
The pplx combine kernel reduces across GPU ranks by default.
"""
if self.use_pplx_kernels:
if (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels):
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
@ -1289,7 +1376,7 @@ class FusedMoE(torch.nn.Module):
ctx = get_forward_context()
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
num_tokens = full_hidden_states.size(0)
for chunk_start_ in range(0, max_tokens_across_dp,
@ -1310,12 +1397,17 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
if self.moe_parallel_config.use_pplx_kernels:
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits)
if self.dp_size > 1:
do_naive_dispatch_combine: bool = (
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
@ -1335,12 +1427,12 @@ class FusedMoE(torch.nn.Module):
apply_router_weight_on_input=self.apply_router_weight_on_input,
)
if self.dp_size > 1:
if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.)
final_hidden_states = tensor_model_parallel_all_reduce(
# Default set to False. (May have to add shared expert outputs.
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
return final_hidden_states

View File

@ -94,7 +94,8 @@ class FusedMoEPrepareAndFinalize(ABC):
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
@ -113,6 +114,10 @@ class FusedMoEPrepareAndFinalize(ABC):
Returns a tuple of:
- quantized + dispatched a.
- quantized + dispatched a1_scales.
- Optional tensor as big as number of local experts that contains the
number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
"""
raise NotImplementedError
@ -138,6 +143,27 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise NotImplementedError
@abstractmethod
def topk_indices_dtype(self) -> Optional[torch.dtype]:
"""
The PrepareFinalize All2All implementations generally constrain the
dtype of the topk_ids they support. This function returns the
required topk indices dtype so it can be respected.
Return None if there are no such restrictions.
"""
raise NotImplementedError
@abstractmethod
def max_num_tokens_per_rank(self) -> Optional[int]:
"""
Some PrepareFinalize All2All implementations are batched. Meaning,
they can processes only as set of tokens at a time. This
function returns the batch size i.e the maximum number of tokens
the implementation can process at a time.
Return None if there are no such restrictions.
"""
raise NotImplementedError
class FusedMoEPermuteExpertsUnpermute(ABC):
"""
@ -261,6 +287,61 @@ class FusedMoEModularKernel(torch.nn.Module):
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
def _do_fused_experts(
self,
a1: torch.Tensor, # input to forward fn
a1q: torch.Tensor, # output of prepare fn
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
expert_num_tokens: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor]) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
# Use a1 here to decipher the correct workspace datatype
workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.zeros(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
fused_out = self.fused_experts.apply(
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
return fused_out
def forward(
self,
hidden_states: torch.Tensor,
@ -315,49 +396,48 @@ class FusedMoEModularKernel(torch.nn.Module):
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
a1 = hidden_states
E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids)
if global_num_experts == -1:
global_num_experts = E
output = a1 if inplace else torch.zeros_like(a1)
workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
global_num_experts))
if global_num_experts == -1:
global_num_experts = w1.size(0)
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.zeros(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
global_num_experts, expert_map, apply_router_weight_on_input)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (topk_weights if _expert_topk_weights is None else
_expert_topk_weights)
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
expert_map, apply_router_weight_on_input)
fused_out = self.fused_experts.apply(
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
fused_out = None
if a1q.numel() == 0:
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels and the DeepEP low-latency kernels are always batched
# and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else:
fused_out = self._do_fused_experts(
a1=a1,
a1q=a1q,
w1=w1,
w2=w2,
topk_ids=topk_ids,
expert_num_tokens=expert_num_tokens,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input)

View File

@ -25,7 +25,7 @@ def _moe_permute(
"""
top_k_num = curr_topk_ids.size(1)
tokens_in_chunk = curr_hidden_states.sizze(0)
tokens_in_chunk = curr_hidden_states.size(0)
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids,
@ -37,11 +37,12 @@ def _moe_permute(
inv_perm: Optional[torch.Tensor] = None
num_tokens = top_k_num * tokens_in_chunk
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
curr_hidden_states = _fp8_perm(curr_hidden_states,
sorted_token_ids // top_k_num)

View File

@ -32,6 +32,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.dp_size = dp_size
self.quant_dtype = quant_dtype
def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_num_tokens
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.uint32
def prepare(
self,
a1: torch.Tensor,
@ -42,7 +48,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
@ -115,7 +122,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
bound_m=bound_m,
)
return expert_x, expert_x_scale, expert_num_tokens
return expert_x, expert_x_scale, expert_num_tokens, None, None
def finalize(
self,

View File

@ -24,6 +24,12 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
self.block_shape = block_shape
self.quant_dtype = quant_dtype
def max_num_tokens_per_rank(self) -> Optional[int]:
return None
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return None
def prepare(
self,
a1: torch.Tensor,
@ -34,7 +40,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
@ -47,7 +55,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
self.per_channel_quant,
self.block_shape)
return a1q, a1q_scale, None
return a1q, a1q_scale, None, None, None
def finalize(
self,

View File

@ -29,9 +29,10 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=per_channel_quant,
block_shape=block_shape,
block_m=block_m)
self.deep_gemm_expert = DeepGemmExperts()
self.allow_deep_gemm = allow_deep_gemm
self.use_fp8_w8a8 = use_fp8_w8a8
self.deep_gemm_expert = DeepGemmExperts(
) if self.allow_deep_gemm else None
def workspace_shapes(
self,
@ -46,6 +47,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
a, M, N, K, topk, num_experts)
else:
@ -73,7 +75,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
) -> torch.Tensor:
N = w1.size(1)
if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
and _valid_deep_gemm(hidden_states, w1, w2)):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.apply(
hidden_states,
w1,

View File

@ -18,8 +18,8 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
"""
assert prod(
v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly?
assert prod(v) <= x.numel(
), f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" # CUDAGRAPH unfriendly?
return x.flatten()[:prod(v)].view(*v)

View File

@ -3,7 +3,7 @@
import functools
import importlib.util
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
@ -452,6 +452,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm:
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
" DeepGemm kernels")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
@ -460,8 +463,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once(
"DeepGemm not supported on the current platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm)
@ -765,18 +770,39 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w2_input_scale
def select_gemm_impl(self, prepare_finalize):
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.")
experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
experts: Optional[Union[BatchedTritonExperts,
TritonOrDeepGemmExperts]] = None
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
use_batched_experts = max_num_tokens_per_rank is not None
if use_batched_experts:
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens_per_rank,
world_size=prepare_finalize.world_size,
dp_size=prepare_finalize.dp_size,
use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
assert experts is not None
return experts
def apply(
@ -797,6 +823,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -808,6 +835,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
if self.rocm_aiter_moe_enabled:
@ -855,7 +883,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,

View File

@ -154,6 +154,21 @@ class CudaPlatformBase(Platform):
logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.")
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
and vllm_config.compilation_config.use_cudagraph):
logger.info(
"Data Parallel: Forcing enforce eager to be True since DP "
"with DeepEP high-throughput kernels are not CUDA Graph "
"compatible. The DeepEP low-latency kernels are CUDA Graph "
"compatible. Set the all_to_all backend to deepep_low_latency "
"to use those kernels instead.")
vllm_config.compilation_config.use_cudagraph = False
vllm_config.model_config.enforce_eager = True
# TODO (varun): Turning this ON gives incorrect results for the
# Deepseek-V2-lite model.
vllm_config.compilation_config.use_inductor = False
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None