mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 03:55:01 +08:00
[Kernel] Integrate batched/masked deepgemm kernel (#19111)
Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
parent
ef3f98b59f
commit
c3fd4d669a
@ -162,12 +162,14 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
|
|||||||
low_latency_mode=True,
|
low_latency_mode=True,
|
||||||
num_qps_per_rank=deepep_ll_args.num_experts //
|
num_qps_per_rank=deepep_ll_args.num_experts //
|
||||||
pgi.world_size)
|
pgi.world_size)
|
||||||
|
|
||||||
return DeepEPLLPrepareAndFinalize(
|
return DeepEPLLPrepareAndFinalize(
|
||||||
buffer=buffer,
|
buffer=buffer,
|
||||||
world_size=pgi.world_size,
|
world_size=pgi.world_size,
|
||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
||||||
quant_dtype=q_dtype,
|
quant_dtype=q_dtype,
|
||||||
|
block_shape=block_shape,
|
||||||
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -185,4 +187,5 @@ def make_deepep_a2a(pg: ProcessGroup,
|
|||||||
block_shape)
|
block_shape)
|
||||||
|
|
||||||
assert deepep_ll_args is not None
|
assert deepep_ll_args is not None
|
||||||
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype)
|
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
|
||||||
|
block_shape)
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""
|
"""
|
||||||
Test DeepEP + DeepGEMM integration
|
Test DeepEP + DeepGEMM integration
|
||||||
|
DeepGEMM are gemm kernels specialized for the
|
||||||
|
fp8 block-quantized case.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
@ -33,10 +35,14 @@ except ImportError:
|
|||||||
if has_deep_ep:
|
if has_deep_ep:
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||||
DeepEPHTPrepareAndFinalize)
|
DeepEPHTPrepareAndFinalize)
|
||||||
|
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||||
|
DeepEPLLPrepareAndFinalize)
|
||||||
|
|
||||||
from .deepep_utils import DeepEPHTArgs, make_deepep_a2a
|
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||||
|
|
||||||
if has_deep_gemm:
|
if has_deep_gemm:
|
||||||
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
|
BatchedDeepGemmExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
DeepGemmExperts)
|
DeepGemmExperts)
|
||||||
|
|
||||||
@ -53,6 +59,13 @@ requires_deep_gemm = pytest.mark.skipif(
|
|||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
|
def next_power_of_2(x):
|
||||||
|
import math
|
||||||
|
if x == 0:
|
||||||
|
return 1
|
||||||
|
return 2**math.ceil(math.log2(x))
|
||||||
|
|
||||||
|
|
||||||
def per_block_cast_to_fp8(
|
def per_block_cast_to_fp8(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -126,6 +139,9 @@ class TestConfig:
|
|||||||
n: int
|
n: int
|
||||||
num_experts: int
|
num_experts: int
|
||||||
block_size: list[int]
|
block_size: list[int]
|
||||||
|
# configs for testing low-latency kernels
|
||||||
|
low_latency: bool
|
||||||
|
use_fp8_dispatch: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -170,9 +186,43 @@ class TestTensors:
|
|||||||
config=config)
|
config=config)
|
||||||
|
|
||||||
|
|
||||||
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||||
num_local_experts: int, q_dtype: Optional[torch.dtype],
|
max_tokens_per_rank: int, dp_size: int,
|
||||||
block_shape: list[int]) -> FusedMoEModularKernel:
|
hidden_size: int, q_dtype: Optional[torch.dtype],
|
||||||
|
test_config: TestConfig) -> FusedMoEModularKernel:
|
||||||
|
|
||||||
|
assert test_config.low_latency
|
||||||
|
assert test_config.use_fp8_dispatch is not None
|
||||||
|
|
||||||
|
a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
|
||||||
|
pg=pg,
|
||||||
|
pgi=pgi,
|
||||||
|
dp_size=dp_size,
|
||||||
|
deepep_ht_args=None,
|
||||||
|
deepep_ll_args=DeepEPLLArgs(
|
||||||
|
max_tokens_per_rank=max_tokens_per_rank,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_experts=test_config.num_experts,
|
||||||
|
use_fp8_dispatch=test_config.use_fp8_dispatch),
|
||||||
|
q_dtype=q_dtype,
|
||||||
|
block_shape=test_config.block_size)
|
||||||
|
|
||||||
|
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
|
||||||
|
world_size=pgi.world_size,
|
||||||
|
dp_size=dp_size,
|
||||||
|
block_shape=test_config.block_size)
|
||||||
|
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||||
|
fused_experts=fused_experts)
|
||||||
|
return mk
|
||||||
|
|
||||||
|
|
||||||
|
def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||||
|
dp_size: int, num_local_experts: int,
|
||||||
|
q_dtype: Optional[torch.dtype],
|
||||||
|
test_config: TestConfig) -> FusedMoEModularKernel:
|
||||||
|
|
||||||
|
assert not test_config.low_latency
|
||||||
|
assert test_config.use_fp8_dispatch is None
|
||||||
|
|
||||||
a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
|
a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
|
||||||
pg=pg,
|
pg=pg,
|
||||||
@ -181,7 +231,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
|||||||
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
|
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
|
||||||
deepep_ll_args=None,
|
deepep_ll_args=None,
|
||||||
q_dtype=q_dtype,
|
q_dtype=q_dtype,
|
||||||
block_shape=block_shape)
|
block_shape=test_config.block_size)
|
||||||
|
|
||||||
fused_experts = DeepGemmExperts()
|
fused_experts = DeepGemmExperts()
|
||||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||||
@ -189,12 +239,42 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
|||||||
return mk
|
return mk
|
||||||
|
|
||||||
|
|
||||||
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||||
test_tensors: TestTensors, w1: torch.Tensor,
|
num_local_experts: int,
|
||||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
test_tensors: TestTensors) -> FusedMoEModularKernel:
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
num_experts: int) -> torch.Tensor:
|
|
||||||
|
|
||||||
|
q_dtype = torch.float8_e4m3fn
|
||||||
|
test_config = test_tensors.config
|
||||||
|
|
||||||
|
mk: FusedMoEModularKernel
|
||||||
|
# Make modular kernel
|
||||||
|
if test_config.low_latency:
|
||||||
|
max_tokens_per_rank = max(
|
||||||
|
64, next_power_of_2(test_tensors.rank_tokens.size(0)))
|
||||||
|
hidden_size = test_tensors.rank_tokens.size(-1)
|
||||||
|
|
||||||
|
mk = make_ll_modular_kernel(pg=pg,
|
||||||
|
pgi=pgi,
|
||||||
|
max_tokens_per_rank=max_tokens_per_rank,
|
||||||
|
dp_size=dp_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
q_dtype=q_dtype,
|
||||||
|
test_config=test_config)
|
||||||
|
else:
|
||||||
|
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
|
||||||
|
q_dtype, test_config)
|
||||||
|
|
||||||
|
return mk
|
||||||
|
|
||||||
|
|
||||||
|
def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||||
|
dp_size: int, test_tensors: TestTensors,
|
||||||
|
w1: torch.Tensor, w2: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
|
||||||
|
test_config = test_tensors.config
|
||||||
|
num_experts = test_config.num_experts
|
||||||
num_local_experts = w1.size(0)
|
num_local_experts = w1.size(0)
|
||||||
|
|
||||||
def build_expert_map():
|
def build_expert_map():
|
||||||
@ -208,14 +288,17 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
|||||||
return expert_map.to(device=torch.cuda.current_device(),
|
return expert_map.to(device=torch.cuda.current_device(),
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
q_dtype = torch.float8_e4m3fn
|
|
||||||
|
|
||||||
# Make modular kernel
|
# Make modular kernel
|
||||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||||
pg, pgi, dp_size, num_local_experts, q_dtype,
|
pg=pg,
|
||||||
test_tensors.config.block_size)
|
pgi=pgi,
|
||||||
|
dp_size=dp_size,
|
||||||
|
num_local_experts=num_local_experts,
|
||||||
|
test_tensors=test_tensors)
|
||||||
|
|
||||||
a1_scale = test_tensors.rank_token_scales
|
# Low-Latency kernels can't dispatch scales.
|
||||||
|
a1_scale = (None
|
||||||
|
if test_config.low_latency else test_tensors.rank_token_scales)
|
||||||
|
|
||||||
out = mk.forward(hidden_states=test_tensors.rank_tokens,
|
out = mk.forward(hidden_states=test_tensors.rank_tokens,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
@ -258,7 +341,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
|||||||
allow_deep_gemm=False)
|
allow_deep_gemm=False)
|
||||||
|
|
||||||
|
|
||||||
def _deep_ep_moe(
|
def _test_deepep_deepgemm_moe(
|
||||||
pgi: ProcessGroupInfo,
|
pgi: ProcessGroupInfo,
|
||||||
dp_size: int,
|
dp_size: int,
|
||||||
config: TestConfig,
|
config: TestConfig,
|
||||||
@ -302,7 +385,7 @@ def _deep_ep_moe(
|
|||||||
w1_scale_ep = w1_scale[e_start:e_end]
|
w1_scale_ep = w1_scale[e_start:e_end]
|
||||||
w2_scale_ep = w2_scale[e_start:e_end]
|
w2_scale_ep = w2_scale[e_start:e_end]
|
||||||
|
|
||||||
deepep_moe = deep_ep_moe_impl(
|
deepep_moe = deepep_deepgemm_moe_impl(
|
||||||
pg,
|
pg,
|
||||||
pgi,
|
pgi,
|
||||||
dp_size,
|
dp_size,
|
||||||
@ -311,7 +394,6 @@ def _deep_ep_moe(
|
|||||||
w2_ep,
|
w2_ep,
|
||||||
w1_scale_ep,
|
w1_scale_ep,
|
||||||
w2_scale_ep,
|
w2_scale_ep,
|
||||||
config.num_experts,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
@ -335,15 +417,21 @@ MNKs = [
|
|||||||
(222, 1024, 2048),
|
(222, 1024, 2048),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
TOPKS = [2, 6]
|
||||||
|
NUM_EXPERTS = [32]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("mnk", MNKs)
|
@pytest.mark.parametrize("mnk", MNKs)
|
||||||
@pytest.mark.parametrize("num_experts", [32])
|
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", [2, 6])
|
@pytest.mark.parametrize("topk", TOPKS)
|
||||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||||
@requires_deep_ep
|
@requires_deep_ep
|
||||||
@requires_deep_gemm
|
@requires_deep_gemm
|
||||||
def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
|
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||||
world_dp_size: tuple[int, int]):
|
topk: int, world_dp_size: tuple[int, int]):
|
||||||
|
"""
|
||||||
|
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||||
|
"""
|
||||||
|
|
||||||
m, n, k = mnk
|
m, n, k = mnk
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
@ -354,6 +442,58 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
|
|||||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||||
block_size = [block_m, block_m]
|
block_size = [block_m, block_m]
|
||||||
|
|
||||||
|
world_size, dp_size = world_dp_size
|
||||||
|
config = TestConfig(topk=topk,
|
||||||
|
m=m,
|
||||||
|
k=k,
|
||||||
|
n=n,
|
||||||
|
num_experts=num_experts,
|
||||||
|
block_size=block_size,
|
||||||
|
low_latency=False,
|
||||||
|
use_fp8_dispatch=None)
|
||||||
|
|
||||||
|
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||||
|
num_experts, n, k, block_size)
|
||||||
|
|
||||||
|
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
|
||||||
|
w2, w1_scale, w2_scale)
|
||||||
|
|
||||||
|
|
||||||
|
MNKs = [
|
||||||
|
(1, 128, 2560),
|
||||||
|
(2, 128, 2560),
|
||||||
|
(3, 1024, 2560),
|
||||||
|
(32, 128, 2560),
|
||||||
|
(45, 512, 2560),
|
||||||
|
(64, 1024, 2560),
|
||||||
|
(222, 1024, 2560),
|
||||||
|
]
|
||||||
|
# Fix tests for USE_FP8_DISPATCH=True
|
||||||
|
USE_FP8_DISPATCH = [False]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("mnk", MNKs)
|
||||||
|
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOPKS)
|
||||||
|
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
|
||||||
|
@pytest.mark.parametrize("block_size", [[128, 128]])
|
||||||
|
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||||
|
@requires_deep_ep
|
||||||
|
@requires_deep_gemm
|
||||||
|
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
|
||||||
|
int], num_experts: int, topk: int,
|
||||||
|
use_fp8_dispatch: bool, block_size: list[int],
|
||||||
|
world_dp_size: tuple[int, int]):
|
||||||
|
"""
|
||||||
|
Tests for Low-Latency DeepEP + DeepGemm integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
m, n, k = mnk
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
|
||||||
|
if topk > num_experts:
|
||||||
|
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
|
||||||
|
|
||||||
world_size, dp_size = world_dp_size
|
world_size, dp_size = world_dp_size
|
||||||
config = TestConfig(
|
config = TestConfig(
|
||||||
topk=topk,
|
topk=topk,
|
||||||
@ -362,10 +502,12 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
|
|||||||
n=n,
|
n=n,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
|
low_latency=True,
|
||||||
|
use_fp8_dispatch=use_fp8_dispatch,
|
||||||
)
|
)
|
||||||
|
|
||||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||||
num_experts, n, k, block_size)
|
num_experts, n, k, block_size)
|
||||||
|
|
||||||
parallel_launch(world_size, _deep_ep_moe, dp_size, config, w1, w2,
|
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
|
||||||
w1_scale, w2_scale)
|
w2, w1_scale, w2_scale)
|
||||||
|
|||||||
124
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Normal file
124
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import importlib.util
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
|
_resize_cache, per_token_group_quant_fp8)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
|
# The Deep Gemm kernels only support block size of 128
|
||||||
|
DEEPGEMM_BLOCK_SHAPE = 128
|
||||||
|
|
||||||
|
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
|
||||||
|
block_shape: list[int]):
|
||||||
|
"""
|
||||||
|
max_num_tokens: Maximum number of tokens from a DP Rank
|
||||||
|
world_size: Number of EP ranks
|
||||||
|
dp_size: Number of data-parallel ranks
|
||||||
|
block_shape: Block quantization block shape
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.max_num_tokens = max_num_tokens
|
||||||
|
self.world_size = world_size
|
||||||
|
self.dp_size = dp_size
|
||||||
|
self.block_shape = block_shape
|
||||||
|
|
||||||
|
assert (len(self.block_shape) == 2 and all(
|
||||||
|
[v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape]))
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[int, int, torch.dtype]:
|
||||||
|
assert a.dim() == 2
|
||||||
|
num_dp = self.world_size // self.dp_size
|
||||||
|
max_num_tokens = a.size(
|
||||||
|
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||||
|
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
|
||||||
|
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
|
||||||
|
return (workspace13, workspace2, a.dtype)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
import deep_gemm as dg
|
||||||
|
assert hidden_states.ndim == 3
|
||||||
|
|
||||||
|
a1q = hidden_states
|
||||||
|
_, N, K = w1.size()
|
||||||
|
|
||||||
|
if global_num_experts == -1:
|
||||||
|
global_num_experts = w1.size(0)
|
||||||
|
|
||||||
|
assert w2.size(1) == K
|
||||||
|
|
||||||
|
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||||
|
hidden_states, w1, w2, topk_ids)
|
||||||
|
|
||||||
|
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||||
|
workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
|
||||||
|
workspace3 = _resize_cache(workspace13, (E, max_num_tokens, K))
|
||||||
|
|
||||||
|
# (from deepgemm docs) : A value hint (which is a value on CPU)
|
||||||
|
# for the M expectation of each batch, correctly setting this value
|
||||||
|
# may lead to better performance.
|
||||||
|
expected_m = max_num_tokens
|
||||||
|
|
||||||
|
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale),
|
||||||
|
(w1, w1_scale),
|
||||||
|
out=workspace1,
|
||||||
|
masked_m=expert_num_tokens,
|
||||||
|
expected_m=expected_m)
|
||||||
|
|
||||||
|
# TODO (varun) [Optimization]: Use a batched version of activation.
|
||||||
|
# Similarly for the quant below.
|
||||||
|
self.activation(activation, workspace2, workspace1.view(-1, N))
|
||||||
|
|
||||||
|
w2_hidden_size = workspace2.size(-1)
|
||||||
|
workspace2 = workspace2.view(-1, w2_hidden_size)
|
||||||
|
|
||||||
|
a2q_scale: Optional[torch.Tensor] = None
|
||||||
|
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
|
||||||
|
self.block_shape[1],
|
||||||
|
column_major_scales=False)
|
||||||
|
a2q = a2q.view(E, max_num_tokens, -1)
|
||||||
|
a2q_scale = a2q_scale.view(E, max_num_tokens, -1)
|
||||||
|
|
||||||
|
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
|
||||||
|
(w2, w2_scale),
|
||||||
|
out=workspace3,
|
||||||
|
masked_m=expert_num_tokens,
|
||||||
|
expected_m=expected_m)
|
||||||
|
|
||||||
|
return workspace3
|
||||||
@ -0,0 +1,116 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
|
BatchedDeepGemmExperts)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
|
BatchedTritonExperts)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
max_num_tokens: int,
|
||||||
|
world_size: int,
|
||||||
|
dp_size: int,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
allow_deep_gemm: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
assert not use_int8_w8a8, "NYI"
|
||||||
|
assert not use_int8_w8a16, "NYI"
|
||||||
|
assert not use_int4_w4a16, "NYI"
|
||||||
|
|
||||||
|
self.max_num_tokens = max_num_tokens
|
||||||
|
self.world_size = world_size
|
||||||
|
self.dp_size = dp_size
|
||||||
|
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||||
|
self.use_int8_w8a8 = use_int8_w8a8
|
||||||
|
self.use_int8_w8a16 = use_int8_w8a16
|
||||||
|
self.use_int4_w4a16 = use_int4_w4a16
|
||||||
|
self.per_channel_quant = per_channel_quant
|
||||||
|
self.block_shape = block_shape
|
||||||
|
self.allow_deep_gemm = allow_deep_gemm
|
||||||
|
|
||||||
|
# BatchedTritonKernel doesn't support block quantization
|
||||||
|
# at the moment.
|
||||||
|
self.batched_triton_experts = BatchedTritonExperts(
|
||||||
|
max_num_tokens=self.max_num_tokens,
|
||||||
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=self.use_int8_w8a8,
|
||||||
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
|
per_channel_quant=self.per_channel_quant,
|
||||||
|
block_shape=self.block_shape,
|
||||||
|
world_size=self.world_size,
|
||||||
|
dp_size=self.dp_size) if self.block_shape is None else None
|
||||||
|
|
||||||
|
is_fp8_128_block_quantized = (self.use_fp8_w8a8
|
||||||
|
and self.block_shape is not None
|
||||||
|
and len(self.block_shape) == 2 and all(
|
||||||
|
[b == 128
|
||||||
|
for b in self.block_shape]))
|
||||||
|
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
|
||||||
|
max_num_tokens=self.max_num_tokens,
|
||||||
|
world_size=self.world_size,
|
||||||
|
dp_size=self.dp_size,
|
||||||
|
block_shape=self.block_shape, # type: ignore[arg-type]
|
||||||
|
) if (self.allow_deep_gemm and is_fp8_128_block_quantized) else None
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[int, int, torch.dtype]:
|
||||||
|
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||||
|
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||||
|
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||||
|
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
|
||||||
|
return self.batched_deep_gemm_experts.workspace_shapes(
|
||||||
|
a, M, N, K, topk, num_experts)
|
||||||
|
else:
|
||||||
|
assert self.batched_triton_experts is not None
|
||||||
|
return self.batched_triton_experts.workspace_shapes(
|
||||||
|
a, M, N, K, topk, num_experts)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
use_batched_deep_gemm_experts = (self.allow_deep_gemm
|
||||||
|
and self.batched_deep_gemm_experts
|
||||||
|
is not None)
|
||||||
|
experts = (self.batched_deep_gemm_experts
|
||||||
|
if use_batched_deep_gemm_experts else
|
||||||
|
self.batched_triton_experts)
|
||||||
|
assert experts is not None
|
||||||
|
return experts.apply(hidden_states, w1, w2, topk_ids, activation,
|
||||||
|
global_num_experts, expert_map, w1_scale,
|
||||||
|
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale,
|
||||||
|
workspace13, workspace2, expert_num_tokens)
|
||||||
@ -1,5 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import deep_ep
|
import deep_ep
|
||||||
import torch
|
import torch
|
||||||
@ -65,6 +65,54 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||||
return torch.int64
|
return torch.int64
|
||||||
|
|
||||||
|
def _do_quant(
|
||||||
|
self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
|
||||||
|
a1_dtype: torch.dtype
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
|
||||||
|
block_k = self.block_shape[1] if self.block_shape is not None else None
|
||||||
|
if self.use_fp8_dispatch:
|
||||||
|
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
|
||||||
|
# DeepEP kernels did the quantization for us.
|
||||||
|
x, x_scales = x
|
||||||
|
return x, x_scales
|
||||||
|
|
||||||
|
# Dequant to get back the tokens in the datatype we dispatched in.
|
||||||
|
x_fp8, x_scales = x
|
||||||
|
x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)
|
||||||
|
|
||||||
|
assert isinstance(x, torch.Tensor)
|
||||||
|
|
||||||
|
# Check if there is a block_shape / or if we can infer the quantization
|
||||||
|
# schemes from the scales.
|
||||||
|
per_token_quant = None
|
||||||
|
if all([v is None for v in [self.block_shape, a1_scale, a2_scale]
|
||||||
|
]) and self.quant_dtype is not None:
|
||||||
|
# Quantization required despite none of the inputs suggesting
|
||||||
|
# quantization. Fallback to per_token_dynamic quant.
|
||||||
|
per_token_quant = True
|
||||||
|
else:
|
||||||
|
per_token_quant = ((self.block_shape is not None) or
|
||||||
|
(a1_scale is not None and a1_scale.numel() != 1)
|
||||||
|
or (a2_scale is not None
|
||||||
|
and a2_scale.numel() != 1))
|
||||||
|
|
||||||
|
num_experts, max_tokens, hidden_dim = x.size()
|
||||||
|
|
||||||
|
# TODO (varun): Optimization - Use a batched version of quant
|
||||||
|
x = x.view((-1, hidden_dim))
|
||||||
|
x, x_scales = moe_kernel_quantize_input(x, a1_scale, self.quant_dtype,
|
||||||
|
per_token_quant,
|
||||||
|
self.block_shape)
|
||||||
|
x = x.view((num_experts, -1, hidden_dim))
|
||||||
|
|
||||||
|
if per_token_quant:
|
||||||
|
assert x_scales is not None
|
||||||
|
x_scales = x_scales.view(num_experts, max_tokens, -1)
|
||||||
|
|
||||||
|
return x, x_scales
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
@ -87,11 +135,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
assert hidden_size % 128 == 0, \
|
assert hidden_size % 128 == 0, \
|
||||||
"DeepEP kernels quantize the inputs in blocks of shape 128"
|
"DeepEP kernels quantize the inputs in blocks of shape 128"
|
||||||
|
|
||||||
# Quantize
|
has_per_token_scales = a1_scale.numel(
|
||||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
) != 1 if a1_scale is not None else (
|
||||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||||
assert not per_act_token, (
|
assert not has_per_token_scales, (
|
||||||
"low_latency kernels don't support per-act-token quant")
|
"low_latency kernels doesn't support dispatching per-token scales")
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
topk = rank_topk_ids.size(1)
|
topk = rank_topk_ids.size(1)
|
||||||
@ -110,22 +158,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
async_finish=False,
|
async_finish=False,
|
||||||
return_recv_hook=False)
|
return_recv_hook=False)
|
||||||
|
|
||||||
if self.use_fp8_dispatch:
|
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
|
||||||
# TODO (varun) : In the case of dynamic quantization, we could
|
a1.dtype)
|
||||||
# probably skip the quant below and use the results directly.
|
|
||||||
# Although note that the deepep quant is per token 128 elements.
|
|
||||||
expert_x_fp8, expert_x_scales = expert_x
|
|
||||||
expert_x = dequant_fp8(expert_x_fp8,
|
|
||||||
expert_x_scales).to(dtype=a1.dtype)
|
|
||||||
|
|
||||||
num_experts = expert_x.size(0)
|
|
||||||
hidden_dim = expert_x.size(-1)
|
|
||||||
|
|
||||||
expert_x = expert_x.view((-1, expert_x.size(-1)))
|
|
||||||
expert_x, expert_x_scale = moe_kernel_quantize_input(
|
|
||||||
expert_x, a1_scale, self.quant_dtype, per_act_token,
|
|
||||||
self.block_shape)
|
|
||||||
expert_x = expert_x.view((num_experts, -1, hidden_dim))
|
|
||||||
|
|
||||||
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
|
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
|
||||||
|
|
||||||
|
|||||||
@ -771,21 +771,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def select_gemm_impl(self, prepare_finalize):
|
def select_gemm_impl(self, prepare_finalize):
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||||
BatchedTritonExperts)
|
BatchedTritonOrDeepGemmExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||||
TritonOrDeepGemmExperts)
|
TritonOrDeepGemmExperts)
|
||||||
|
|
||||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||||
|
|
||||||
experts: Optional[Union[BatchedTritonExperts,
|
experts: Optional[Union[BatchedTritonOrDeepGemmExperts,
|
||||||
TritonOrDeepGemmExperts]] = None
|
TritonOrDeepGemmExperts]] = None
|
||||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||||
use_batched_experts = max_num_tokens_per_rank is not None
|
use_batched_experts = max_num_tokens_per_rank is not None
|
||||||
|
|
||||||
if use_batched_experts:
|
if use_batched_experts:
|
||||||
experts = BatchedTritonExperts(
|
experts = BatchedTritonOrDeepGemmExperts(
|
||||||
max_num_tokens=max_num_tokens_per_rank,
|
max_num_tokens=max_num_tokens_per_rank,
|
||||||
world_size=prepare_finalize.world_size,
|
world_size=prepare_finalize.world_size,
|
||||||
dp_size=prepare_finalize.dp_size,
|
dp_size=prepare_finalize.dp_size,
|
||||||
@ -793,7 +793,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
use_int8_w8a8=False,
|
use_int8_w8a8=False,
|
||||||
use_int8_w8a16=False,
|
use_int8_w8a16=False,
|
||||||
use_int4_w4a16=False,
|
use_int4_w4a16=False,
|
||||||
block_shape=None,
|
per_channel_quant=False,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
experts = TritonOrDeepGemmExperts(
|
experts = TritonOrDeepGemmExperts(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user