mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 05:58:42 +08:00
[Kernel] DeepEP dispatch-combine kernel integration (#18434)
Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
01eee40536
commit
fa98d77773
@ -516,9 +516,8 @@ void topk_softmax(
|
||||
topk,
|
||||
stream);
|
||||
}
|
||||
else
|
||||
else if (topk_indices.scalar_type() == at::ScalarType::UInt32)
|
||||
{
|
||||
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
@ -530,4 +529,17 @@ void topk_softmax(
|
||||
topk,
|
||||
stream);
|
||||
}
|
||||
else {
|
||||
assert(topk_indices.scalar_type() == at::ScalarType::Int64);
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int64_t>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
0
tests/kernels/moe/__init__.py
Normal file
0
tests/kernels/moe/__init__.py
Normal file
188
tests/kernels/moe/deepep_utils.py
Normal file
188
tests/kernels/moe/deepep_utils.py
Normal file
@ -0,0 +1,188 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
DeepEP test utilities
|
||||
"""
|
||||
import dataclasses
|
||||
import importlib
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
if has_deep_ep:
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
"tcp://localhost:29500",
|
||||
worker,
|
||||
) + args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
## DeepEP specific utils
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPHTArgs:
|
||||
num_local_experts: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPLLArgs:
|
||||
max_tokens_per_rank: int
|
||||
hidden_size: int
|
||||
num_experts: int
|
||||
use_fp8_dispatch: bool
|
||||
|
||||
|
||||
def make_deepep_ht_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
ht_args: DeepEPHTArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
import deep_ep
|
||||
|
||||
# high throughput a2a
|
||||
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
|
||||
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=low_latency_mode,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
return DeepEPHTPrepareAndFinalize(buffer=buffer,
|
||||
world_size=pgi.world_size,
|
||||
rank=pgi.rank,
|
||||
dp_size=dp_size,
|
||||
rank_expert_offset=pgi.rank *
|
||||
ht_args.num_local_experts,
|
||||
quant_dtype=q_dtype,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def make_deepep_ll_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ll_args: DeepEPLLArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
import deep_ep
|
||||
|
||||
# low-latency a2a
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
|
||||
pgi.world_size, deepep_ll_args.num_experts)
|
||||
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=deepep_ll_args.num_experts //
|
||||
pgi.world_size)
|
||||
return DeepEPLLPrepareAndFinalize(
|
||||
buffer=buffer,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
||||
quant_dtype=q_dtype,
|
||||
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
||||
)
|
||||
|
||||
|
||||
def make_deepep_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ht_args: Optional[DeepEPHTArgs],
|
||||
deepep_ll_args: Optional[DeepEPLLArgs],
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
if deepep_ht_args is not None:
|
||||
assert deepep_ll_args is None
|
||||
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
|
||||
block_shape)
|
||||
|
||||
assert deepep_ll_args is not None
|
||||
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype)
|
||||
371
tests/kernels/moe/test_deepep_deepgemm_moe.py
Normal file
371
tests/kernels/moe/test_deepep_deepgemm_moe.py
Normal file
@ -0,0 +1,371 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Test DeepEP + DeepGEMM integration
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .deepep_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
|
||||
try:
|
||||
import deep_gemm
|
||||
has_deep_gemm = True
|
||||
except ImportError:
|
||||
has_deep_gemm = False
|
||||
|
||||
if has_deep_ep:
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
|
||||
from .deepep_utils import DeepEPHTArgs, make_deepep_a2a
|
||||
|
||||
if has_deep_gemm:
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts)
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep,
|
||||
reason="Requires deep_ep kernels",
|
||||
)
|
||||
|
||||
requires_deep_gemm = pytest.mark.skipif(
|
||||
not has_deep_gemm,
|
||||
reason="Requires deep_gemm kernels",
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(deep_gemm.ceil_div(m, 128) * 128,
|
||||
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
block_size: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
|
||||
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
|
||||
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
|
||||
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
||||
|
||||
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestConfig:
|
||||
topk: int
|
||||
m: int
|
||||
k: int
|
||||
n: int
|
||||
num_experts: int
|
||||
block_size: list[int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
rank_tokens: torch.Tensor # all ranks make this many tokens
|
||||
rank_token_scales: Optional[torch.Tensor]
|
||||
topk: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
config: TestConfig
|
||||
|
||||
@staticmethod
|
||||
def make(config: TestConfig, rank) -> "TestTensors":
|
||||
|
||||
dtype = torch.bfloat16
|
||||
topk, m, k, block_size = (config.topk, config.m, config.k,
|
||||
config.block_size)
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
rank_tokens = torch.randn(
|
||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
|
||||
|
||||
block_k = block_size[1]
|
||||
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
|
||||
|
||||
topk_ids = torch.randint(
|
||||
low=0,
|
||||
high=config.num_experts,
|
||||
size=(m, topk),
|
||||
device=torch.cuda.current_device()).to(dtype=torch.int64)
|
||||
|
||||
topk_weights = torch.randn(topk_ids.shape,
|
||||
dtype=torch.float32,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
return TestTensors(rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
config=config)
|
||||
|
||||
|
||||
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int, q_dtype: Optional[torch.dtype],
|
||||
block_shape: list[int]) -> FusedMoEModularKernel:
|
||||
|
||||
a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
|
||||
deepep_ll_args=None,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=block_shape)
|
||||
|
||||
fused_experts = DeepGemmExperts()
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
test_tensors: TestTensors, w1: torch.Tensor,
|
||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
num_experts: int) -> torch.Tensor:
|
||||
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
def build_expert_map():
|
||||
num_local_experts = w1.size(0)
|
||||
expert_map = torch.full((num_experts, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
return expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg, pgi, dp_size, num_local_experts, q_dtype,
|
||||
test_tensors.config.block_size)
|
||||
|
||||
a1_scale = test_tensors.rank_token_scales
|
||||
|
||||
out = mk.forward(hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=False)
|
||||
return out
|
||||
|
||||
|
||||
def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor, block_shape: list[int]):
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_shape,
|
||||
# Make sure this is set to False so we
|
||||
# dont end up comparing the same implementation.
|
||||
allow_deep_gemm=False)
|
||||
|
||||
|
||||
def _deep_ep_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
config: TestConfig,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
w2 = w2.to(device=torch.cuda.current_device())
|
||||
w1_scale = w1_scale.to(device=torch.cuda.current_device())
|
||||
w2_scale = w2_scale.to(device=torch.cuda.current_device())
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, pgi.rank)
|
||||
block_shape = [
|
||||
w1.size(1) // w1_scale.size(1),
|
||||
w1.size(2) // w1_scale.size(2)
|
||||
]
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
triton_moe = triton_impl(a=test_tensors.rank_tokens,
|
||||
topk_ids=test_tensors.topk,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=test_tensors.rank_token_scales,
|
||||
block_shape=block_shape)
|
||||
|
||||
# Slice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
e_start = num_local_experts * pgi.rank
|
||||
e_end = e_start + num_local_experts
|
||||
w1_ep = w1[e_start:e_end]
|
||||
w2_ep = w2[e_start:e_end]
|
||||
w1_scale_ep = w1_scale[e_start:e_end]
|
||||
w2_scale_ep = w2_scale[e_start:e_end]
|
||||
|
||||
deepep_moe = deep_ep_moe_impl(
|
||||
pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
test_tensors,
|
||||
w1_ep,
|
||||
w2_ep,
|
||||
w1_scale_ep,
|
||||
w2_scale_ep,
|
||||
config.num_experts,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
triton_moe,
|
||||
deepep_moe,
|
||||
atol=6e-2,
|
||||
rtol=6e-2,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
(8, 128, 128),
|
||||
(8, 128, 512),
|
||||
(8, 512, 512),
|
||||
(3, 1024, 2048),
|
||||
(32, 128, 1024),
|
||||
(45, 512, 2048),
|
||||
(64, 1024, 1024),
|
||||
(129, 128, 256),
|
||||
(129, 1024, 2048),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
|
||||
world_dp_size: tuple[int, int]):
|
||||
|
||||
m, n, k = mnk
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(
|
||||
topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
num_experts, n, k, block_size)
|
||||
|
||||
parallel_launch(world_size, _deep_ep_moe, dp_size, config, w1, w2,
|
||||
w1_scale, w2_scale)
|
||||
459
tests/kernels/moe/test_deepep_moe.py
Normal file
459
tests/kernels/moe/test_deepep_moe.py
Normal file
@ -0,0 +1,459 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Test deepep dispatch-combine logic
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .deepep_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
|
||||
if has_deep_ep:
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep,
|
||||
reason="Requires deep_ep kernels",
|
||||
)
|
||||
|
||||
MAX_TOKENS_PER_RANK = 64
|
||||
|
||||
|
||||
def make_weights(
|
||||
e, n, k, dtype
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1_scale, w2_scale
|
||||
"""
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
return w1, w2, None, None
|
||||
|
||||
# per-out-channel weight quantization
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
w1 = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float16)
|
||||
w2 = torch.empty((e, k, n), device="cuda", dtype=torch.float16)
|
||||
|
||||
n_b_scales = 2 * n
|
||||
k_b_scales = k
|
||||
w1_q = torch.empty_like(w1, dtype=dtype)
|
||||
w2_q = torch.empty_like(w2, dtype=dtype)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=True)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=True)
|
||||
return w1_q, w2_q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestConfig:
|
||||
dtype: torch.dtype
|
||||
topk: int
|
||||
m: int
|
||||
k: int
|
||||
n: int
|
||||
num_experts: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
rank_tokens: torch.Tensor # all ranks make this many tokens
|
||||
rank_token_scales: Optional[torch.Tensor]
|
||||
topk: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
config: TestConfig
|
||||
|
||||
@staticmethod
|
||||
def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors":
|
||||
# TODO (varun) - check that float16 works ?
|
||||
assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn]
|
||||
token_dtype = (torch.bfloat16 if config.dtype == torch.float8_e4m3fn
|
||||
else config.dtype)
|
||||
rank_tokens = torch.randn(
|
||||
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
|
||||
rank_token_scales = None
|
||||
if config.dtype == torch.float8_e4m3fn:
|
||||
# low_latency_mode kernels dont support per-token quant.
|
||||
_, rank_token_scales = ops.scaled_fp8_quant(
|
||||
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
|
||||
|
||||
topk = torch.randint(low=0,
|
||||
high=config.num_experts,
|
||||
size=(config.m, config.topk),
|
||||
device="cuda").to(dtype=torch.int64)
|
||||
topk_weights = torch.randn(topk.shape,
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
return TestTensors(rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk,
|
||||
topk_weights=topk_weights,
|
||||
config=config)
|
||||
|
||||
|
||||
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool, hidden_size: int, dp_size: int,
|
||||
num_experts: int, num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
use_fp8_dispatch: bool) -> FusedMoEModularKernel:
|
||||
|
||||
is_quantized = q_dtype is not None
|
||||
|
||||
ht_args: Optional[DeepEPHTArgs] = None
|
||||
ll_args: Optional[DeepEPLLArgs] = None
|
||||
|
||||
if low_latency_mode:
|
||||
ll_args = DeepEPLLArgs(max_tokens_per_rank=MAX_TOKENS_PER_RANK,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts,
|
||||
use_fp8_dispatch=use_fp8_dispatch)
|
||||
else:
|
||||
assert not use_fp8_dispatch, (
|
||||
"FP8 Dispatch is valid only for low-latency kernels")
|
||||
ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)
|
||||
|
||||
a2a : Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = \
|
||||
make_deepep_a2a(pg = pg,
|
||||
pgi = pgi,
|
||||
dp_size = dp_size,
|
||||
q_dtype = q_dtype,
|
||||
block_shape = None,
|
||||
deepep_ht_args = ht_args,
|
||||
deepep_ll_args = ll_args)
|
||||
|
||||
if low_latency_mode:
|
||||
fused_experts = BatchedTritonExperts(
|
||||
max_num_tokens=MAX_TOKENS_PER_RANK,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False)
|
||||
else:
|
||||
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False)
|
||||
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool, dp_size: int,
|
||||
test_tensors: TestTensors, w1: torch.Tensor,
|
||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], num_experts: int,
|
||||
use_fp8_dispatch: bool) -> torch.Tensor:
|
||||
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
def build_expert_map():
|
||||
num_local_experts = w1.size(0)
|
||||
expert_map = torch.full((num_experts, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
return expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
|
||||
hidden_size = test_tensors.rank_tokens.size(1)
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
q_dtype = None
|
||||
if is_quantized:
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode,
|
||||
hidden_size, dp_size,
|
||||
num_experts,
|
||||
num_local_experts, q_dtype,
|
||||
use_fp8_dispatch)
|
||||
|
||||
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
||||
total_num_tokens = test_tensors.rank_tokens.size(0)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
rank_tokens_chunk = test_tensors.rank_tokens[chunk_start:chunk_end]
|
||||
topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end]
|
||||
topk_chunk = test_tensors.topk[chunk_start:chunk_end]
|
||||
rank_token_scales_chunk = test_tensors.rank_token_scales
|
||||
if rank_token_scales_chunk is not None and rank_token_scales_chunk.size(
|
||||
0) == total_num_tokens:
|
||||
# per act token
|
||||
rank_token_scales_chunk = rank_token_scales_chunk[
|
||||
chunk_start:chunk_end]
|
||||
|
||||
out = mk.forward(hidden_states=rank_tokens_chunk,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights_chunk,
|
||||
topk_ids=topk_chunk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
a1_scale=rank_token_scales_chunk,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=False)
|
||||
|
||||
if not skip_result_store:
|
||||
out_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
out, non_blocking=True)
|
||||
|
||||
max_num_tokens_per_dp = (MAX_TOKENS_PER_RANK
|
||||
if low_latency_mode else total_num_tokens)
|
||||
|
||||
for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(chunk_start + max_num_tokens_per_dp, total_num_tokens)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, total_num_tokens - 1)
|
||||
chunk_end = min(chunk_end, total_num_tokens)
|
||||
|
||||
process_chunk(chunk_start,
|
||||
chunk_end,
|
||||
skip_result_store=chunk_start_ >= total_num_tokens)
|
||||
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
|
||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool):
|
||||
|
||||
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
|
||||
test_tensors.topk_weights)
|
||||
if using_fp8_dispatch:
|
||||
# The DeepEP implementation is requested to dispatch using FP8.
|
||||
# For numerical stability for testing, emulate the fp8 dispatch by
|
||||
# blockwise quant and de-quant.
|
||||
a = test_tensors.rank_tokens
|
||||
aq, aq_scale = per_token_group_quant_fp8(a, 128)
|
||||
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
|
||||
a.shape).to(a.dtype)
|
||||
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
a_dtype = a.dtype
|
||||
if is_quantized:
|
||||
w1 = w1.to(dtype=torch.float32) * w1_scale
|
||||
w2 = w2.to(dtype=torch.float32) * w2_scale
|
||||
a = a.to(dtype=torch.float32)
|
||||
|
||||
m, _ = a.shape
|
||||
topk = topk_ids.size(1)
|
||||
out = torch.zeros_like(a)
|
||||
|
||||
for i in range(m):
|
||||
a_i = a[i]
|
||||
o_i = out[i]
|
||||
for j in range(topk):
|
||||
e = topk_ids[i][j]
|
||||
e_w = topk_weights[i][j]
|
||||
w1_e = w1[e]
|
||||
w2_e = w2[e]
|
||||
o_i += (SiluAndMul()
|
||||
(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)) * e_w
|
||||
|
||||
if is_quantized:
|
||||
out = out.to(dtype=a_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _deep_ep_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
dp_size: int,
|
||||
config: TestConfig,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
use_fp8_dispatch: bool,
|
||||
):
|
||||
|
||||
if not low_latency_mode:
|
||||
assert not use_fp8_dispatch, (
|
||||
"FP8 dispatch interface is available only in low-latency mode")
|
||||
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
w2 = w2.to(device=torch.cuda.current_device())
|
||||
if is_quantized:
|
||||
w1_scale = w1_scale.to( # type: ignore
|
||||
device=torch.cuda.current_device())
|
||||
w2_scale = w2_scale.to( # type: ignore
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, low_latency_mode)
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
|
||||
w2_scale, use_fp8_dispatch)
|
||||
|
||||
# Splice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
e_start = num_local_experts * pgi.rank
|
||||
e_end = e_start + num_local_experts
|
||||
w1_ep = w1[e_start:e_end]
|
||||
w2_ep = w2[e_start:e_end]
|
||||
|
||||
w1_scale_ep, w2_scale_ep = None, None
|
||||
if is_quantized:
|
||||
w1_scale_ep = w1_scale[e_start:e_end] # type: ignore
|
||||
w2_scale_ep = w2_scale[e_start:e_end] # type: ignore
|
||||
deepep_combined = deep_ep_moe_impl(
|
||||
pg,
|
||||
pgi,
|
||||
low_latency_mode,
|
||||
dp_size,
|
||||
test_tensors,
|
||||
w1_ep,
|
||||
w2_ep,
|
||||
w1_scale_ep,
|
||||
w2_scale_ep,
|
||||
config.num_experts,
|
||||
use_fp8_dispatch,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
torch_combined,
|
||||
deepep_combined,
|
||||
atol=6e-2,
|
||||
rtol=6e-2,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
(1, 128, 128),
|
||||
(2, 128, 512),
|
||||
(3, 1024, 2048),
|
||||
(32, 128, 1024),
|
||||
(45, 512, 2048),
|
||||
(64, 1024, 1024),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@requires_deep_ep
|
||||
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
num_experts: int, topk: int, world_dp_size: tuple[int,
|
||||
int]):
|
||||
low_latency_mode = False
|
||||
use_fp8_dispatch = False
|
||||
m, n, k = mnk
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(dtype=dtype,
|
||||
topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
|
||||
|
||||
|
||||
MNKs = [
|
||||
(1, 128, 2560),
|
||||
(2, 128, 2560),
|
||||
(3, 1024, 2560),
|
||||
(32, 128, 2560),
|
||||
(45, 512, 2560),
|
||||
(64, 1024, 2560),
|
||||
(222, 1024, 2560),
|
||||
]
|
||||
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
|
||||
USE_FP8_DISPATCH = [True, False]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
|
||||
@requires_deep_ep
|
||||
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
num_experts: int, topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_fp8_dispatch: bool):
|
||||
|
||||
low_latency_mode = True
|
||||
m, n, k = mnk
|
||||
|
||||
if (low_latency_mode
|
||||
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):
|
||||
pytest.skip(
|
||||
f"Skipping test as hidden size {k} is not in list of supported "
|
||||
f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}"
|
||||
)
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(dtype=dtype,
|
||||
topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
|
||||
@ -1856,6 +1856,8 @@ class ParallelConfig:
|
||||
factors.append(self.pipeline_parallel_size)
|
||||
factors.append(self.tensor_parallel_size)
|
||||
factors.append(self.enable_expert_parallel)
|
||||
factors.append(self.data_parallel_size)
|
||||
factors.append(envs.VLLM_ALL2ALL_BACKEND)
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.util
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -129,3 +129,147 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
||||
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
# reasonable defaults based on profiling.
|
||||
self.num_sms = 20
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = 1024 * 1024 * 1024
|
||||
num_rdma_bytes = None
|
||||
num_qps_per_rank = None
|
||||
|
||||
if self.internode:
|
||||
num_rdma_bytes = 1024 * 1024 * 1024
|
||||
num_qps_per_rank = self.num_sms // 2
|
||||
else:
|
||||
assert self.intranode
|
||||
num_rdma_bytes = 0
|
||||
num_qps_per_rank = 1
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
assert num_qps_per_rank is not None
|
||||
return dict(group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
|
||||
assert len(kwargs) == 0, (
|
||||
"DeepEPHTAll2AllManager expects no arguments. All the required "
|
||||
"args are computed in the Manager itself.")
|
||||
|
||||
import deep_ep
|
||||
buffer_kwargs = self._make_all2all_kwargs()
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
# It is dangerous to set num sms outside this function. num_sms is not
|
||||
# a part of the hash-key that identifies this object. If we are in a
|
||||
# situation where we make objects with different num_sms, the hash key
|
||||
# in get_or_create must be updated.
|
||||
handle.set_num_sms(self.num_sms)
|
||||
return handle
|
||||
|
||||
|
||||
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP Low-Latency kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
max_num_tokens_per_dp_rank: int,
|
||||
token_hidden_size: int,
|
||||
num_ep_ranks: int,
|
||||
num_global_experts: int,
|
||||
num_local_experts: int,
|
||||
) -> dict[Any, Any]:
|
||||
"""
|
||||
max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
|
||||
can dispatch all the ranks must hold the same value.
|
||||
token_hidden_size: the hidden dimension of each token.
|
||||
num_ep_ranks: the number of EP group ranks.
|
||||
num_global_experts: Number of experts in the model.
|
||||
num_local_experts: Number of experts in an EP rank.
|
||||
"""
|
||||
import deep_ep
|
||||
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = 1024 * 1024 * 1024
|
||||
num_qps_per_rank = num_local_experts
|
||||
num_rdma_bytes = None
|
||||
|
||||
if self.internode:
|
||||
num_rdma_bytes = 1024 * 1024 * 1024
|
||||
else:
|
||||
assert self.intranode
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
hidden=token_hidden_size,
|
||||
num_ranks=num_ep_ranks,
|
||||
num_experts=num_global_experts)
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
return dict(group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
"""
|
||||
The kwargs for DeepEPLLAll2AllManager is dictated by
|
||||
_make_all2all_kwargs.
|
||||
"""
|
||||
import deep_ep
|
||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
# It is dangerous to set num sms outside this function. num_sms is not
|
||||
# a part of the hash-key that identifies this object. If we are in a
|
||||
# situation where we make objects with different num_sms, the hash key
|
||||
# in get_or_create must be updated.
|
||||
handle.set_num_sms(self.num_sms)
|
||||
return handle
|
||||
|
||||
@ -67,6 +67,14 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
from .all2all import PPLXAll2AllManager
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
logger.info("Using PPLX all2all manager.")
|
||||
elif all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
logger.info("Using DeepEP High-Throughput all2all manager.")
|
||||
elif all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
logger.info("Using DeepEP Low-Latency all2all manager.")
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
||||
|
||||
|
||||
@ -826,6 +826,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Available options:
|
||||
# - "naive": naive all2all implementation using all-reduce
|
||||
# - "pplx": use pplx kernels
|
||||
# - "deepep_high_throughput", use deepep high-throughput kernels
|
||||
# - "deepep_low_latency", use deepep low-latency kernels
|
||||
"VLLM_ALL2ALL_BACKEND":
|
||||
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
|
||||
|
||||
|
||||
@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_permute)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
_resize_cache)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, per_token_group_quant_fp8)
|
||||
from vllm.utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -34,10 +34,8 @@ def _valid_deep_gemm_shape(M: int, N: int, K: int):
|
||||
return align <= M and N % align == 0 and K % align == 0
|
||||
|
||||
|
||||
def _valid_deep_gemm(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None) -> bool:
|
||||
def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
w2: torch.Tensor) -> bool:
|
||||
"""
|
||||
Check if the given problem size is supported by the DeepGemm grouped
|
||||
gemm kernel. All of M, N, K and the quantization block_shape must be
|
||||
@ -47,10 +45,6 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
|
||||
logger.debug("DeepGemm disabled: deep_gemm not available.")
|
||||
return False
|
||||
|
||||
if expert_map is not None:
|
||||
logger.debug("DeepGemm disabled: expert map NYI.")
|
||||
return False
|
||||
|
||||
M = hidden_states.size(0)
|
||||
_, K, N = w2.size()
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
@ -116,7 +110,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
|
||||
assert global_num_experts != -1
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = w1.size(0)
|
||||
|
||||
assert w2.size(1) == K
|
||||
|
||||
a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
|
||||
@ -128,6 +124,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.block_shape[0],
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
# DeepGemm (Grouped Contiguous) kernel needs a valid B index
|
||||
# for all rows of A. To that effect, simply compute with
|
||||
# the 0th weight matrix.
|
||||
# Note that this relies on the fact that corresponding topk
|
||||
# weights would be 0 during weight multiplication.
|
||||
expert_ids = torch.where(expert_ids == -1, 0, expert_ids)
|
||||
|
||||
# Note: M_sum is different than the pre-permuted shape of a1q.
|
||||
M_sum = a1q.size(0)
|
||||
workspace1 = _resize_cache(workspace13, (M_sum, N))
|
||||
@ -140,9 +144,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.activation(activation, workspace2, workspace1.view(-1, N))
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False,
|
||||
self.block_shape)
|
||||
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
|
||||
self.block_shape[1],
|
||||
column_major_scales=True)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
|
||||
|
||||
@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
|
||||
|
||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
Prepare/Finalize using DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
buffer: deep_ep.Buffer,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
dp_size: int,
|
||||
rank_expert_offset: int,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
super().__init__()
|
||||
self.buffer = buffer
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.dp_size = dp_size
|
||||
self.rank_expert_offset = rank_expert_offset
|
||||
self.quant_dtype = quant_dtype
|
||||
self.block_shape = block_shape
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
self.handle = None
|
||||
|
||||
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
|
||||
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
return torch.int64
|
||||
|
||||
def _get_dispatch_config(self) -> Optional[deep_ep.Config]:
|
||||
if self.dp_size not in self.available_rank_configs:
|
||||
return None
|
||||
return deep_ep.Buffer.get_dispatch_config(self.dp_size)
|
||||
|
||||
def _get_combine_config(self) -> Optional[deep_ep.Config]:
|
||||
if self.dp_size not in self.available_rank_configs:
|
||||
return None
|
||||
return deep_ep.Buffer.get_combine_config(self.dp_size)
|
||||
|
||||
def _do_quant(self, tokens: torch.Tensor,
|
||||
token_scales: Optional[torch.Tensor], per_act_token: bool):
|
||||
tokens, token_scales = moe_kernel_quantize_input(
|
||||
tokens, token_scales, self.quant_dtype, per_act_token,
|
||||
self.block_shape)
|
||||
return tokens, token_scales
|
||||
|
||||
def _do_dispatch(self, tokens: torch.Tensor,
|
||||
token_scales: Optional[torch.Tensor],
|
||||
rank_topk_ids: torch.Tensor,
|
||||
rank_topk_weights: torch.Tensor, num_experts: int):
|
||||
|
||||
has_scales = token_scales is not None
|
||||
|
||||
(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens,
|
||||
is_token_in_rank, event) = self.buffer.get_dispatch_layout(
|
||||
topk_idx=rank_topk_ids,
|
||||
num_experts=num_experts,
|
||||
previous_event=None,
|
||||
async_finish=False,
|
||||
allocate_on_comm_stream=False)
|
||||
|
||||
token_data = tokens
|
||||
if has_scales:
|
||||
token_data = (tokens, token_scales)
|
||||
|
||||
(
|
||||
token_data, expert_topk_ids, expert_topk_weights,
|
||||
expert_num_tokens_per_expert_list, self.handle, event
|
||||
) = self.buffer.dispatch(
|
||||
x=token_data,
|
||||
handle=None,
|
||||
num_tokens_per_rank=num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||
is_token_in_rank=is_token_in_rank,
|
||||
num_tokens_per_expert=expert_num_tokens,
|
||||
topk_idx=rank_topk_ids,
|
||||
topk_weights=rank_topk_weights,
|
||||
# expert_alignment rounds the number of tokens per expert
|
||||
# to this value.
|
||||
expert_alignment=1,
|
||||
config=self._get_dispatch_config(),
|
||||
previous_event=None,
|
||||
async_finish=False,
|
||||
allocate_on_comm_stream=False)
|
||||
|
||||
if has_scales:
|
||||
expert_x, expert_x_scale = token_data
|
||||
else:
|
||||
expert_x, expert_x_scale = token_data, None
|
||||
|
||||
# The existing MOE kernels assume that all entries of topk_ids are
|
||||
# valid. To that effect, set the -1s in expert_topk_ids to some expert
|
||||
# outside this rank so the expert_map can remap it to -1 when safe.
|
||||
# With Expert Parallel, the experts are divided amongst the rank
|
||||
# sequentially. For rank 0, set it to num_experts - 1 and for all other
|
||||
# ranks set it to 0 as we know that expert_map will have a -1 in those
|
||||
# regions for those ranks.
|
||||
#
|
||||
# DeepEP's topk_ids output refers to the local experts directly. Offset
|
||||
# the topk_ids to move it back to the global experts space so it aligns
|
||||
# with existing vLLM interfaces.
|
||||
expert_topk_ids = torch.where(
|
||||
expert_topk_ids == -1,
|
||||
num_experts - 1 if self.rank_expert_offset == 0 else 0,
|
||||
expert_topk_ids + self.rank_expert_offset)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
||||
expert_topk_weights)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
rank_topk_weights: torch.Tensor,
|
||||
rank_topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = rank_topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1")
|
||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||
|
||||
# Check if there is a block_shape / or if we can infer the quantization
|
||||
# schemes from the scales.
|
||||
per_token_quant = None
|
||||
if all([x is None for x in [self.block_shape, a1_scale, a2_scale]
|
||||
]) and self.quant_dtype is not None:
|
||||
# Quantization required despite none of the inputs suggesting
|
||||
# quantization. Fallback to per_token_dynamic quant.
|
||||
per_token_quant = True
|
||||
else:
|
||||
per_token_quant = ((self.block_shape is not None) or
|
||||
(a1_scale is not None and a1_scale.numel() != 1)
|
||||
or (a2_scale is not None
|
||||
and a2_scale.numel() != 1))
|
||||
|
||||
if per_token_quant:
|
||||
a1q, a1q_scale = self._do_quant(a1, a1_scale, per_act_token=True)
|
||||
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
||||
expert_topk_weights) = self._do_dispatch(
|
||||
tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=rank_topk_ids,
|
||||
rank_topk_weights=rank_topk_weights,
|
||||
num_experts=num_experts)
|
||||
else:
|
||||
# DeepEP kernels only support dispatching per-token-quant
|
||||
# quantization. dispatch in bfloat16.
|
||||
(expert_x, _, expert_num_tokens, expert_topk_ids,
|
||||
expert_topk_weights) = self._do_dispatch(
|
||||
tokens=a1,
|
||||
token_scales=None,
|
||||
rank_topk_ids=rank_topk_ids,
|
||||
rank_topk_weights=rank_topk_weights,
|
||||
num_experts=num_experts)
|
||||
# quantize now
|
||||
expert_x_scale = None
|
||||
if expert_x.numel() != 0:
|
||||
expert_x, expert_x_scale = self._do_quant(expert_x,
|
||||
a1_scale,
|
||||
per_act_token=False)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
||||
expert_topk_weights)
|
||||
|
||||
def _apply_weights_and_reduce(self, num_tokens: int,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
output_dtype: torch.dtype):
|
||||
|
||||
if fused_expert_output.ndim == 2:
|
||||
hidden_dim = fused_expert_output.size(-1)
|
||||
fused_expert_output = fused_expert_output.view(
|
||||
num_tokens, -1, hidden_dim)
|
||||
|
||||
if not apply_router_weight_on_input:
|
||||
# The DeepEP combine kernels don't do the topk weight
|
||||
# multiplication. We multiply the weights locally.
|
||||
fused_expert_output = fused_expert_output.to(torch.float32)
|
||||
fused_expert_output = fused_expert_output * topk_weights.view(
|
||||
fused_expert_output.size(0), -1, 1)
|
||||
fused_expert_output = fused_expert_output.to(output_dtype)
|
||||
|
||||
return fused_expert_output.sum(dim=1).to(output_dtype)
|
||||
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> None:
|
||||
|
||||
assert self.handle is not None
|
||||
|
||||
# fused_expert_output can have 0 tokens - This happens when none of the
|
||||
# tokens from the all2all reach this EP rank.
|
||||
if fused_expert_output.numel() != 0:
|
||||
fused_expert_output = self._apply_weights_and_reduce(
|
||||
num_tokens=topk_ids.size(0),
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
output_dtype=output.dtype)
|
||||
|
||||
combined_x, _, event = self.buffer.combine(
|
||||
x=fused_expert_output,
|
||||
handle=self.handle,
|
||||
topk_weights=None,
|
||||
config=self._get_combine_config(),
|
||||
previous_event=None,
|
||||
async_finish=False,
|
||||
allocate_on_comm_stream=False)
|
||||
# Respect inplace outputs.
|
||||
output.copy_(combined_x, non_blocking=True)
|
||||
@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
|
||||
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||
|
||||
|
||||
def dequant_fp8(expert_x_fp8: torch.Tensor,
|
||||
expert_x_scales: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Return dequantized tensor in fp32
|
||||
"""
|
||||
# TODO (varun) : Optimize leverage num_tokens_per_expert counts
|
||||
assert expert_x_fp8.is_contiguous()
|
||||
expert_x_scales = expert_x_scales.contiguous()
|
||||
num_experts = expert_x_fp8.size(0)
|
||||
|
||||
expert_x_fp32 = expert_x_fp8.to(torch.float32).view(
|
||||
num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE)
|
||||
expert_x_scales = expert_x_scales.view(num_experts, -1, 1)
|
||||
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape)
|
||||
|
||||
|
||||
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
Prepare/Finalize using DeepEP low-latency kernels.
|
||||
"""
|
||||
|
||||
# DeepEP low-latency kernels are compiled only for certain
|
||||
# specific hidden sizes.
|
||||
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
|
||||
|
||||
def __init__(self,
|
||||
buffer: deep_ep.Buffer,
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
max_tokens_per_rank: int,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_fp8_dispatch: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.buffer = buffer
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
self.quant_dtype = quant_dtype
|
||||
self.block_shape = block_shape
|
||||
self.max_tokens_per_rank = max_tokens_per_rank
|
||||
self.use_fp8_dispatch = use_fp8_dispatch
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
self.handle = None
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return self.max_tokens_per_rank
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
return torch.int64
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
rank_topk_weights: torch.Tensor,
|
||||
rank_topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
hidden_size = a1.size(1)
|
||||
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
|
||||
(f"Hidden Size {hidden_size} not in supported list of hidden sizes"
|
||||
f"{self.SUPPORTED_HIDDEN_SIZES}")
|
||||
|
||||
if self.use_fp8_dispatch:
|
||||
assert hidden_size % 128 == 0, \
|
||||
"DeepEP kernels quantize the inputs in blocks of shape 128"
|
||||
|
||||
# Quantize
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
assert not per_act_token, (
|
||||
"low_latency kernels don't support per-act-token quant")
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = rank_topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1")
|
||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||
|
||||
# Dispatch
|
||||
expert_x, expert_num_tokens, self.handle, event, hook = \
|
||||
self.buffer.low_latency_dispatch(a1,
|
||||
rank_topk_ids,
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
async_finish=False,
|
||||
return_recv_hook=False)
|
||||
|
||||
if self.use_fp8_dispatch:
|
||||
# TODO (varun) : In the case of dynamic quantization, we could
|
||||
# probably skip the quant below and use the results directly.
|
||||
# Although note that the deepep quant is per token 128 elements.
|
||||
expert_x_fp8, expert_x_scales = expert_x
|
||||
expert_x = dequant_fp8(expert_x_fp8,
|
||||
expert_x_scales).to(dtype=a1.dtype)
|
||||
|
||||
num_experts = expert_x.size(0)
|
||||
hidden_dim = expert_x.size(-1)
|
||||
|
||||
expert_x = expert_x.view((-1, expert_x.size(-1)))
|
||||
expert_x, expert_x_scale = moe_kernel_quantize_input(
|
||||
expert_x, a1_scale, self.quant_dtype, per_act_token,
|
||||
self.block_shape)
|
||||
expert_x = expert_x.view((num_experts, -1, hidden_dim))
|
||||
|
||||
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
|
||||
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> None:
|
||||
|
||||
assert self.handle is not None
|
||||
|
||||
combine_topk_weights = topk_weights
|
||||
if apply_router_weight_on_input:
|
||||
# weights have already been applied.
|
||||
combine_topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
# TODO (varun) : Enable zero copy mode
|
||||
_, event, hook = self.buffer.low_latency_combine(
|
||||
fused_expert_output,
|
||||
topk_ids,
|
||||
combine_topk_weights,
|
||||
self.handle,
|
||||
async_finish=False,
|
||||
zero_copy=False,
|
||||
return_recv_hook=False,
|
||||
out=output)
|
||||
@ -10,7 +10,8 @@ import triton.language as tl
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, try_get_optimal_moe_config)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -397,6 +398,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self.rank = rank
|
||||
self.max_num_tokens = max_num_tokens
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return self.max_num_tokens
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
return None
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
@ -407,7 +414,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
assert a1.dim() == 2
|
||||
assert topk_ids.dim() == 2
|
||||
assert topk_ids.size(0) == a1.size(0)
|
||||
@ -450,7 +458,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
first_expert, :rows, :] = a1[:topks.numel()][topks]
|
||||
tokens_per_expert[expert_id - first_expert] = rows
|
||||
|
||||
return b_a1, a1_scale, tokens_per_expert
|
||||
return b_a1, a1_scale, tokens_per_expert, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
@ -601,6 +609,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
world_size: int = 1,
|
||||
dp_size: int = 1,
|
||||
@ -611,12 +620,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.block_shape = block_shape
|
||||
self.per_channel_quant = per_channel_quant
|
||||
self.max_num_tokens = max_num_tokens
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
assert self.block_shape is None, "NYI"
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -670,8 +682,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
|
||||
]
|
||||
|
||||
# TODO: num_tokens -> max_num_tokens?
|
||||
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||
hidden_states, w1, w2, topk_ids)
|
||||
|
||||
assert w1.size(0) == E
|
||||
@ -687,7 +698,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
num_tokens,
|
||||
max_num_tokens,
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
|
||||
@ -706,10 +717,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
|
||||
intermediate_cache1 = _resize_cache(workspace13,
|
||||
(E, max_num_tokens, N))
|
||||
intermediate_cache2 = _resize_cache(workspace2,
|
||||
(E, num_tokens, N // 2))
|
||||
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
|
||||
(E, max_num_tokens, N // 2))
|
||||
intermediate_cache3 = _resize_cache(workspace13,
|
||||
(E, max_num_tokens, K))
|
||||
|
||||
# MM1
|
||||
invoke_moe_batched_triton_kernel(A=hidden_states,
|
||||
@ -731,15 +744,20 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.activation(activation, intermediate_cache2.view(-1, N // 2),
|
||||
intermediate_cache1.view(-1, N))
|
||||
|
||||
#qintermediate_cache2 = intermediate_cache2
|
||||
a2q_scale = a2_scale
|
||||
# TODO (varun) : support w8a8
|
||||
assert not self.use_fp8_w8a8
|
||||
#if self.use_fp8_w8a8:
|
||||
# qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||
# intermediate_cache2, a2_scale, self.block_shape)
|
||||
ic2_hidden_size = intermediate_cache2.size(-1)
|
||||
intermediate_cache2 = intermediate_cache2.view(-1, ic2_hidden_size)
|
||||
|
||||
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
A_scale=a2_scale,
|
||||
qtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
qintermediate_cache2 = qintermediate_cache2.view(
|
||||
(E, -1, ic2_hidden_size))
|
||||
|
||||
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
|
||||
B=w2,
|
||||
C=intermediate_cache3,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
@ -752,5 +770,4 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
config=config,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
@ -1164,7 +1164,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
# permute/unpermute ops are available.
|
||||
N = w1.shape[1]
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||
and _valid_deep_gemm(hidden_states, w1, w2)):
|
||||
assert apply_router_weight_on_input is False
|
||||
return deep_gemm_moe_fp8(
|
||||
hidden_states=hidden_states,
|
||||
|
||||
@ -5,7 +5,7 @@ import importlib
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -30,16 +30,19 @@ from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
has_deepep = importlib.util.find_spec("deep_ep") is not None
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_batched_moe import (BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts)
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
from .fused_moe import TritonExperts, fused_experts
|
||||
from .modular_kernel import (FusedMoEModularKernel,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize)
|
||||
if has_pplx:
|
||||
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
||||
if has_deepep:
|
||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
FusedMoEPermuteExpertsUnpermute = None # type: ignore
|
||||
@ -71,10 +74,24 @@ class FusedMoEParallelConfig:
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
|
||||
@property
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep and \
|
||||
envs.VLLM_ALL2ALL_BACKEND == "pplx"
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
@ -231,6 +248,14 @@ class MoEConfig:
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
@ -252,7 +277,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
prepare_finalize = None
|
||||
quant_dtype = None
|
||||
act_quant_block_size = None
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
if isinstance(quant_config, Fp8Config):
|
||||
act_quant_block_size = quant_config.weight_block_size
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
prepare_finalize: Optional[Union[PplxPrepareAndFinalize,
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
DeepEPLLPrepareAndFinalize]] = None
|
||||
if moe.use_pplx_kernels:
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
@ -288,8 +322,49 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
quant_dtype=moe.in_dtype,
|
||||
)
|
||||
elif moe.use_deepep_ht_kernels:
|
||||
assert moe.dp_size == all2all_manager.dp_world_size
|
||||
|
||||
all_to_all_args = dict()
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
prepare_finalize = DeepEPHTPrepareAndFinalize(
|
||||
handle,
|
||||
world_size=all2all_manager.world_size,
|
||||
rank=all2all_manager.rank,
|
||||
dp_size=all2all_manager.dp_world_size,
|
||||
rank_expert_offset=all2all_manager.rank *
|
||||
moe.num_local_experts,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=act_quant_block_size,
|
||||
)
|
||||
|
||||
elif moe.use_deepep_ll_kernels:
|
||||
assert moe.dp_size == all2all_manager.dp_world_size
|
||||
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||
token_hidden_size=moe.hidden_dim,
|
||||
num_ep_ranks=all2all_manager.world_size,
|
||||
num_global_experts=moe.num_experts,
|
||||
num_local_experts=moe.num_experts //
|
||||
all2all_manager.world_size)
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
# Note (varun): Whether to use FP8 dispatch or not needs some
|
||||
# profiling. Turning it off for now.
|
||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||
handle,
|
||||
world_size=all2all_manager.world_size,
|
||||
dp_size=all2all_manager.dp_world_size,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=act_quant_block_size,
|
||||
use_fp8_dispatch=False,
|
||||
)
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
if prepare_finalize is not None:
|
||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||
experts = self.select_gemm_impl(prepare_finalize)
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@ -297,7 +372,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]
|
||||
self, prepare_finalize: FusedMoEPrepareAndFinalize
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
# gemm implementation
|
||||
@ -334,6 +409,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def __init__(self, moe: MoEConfig):
|
||||
super().__init__()
|
||||
self.fused_experts = fused_experts # type: ignore
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
@ -343,8 +419,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
else:
|
||||
self.rocm_aiter_fused_experts = None # type: ignore
|
||||
|
||||
def select_gemm_impl(
|
||||
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]):
|
||||
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize):
|
||||
|
||||
assert self.fused_experts == fused_experts
|
||||
|
||||
@ -353,11 +428,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
||||
|
||||
if isinstance(prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
use_batched_experts = prepare_finalize.max_num_tokens_per_rank(
|
||||
) is not None
|
||||
if use_batched_experts:
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
assert self.moe.dp_size == all2all_manager.dp_world_size
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
max_num_tokens=self.moe.max_num_tokens,
|
||||
world_size=all2all_manager.world_size,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
@ -366,6 +443,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
per_channel_quant=False,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
@ -494,6 +572,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -505,7 +584,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=torch.uint32 if self.moe.use_pplx_kernels else None)
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
assert expert_map is None
|
||||
@ -806,11 +885,8 @@ class FusedMoE(torch.nn.Module):
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
quant_method: Optional[QuantizeMethodBase] = None
|
||||
|
||||
if quant_config is None:
|
||||
quant_method = UnquantizedFusedMoEMethod(moe)
|
||||
else:
|
||||
quant_method = quant_config.get_quant_method(self, prefix)
|
||||
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
|
||||
else quant_config.get_quant_method(self, prefix))
|
||||
|
||||
assert quant_method is not None
|
||||
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
@ -836,7 +912,8 @@ class FusedMoE(torch.nn.Module):
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
if self.moe_parallel_config.use_pplx_kernels:
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
act_dtype = vllm_config.model_config.dtype
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||
@ -880,6 +957,14 @@ class FusedMoE(torch.nn.Module):
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
def _load_per_tensor_weight_scale(self, shard_id: str,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
@ -1210,19 +1295,21 @@ class FusedMoE(torch.nn.Module):
|
||||
When just tensor-parallel is used, it is not required to reduce
|
||||
the shared_experts results immediately. Instead we reduce at the
|
||||
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
|
||||
With EP and the pplx kernels - this is no longer viable as all
|
||||
With EP and all2all kernels - this is no longer viable as all
|
||||
GPU ranks in DP, produce the complete set of hidden_states.
|
||||
Therefore it is required that we reduce the shared_experts output
|
||||
early.
|
||||
"""
|
||||
return self.use_pplx_kernels
|
||||
return (self.use_pplx_kernels or self.use_deepep_ht_kernels
|
||||
or self.use_deepep_ll_kernels)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(
|
||||
self, final_hidden_states: torch.Tensor):
|
||||
"""
|
||||
The pplx combine kernel reduces across GPU ranks by default.
|
||||
"""
|
||||
if self.use_pplx_kernels:
|
||||
if (self.use_pplx_kernels or self.use_deepep_ht_kernels
|
||||
or self.use_deepep_ll_kernels):
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
@ -1289,7 +1376,7 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
ctx = get_forward_context()
|
||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
|
||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||
|
||||
num_tokens = full_hidden_states.size(0)
|
||||
for chunk_start_ in range(0, max_tokens_across_dp,
|
||||
@ -1310,12 +1397,17 @@ class FusedMoE(torch.nn.Module):
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
if self.moe_parallel_config.use_pplx_kernels:
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||
|
||||
if self.dp_size > 1:
|
||||
do_naive_dispatch_combine: bool = (
|
||||
self.dp_size > 1
|
||||
and not self.moe_parallel_config.use_deepep_ht_kernels)
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
@ -1335,12 +1427,12 @@ class FusedMoE(torch.nn.Module):
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
if self.dp_size > 1:
|
||||
if do_naive_dispatch_combine:
|
||||
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.)
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
# Default set to False. (May have to add shared expert outputs.
|
||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@ -94,7 +94,8 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed
|
||||
for this kernel.
|
||||
@ -113,6 +114,10 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- quantized + dispatched a1_scales.
|
||||
- Optional tensor as big as number of local experts that contains the
|
||||
number of tokens assigned to each local expert.
|
||||
- Optional dispatched expert topk IDs
|
||||
- Optional dispatched expert topk weight
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -138,6 +143,27 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
required topk indices dtype so it can be respected.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can processes only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
@ -261,6 +287,61 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
|
||||
def _do_fused_experts(
|
||||
self,
|
||||
a1: torch.Tensor, # input to forward fn
|
||||
a1q: torch.Tensor, # output of prepare fn
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
# Use a1 here to decipher the correct workspace datatype
|
||||
workspace13_shape, workspace2_shape, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
|
||||
global_num_experts))
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
workspace13 = torch.zeros(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.zeros(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
|
||||
return fused_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -315,49 +396,48 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
a1 = hidden_states
|
||||
E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids)
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
output = a1 if inplace else torch.zeros_like(a1)
|
||||
|
||||
workspace13_shape, workspace2_shape, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
|
||||
global_num_experts))
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = w1.size(0)
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
workspace13 = torch.zeros(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.zeros(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1, a1_scale, a2_scale, topk_weights, topk_ids,
|
||||
global_num_experts, expert_map, apply_router_weight_on_input)
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
topk_weights = (topk_weights if _expert_topk_weights is None else
|
||||
_expert_topk_weights)
|
||||
|
||||
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare(
|
||||
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
|
||||
expert_map, apply_router_weight_on_input)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
fused_out = None
|
||||
if a1q.numel() == 0:
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
# incompatible all2all kernels like the DeepEP high-throughput
|
||||
# kernels. CUDAGraph compatible all2all kernels like the pplx
|
||||
# kernels and the DeepEP low-latency kernels are always batched
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
|
||||
else:
|
||||
fused_out = self._do_fused_experts(
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_ids=topk_ids,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale)
|
||||
|
||||
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||
topk_ids, apply_router_weight_on_input)
|
||||
|
||||
@ -25,7 +25,7 @@ def _moe_permute(
|
||||
"""
|
||||
top_k_num = curr_topk_ids.size(1)
|
||||
|
||||
tokens_in_chunk = curr_hidden_states.sizze(0)
|
||||
tokens_in_chunk = curr_hidden_states.size(0)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids,
|
||||
@ -37,11 +37,12 @@ def _moe_permute(
|
||||
inv_perm: Optional[torch.Tensor] = None
|
||||
|
||||
num_tokens = top_k_num * tokens_in_chunk
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
||||
|
||||
# Permute according to sorted token ids.
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
|
||||
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
||||
sorted_token_ids // top_k_num)
|
||||
|
||||
|
||||
@ -32,6 +32,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self.dp_size = dp_size
|
||||
self.quant_dtype = quant_dtype
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return self.max_num_tokens
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
return torch.uint32
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
@ -42,7 +48,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
@ -115,7 +122,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
bound_m=bound_m,
|
||||
)
|
||||
|
||||
return expert_x, expert_x_scale, expert_num_tokens
|
||||
return expert_x, expert_x_scale, expert_num_tokens, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
|
||||
@ -24,6 +24,12 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
self.block_shape = block_shape
|
||||
self.quant_dtype = quant_dtype
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
return None
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
@ -34,7 +40,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
@ -47,7 +55,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
self.per_channel_quant,
|
||||
self.block_shape)
|
||||
|
||||
return a1q, a1q_scale, None
|
||||
return a1q, a1q_scale, None, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
|
||||
@ -29,9 +29,10 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
block_m=block_m)
|
||||
self.deep_gemm_expert = DeepGemmExperts()
|
||||
self.allow_deep_gemm = allow_deep_gemm
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.deep_gemm_expert = DeepGemmExperts(
|
||||
) if self.allow_deep_gemm else None
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@ -46,6 +47,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
|
||||
assert self.deep_gemm_expert is not None
|
||||
return self.deep_gemm_expert.workspace_shapes(
|
||||
a, M, N, K, topk, num_experts)
|
||||
else:
|
||||
@ -73,7 +75,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
) -> torch.Tensor:
|
||||
N = w1.size(1)
|
||||
if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||
and _valid_deep_gemm(hidden_states, w1, w2)):
|
||||
assert self.deep_gemm_expert is not None
|
||||
return self.deep_gemm_expert.apply(
|
||||
hidden_states,
|
||||
w1,
|
||||
|
||||
@ -18,8 +18,8 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
||||
Shrink the given tensor and apply the given view to it. This is
|
||||
used to resize the intermediate fused_moe caches.
|
||||
"""
|
||||
assert prod(
|
||||
v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly?
|
||||
assert prod(v) <= x.numel(
|
||||
), f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" # CUDAGRAPH unfriendly?
|
||||
return x.flatten()[:prod(v)].view(*v)
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import functools
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -452,6 +452,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
if not has_deep_gemm:
|
||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||
elif not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
" DeepGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(90)):
|
||||
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||
@ -460,8 +463,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
self.fused_experts = functools.partial( # type: ignore
|
||||
fused_experts,
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm)
|
||||
|
||||
@ -765,18 +770,39 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w2_input_scale
|
||||
|
||||
def select_gemm_impl(self, prepare_finalize):
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||
|
||||
experts = TritonOrDeepGemmExperts(
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
experts: Optional[Union[BatchedTritonExperts,
|
||||
TritonOrDeepGemmExperts]] = None
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
use_batched_experts = max_num_tokens_per_rank is not None
|
||||
|
||||
if use_batched_experts:
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
world_size=prepare_finalize.world_size,
|
||||
dp_size=prepare_finalize.dp_size,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
)
|
||||
else:
|
||||
experts = TritonOrDeepGemmExperts(
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
assert experts is not None
|
||||
return experts
|
||||
|
||||
def apply(
|
||||
@ -797,6 +823,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -808,6 +835,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
@ -855,7 +883,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
|
||||
@ -154,6 +154,21 @@ class CudaPlatformBase(Platform):
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLA backend.")
|
||||
|
||||
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
|
||||
and parallel_config.data_parallel_size > 1
|
||||
and vllm_config.compilation_config.use_cudagraph):
|
||||
logger.info(
|
||||
"Data Parallel: Forcing enforce eager to be True since DP "
|
||||
"with DeepEP high-throughput kernels are not CUDA Graph "
|
||||
"compatible. The DeepEP low-latency kernels are CUDA Graph "
|
||||
"compatible. Set the all_to_all backend to deepep_low_latency "
|
||||
"to use those kernels instead.")
|
||||
vllm_config.compilation_config.use_cudagraph = False
|
||||
vllm_config.model_config.enforce_eager = True
|
||||
# TODO (varun): Turning this ON gives incorrect results for the
|
||||
# Deepseek-V2-lite model.
|
||||
vllm_config.compilation_config.use_inductor = False
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user