mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 10:57:16 +08:00
[Kernels] Modular kernel refactor (#24812)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
f08919b7d1
commit
da364615fc
@ -209,18 +209,18 @@ class Config:
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend
|
||||
|
||||
def is_valid(self):
|
||||
def is_valid(self) -> tuple[bool, Optional[str]]:
|
||||
# Check prepare-finalize and fused-experts compatibility
|
||||
if self.is_batched_prepare_finalize():
|
||||
if not self.is_batched_fused_experts():
|
||||
return False
|
||||
return False, "Mismatched format."
|
||||
else:
|
||||
if not self.is_standard_fused_experts():
|
||||
return False
|
||||
return False, "Mismatched format."
|
||||
|
||||
use_chunking = self.fused_moe_chunk_size is not None
|
||||
if use_chunking and not self.is_fe_supports_chunking():
|
||||
return False
|
||||
return False, "Chunking not supported."
|
||||
|
||||
# Check quantization sanity
|
||||
if (
|
||||
@ -229,7 +229,7 @@ class Config:
|
||||
+ int(self.quant_block_shape is not None)
|
||||
) > 1:
|
||||
# invalid quant config
|
||||
return False
|
||||
return False, f"Bad quant_config {self.quant_config}."
|
||||
|
||||
# check type support
|
||||
if self.quant_dtype is None:
|
||||
@ -237,34 +237,43 @@ class Config:
|
||||
self.dtype not in self.pf_supported_types()
|
||||
or self.dtype not in self.fe_supported_types()
|
||||
):
|
||||
return False
|
||||
return False, (
|
||||
f"Unsupported type {self.dtype} not in "
|
||||
f"{self.pf_supported_types()} and "
|
||||
f"{self.fe_supported_types()}."
|
||||
)
|
||||
else:
|
||||
if (
|
||||
self.quant_dtype not in self.pf_supported_types()
|
||||
or self.quant_dtype not in self.fe_supported_types()
|
||||
):
|
||||
return False
|
||||
return False, (
|
||||
f"Unsupported quant type {self.quant_dtype} "
|
||||
f"not in {self.pf_supported_types()} and "
|
||||
f"{self.fe_supported_types()}."
|
||||
)
|
||||
|
||||
# Check block quanization support
|
||||
is_block_quatized = self.quant_block_shape is not None
|
||||
if is_block_quatized and self.quant_dtype is None:
|
||||
return False
|
||||
return False, "No block quantization support."
|
||||
|
||||
if is_block_quatized and not self.is_block_quant_supported():
|
||||
return False
|
||||
return False, "Mismatched block quantization support."
|
||||
|
||||
# deep_gemm only works with block-quantized
|
||||
if self.needs_deep_gemm() and not is_block_quatized:
|
||||
return False
|
||||
return False, "Needs DeepGEMM but not block quantized."
|
||||
|
||||
# Check dependencies (turn into asserts?)
|
||||
if self.needs_deep_ep() and not has_deep_ep():
|
||||
return False
|
||||
return False, "Needs DeepEP, but DeepEP not available."
|
||||
if self.needs_deep_gemm() and not has_deep_gemm():
|
||||
return False
|
||||
return False, "Needs DeepGEMM, but DeepGEMM not available."
|
||||
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
|
||||
return False
|
||||
return False, "Needs PPLX, but PPLX not available."
|
||||
|
||||
return True
|
||||
return True, None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -140,7 +140,7 @@ def make_feature_matrix(csv_file_path: str):
|
||||
)
|
||||
|
||||
success = None
|
||||
if config.is_valid():
|
||||
if config.is_valid()[0]:
|
||||
print(f"Running config : {config.describe()} ...")
|
||||
try:
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
|
||||
@ -244,7 +244,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
|
||||
register_prepare_and_finalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
standard_format,
|
||||
nvfp4_types,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
backend=None,
|
||||
force_multigpu=True,
|
||||
@ -254,7 +254,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
|
||||
register_experts(
|
||||
FlashInferExperts,
|
||||
standard_format,
|
||||
nvfp4_types,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
# Note: this is a hack to get it to run for now
|
||||
@ -274,17 +274,15 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
(
|
||||
register_experts(
|
||||
DeepGemmExperts,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
),
|
||||
register_experts(
|
||||
DeepGemmExperts,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
@ -464,7 +462,7 @@ def make_fused_experts(
|
||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == DeepGemmExperts:
|
||||
print("Making DeepGemmExperts {quant_config} ...")
|
||||
print(f"Making DeepGemmExperts {quant_config} ...")
|
||||
experts = DeepGemmExperts(quant_config)
|
||||
elif fused_experts_type == TritonExperts:
|
||||
kwargs = quant_kwargs
|
||||
|
||||
@ -5,7 +5,7 @@ import copy
|
||||
import textwrap
|
||||
import traceback
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -13,10 +13,9 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .modular_kernel_tools.common import (
|
||||
Config,
|
||||
RankTensors,
|
||||
@ -132,7 +131,8 @@ def rank_worker(
|
||||
|
||||
|
||||
def run(config: Config, verbose: bool):
|
||||
assert config.is_valid()
|
||||
assert config.is_valid()[0]
|
||||
assert not is_nyi_config(config)
|
||||
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
|
||||
@ -168,17 +168,77 @@ def is_nyi_config(config: Config) -> bool:
|
||||
return not info.supports_expert_map
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", Ks)
|
||||
@pytest.mark.parametrize("n", Ns)
|
||||
@pytest.mark.parametrize("e", Es)
|
||||
@pytest.mark.parametrize("dtype", DTYPEs)
|
||||
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
|
||||
def generate_valid_test_cases(
|
||||
world_size: int, prepare_finalize_types
|
||||
) -> list[tuple[Any, ...]]:
|
||||
cases = []
|
||||
total = 0
|
||||
|
||||
for k, n, e, dtype, quant_config, combination, chunk_size in product(
|
||||
Ks,
|
||||
Ns,
|
||||
Es,
|
||||
DTYPEs,
|
||||
MK_QUANT_CONFIGS,
|
||||
product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
|
||||
FUSED_MOE_CHUNK_SIZEs,
|
||||
):
|
||||
total = total + 1
|
||||
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=combination[0],
|
||||
fused_experts_type=combination[1],
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
# TODO(bnell): figure out how to get verbose flag here.
|
||||
verbose = False # pytestconfig.getoption('verbose') > 0
|
||||
|
||||
valid, reason = config.is_valid()
|
||||
|
||||
if not valid:
|
||||
if verbose:
|
||||
print(f"Test config {config} is not valid: {reason}")
|
||||
continue
|
||||
|
||||
if is_nyi_config(config):
|
||||
if verbose:
|
||||
print(f"Test config {config} is nyi.")
|
||||
continue
|
||||
|
||||
cases.append(
|
||||
(
|
||||
k,
|
||||
n,
|
||||
e,
|
||||
dtype,
|
||||
quant_config,
|
||||
combination[0],
|
||||
combination[1],
|
||||
chunk_size,
|
||||
world_size,
|
||||
)
|
||||
)
|
||||
|
||||
print(f"{len(cases)} of {total} valid configs generated.")
|
||||
|
||||
return cases
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
|
||||
generate_valid_test_cases(
|
||||
world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@meets_multi_gpu_requirements
|
||||
def test_modular_kernel_combinations_multigpu(
|
||||
k: int,
|
||||
@ -186,13 +246,19 @@ def test_modular_kernel_combinations_multigpu(
|
||||
e: int,
|
||||
dtype: torch.dtype,
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
combination: tuple[
|
||||
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
|
||||
],
|
||||
fused_moe_chunk_size: Optional[int],
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
chunk_size: Optional[int],
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
if cuda_device_count_stateless() < world_size:
|
||||
pytest.skip(
|
||||
f"Not enough GPUs available to run, got "
|
||||
f"{cuda_device_count_stateless()} exepected "
|
||||
f"{world_size}."
|
||||
)
|
||||
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu(
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=combination[0],
|
||||
fused_experts_type=combination[1],
|
||||
fused_moe_chunk_size=fused_moe_chunk_size,
|
||||
prepare_finalize_type=prepare_finalize_type,
|
||||
fused_experts_type=fused_experts_type,
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
if not config.is_valid():
|
||||
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
|
||||
|
||||
if is_nyi_config(config):
|
||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||
|
||||
verbosity = pytestconfig.getoption("verbose")
|
||||
run(config, verbosity > 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", Ks)
|
||||
@pytest.mark.parametrize("n", Ns)
|
||||
@pytest.mark.parametrize("e", Es)
|
||||
@pytest.mark.parametrize("dtype", DTYPEs)
|
||||
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
|
||||
@pytest.mark.parametrize(
|
||||
"combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
|
||||
generate_valid_test_cases(
|
||||
world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||
@pytest.mark.parametrize("world_size", [1])
|
||||
def test_modular_kernel_combinations_singlegpu(
|
||||
k: int,
|
||||
n: int,
|
||||
e: int,
|
||||
dtype: torch.dtype,
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
combination: tuple[
|
||||
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
|
||||
],
|
||||
fused_moe_chunk_size: Optional[int],
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
chunk_size: Optional[int],
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=combination[0],
|
||||
fused_experts_type=combination[1],
|
||||
fused_moe_chunk_size=fused_moe_chunk_size,
|
||||
prepare_finalize_type=prepare_finalize_type,
|
||||
fused_experts_type=fused_experts_type,
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
if not config.is_valid():
|
||||
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
|
||||
|
||||
if is_nyi_config(config):
|
||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||
|
||||
verbosity = pytestconfig.getoption("verbose")
|
||||
run(config, verbosity > 0)
|
||||
|
||||
|
||||
@ -247,29 +247,24 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# FIXME (varun): We should be able to dispatch only from the leader
|
||||
# DP ranks in the case of TP > 1. At the moment, all the Ranks
|
||||
# end up sending their tokens. This needs to be fixed.
|
||||
num_dispatchers = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = (
|
||||
a.size(0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
)
|
||||
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
|
||||
workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
|
||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||
return (workspace13, workspace2, output, a.dtype)
|
||||
return (workspace13, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -300,7 +295,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
assert w2.size(1) == K
|
||||
|
||||
E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
|
||||
E, max_num_tokens, N, K, _ = self.moe_problem_size(
|
||||
hidden_states, w1, w2, topk_ids
|
||||
)
|
||||
|
||||
|
||||
@ -99,10 +99,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert bte_war is not None
|
||||
return bte_war
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return act_dtype
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -110,15 +111,13 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm:
|
||||
assert self.batched_deep_gemm_experts is not None
|
||||
return self.batched_deep_gemm_experts.workspace_shapes(
|
||||
a,
|
||||
aq,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@ -130,8 +129,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
else:
|
||||
assert self.batched_triton_experts is not None
|
||||
return self.batched_triton_experts.workspace_shapes(
|
||||
a,
|
||||
aq,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
|
||||
@ -366,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
# topk weights and reduction are fused in moe_unpermute cuda kernel
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return self.out_dtype if self.out_dtype is not None else act_dtype
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -377,16 +378,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
workspace2 = (M * topk, max(N // 2, K))
|
||||
output = (M, K)
|
||||
return (
|
||||
workspace1,
|
||||
workspace2,
|
||||
output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype,
|
||||
)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
|
||||
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
@ -428,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return self.out_dtype if self.out_dtype is not None else act_dtype
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -440,19 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
padded_M = aq.size(1)
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
assert num_dp is not None
|
||||
workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, max(N // 2, K))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
return (
|
||||
workspace1,
|
||||
workspace2,
|
||||
output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype,
|
||||
)
|
||||
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K))
|
||||
output = (self.max_experts_per_worker, M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
|
||||
def cutlass_moe_fp8(
|
||||
@ -767,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return self.out_dtype if self.out_dtype is not None else act_dtype
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -778,25 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1: tuple[int, ...] = ()
|
||||
workspace2: tuple[int, ...] = ()
|
||||
output: tuple[int, ...] = ()
|
||||
if self.use_batched_format:
|
||||
padded_M = aq.size(1)
|
||||
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
workspace1 = (self.max_experts_per_worker, M, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, M, (N // 2))
|
||||
output = (self.max_experts_per_worker, M, K)
|
||||
else:
|
||||
workspace1 = (M * topk, max(2 * N, K))
|
||||
workspace2 = (M * topk, N)
|
||||
output = (M, K)
|
||||
return (
|
||||
workspace1,
|
||||
workspace2,
|
||||
output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype,
|
||||
)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -198,8 +198,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -207,7 +205,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
assert self.block_shape is not None
|
||||
block_m = self.block_shape[0]
|
||||
M_sum = compute_aligned_M(
|
||||
@ -218,7 +216,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace1 = (M_sum, max(N, K))
|
||||
workspace2 = (M_sum, max(N // 2, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -70,6 +70,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@ -73,6 +73,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
@ -90,8 +90,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -99,7 +97,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
"""
|
||||
@ -118,14 +116,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
- Note: in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens.
|
||||
"""
|
||||
aq_m, aq_n = aq.shape
|
||||
workspace1 = (M, K)
|
||||
workspace2 = (0,)
|
||||
output_shape = (aq_m, aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m, aq_n)
|
||||
workspace_dtype = a.dtype
|
||||
workspace1 = output_shape
|
||||
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
|
||||
# The workspace is determined by `aq`, since it comes after any
|
||||
# potential communication op and is involved in the expert computation.
|
||||
return (workspace1, workspace2, output_shape, workspace_dtype)
|
||||
return (workspace1, workspace2, output_shape)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -11,6 +11,9 @@ from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
@ -45,6 +48,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def _apply_router_weight_on_input(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
@ -194,6 +200,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP)
|
||||
|
||||
if self.use_dp:
|
||||
fused_expert_output = get_dp_group().reduce_scatterv(
|
||||
fused_expert_output, dim=0, sizes=get_local_sizes()
|
||||
|
||||
@ -509,6 +509,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
@ -665,8 +668,6 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -674,14 +675,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
|
||||
workspace2 = (self.max_num_tokens * num_dp, N)
|
||||
output = workspace13
|
||||
return (workspace13, workspace2, output, a.dtype)
|
||||
return (workspace13, workspace2, output)
|
||||
|
||||
def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
assert self.quant_config.is_quantized
|
||||
@ -862,8 +862,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -871,15 +869,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = self.max_num_tokens
|
||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
|
||||
output = (num_experts, max_num_tokens * num_dp, K)
|
||||
return (workspace13, workspace2, output, a.dtype)
|
||||
return (workspace13, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -331,8 +331,6 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -340,7 +338,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Modular Kernel provisions output buffer from workspace1. However in
|
||||
# the fused_marlin_moe() function, the final torch.sum(), is defined
|
||||
# essentially as,
|
||||
@ -360,7 +358,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace2 = (M * topk * max(2 * N, K),)
|
||||
output = (M, K)
|
||||
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -1954,8 +1954,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -1963,11 +1961,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (M, topk, max(N // 2, K))
|
||||
workspace2 = (M, topk, max(N, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -255,8 +255,6 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -264,12 +262,12 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# workspace are allocated inside the kernel
|
||||
workspace1 = (M, K)
|
||||
workspace2 = (0, 0)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -283,6 +283,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def using_modular_kernel(self) -> bool:
|
||||
return self.fused_experts is not None
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
@ -1237,39 +1241,25 @@ class FusedMoE(CustomOp):
|
||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
|
||||
# TODO(bnell): flashinfer uses non-batched format.
|
||||
# Does it really need a batched buffer?
|
||||
if (
|
||||
self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_config.use_flashinfer_cutlass_kernels
|
||||
):
|
||||
if self.use_dp_chunking:
|
||||
states_shape: tuple[int, ...]
|
||||
logits_shape: tuple[int, ...]
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
if vllm_config.parallel_config.enable_dbo:
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(2, moe.max_num_tokens, self.hidden_size),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(2, moe.max_num_tokens, num_experts),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
states_shape = (2, moe.max_num_tokens, self.hidden_size)
|
||||
logits_shape = (2, moe.max_num_tokens, num_experts)
|
||||
else:
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(moe.max_num_tokens, self.hidden_size),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
states_shape = (moe.max_num_tokens, self.hidden_size)
|
||||
logits_shape = (moe.max_num_tokens, num_experts)
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(moe.max_num_tokens, num_experts),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
self.batched_router_logits = torch.zeros(
|
||||
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> Optional[torch.nn.Module]:
|
||||
@ -1323,6 +1313,16 @@ class FusedMoE(CustomOp):
|
||||
and self.moe_config.use_flashinfer_cutlass_kernels
|
||||
)
|
||||
|
||||
@property
|
||||
def use_dp_chunking(self) -> bool:
|
||||
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
||||
# only when data parallelism (DP) is enabled.
|
||||
return (
|
||||
self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
|
||||
)
|
||||
|
||||
def update_expert_map(self):
|
||||
# ep_size and ep_rank should already be updated
|
||||
assert self.expert_map is not None
|
||||
@ -1987,21 +1987,17 @@ class FusedMoE(CustomOp):
|
||||
Therefore it is required that we reduce the shared_experts output
|
||||
early.
|
||||
"""
|
||||
assert self.quant_method is not None
|
||||
return (
|
||||
self.use_pplx_kernels
|
||||
or self.use_deepep_ht_kernels
|
||||
or self.use_deepep_ll_kernels
|
||||
self.quant_method.fused_experts is not None
|
||||
and self.quant_method.fused_experts.output_is_reduced()
|
||||
)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
||||
"""
|
||||
The pplx combine kernel reduces across GPU ranks by default.
|
||||
Some combine kernels reduce across GPU ranks by default.
|
||||
"""
|
||||
if (
|
||||
self.use_pplx_kernels
|
||||
or self.use_deepep_ht_kernels
|
||||
or self.use_deepep_ll_kernels
|
||||
):
|
||||
if self.must_reduce_shared_expert_outputs():
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
@ -2209,23 +2205,11 @@ class FusedMoE(CustomOp):
|
||||
|
||||
self.ensure_moe_quant_config()
|
||||
|
||||
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
||||
# only when data parallelism (DP) is enabled.
|
||||
_use_flashinfer_cutlass_kernels = (
|
||||
self.dp_size > 1 and self.use_flashinfer_cutlass_kernels
|
||||
)
|
||||
|
||||
if (
|
||||
self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or _use_flashinfer_cutlass_kernels
|
||||
):
|
||||
if self.use_dp_chunking:
|
||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||
|
||||
do_naive_dispatch_combine: bool = (
|
||||
self.dp_size > 1
|
||||
and not self.moe_parallel_config.use_deepep_ht_kernels
|
||||
and not self.moe_config.use_flashinfer_cutlass_kernels
|
||||
self.dp_size > 1 and not self.quant_method.using_modular_kernel
|
||||
)
|
||||
|
||||
# If there are shared experts but we are not using a modular kernel, the
|
||||
|
||||
@ -337,6 +337,14 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of finalize is reduced across all
|
||||
ranks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: add supported activations method (return string)
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
@ -493,11 +501,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
Workspace type: The dtype to use for the workspace tensors.
|
||||
"""
|
||||
return act_dtype
|
||||
|
||||
@abstractmethod
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -505,22 +517,33 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
and activation in the fused expert function. Since the gemms are
|
||||
independent, the workspace for the first gemm can be shared with the
|
||||
workspace for the last gemm.
|
||||
|
||||
Inputs:
|
||||
- M: number of tokens.
|
||||
- N: Row (or column) dimension of expert weights.
|
||||
- K: hidden dimension
|
||||
- topk: The number of top-k experts to select.
|
||||
- global_num_experts: global number of experts.
|
||||
- local_num_experts: local number of experts due to DP/EP.
|
||||
- expert_tokens_meta: number of tokens per expert metadata for batched
|
||||
format.
|
||||
|
||||
Returns a tuple of:
|
||||
- workspace13 shape tuple: must be large enough to hold the
|
||||
result of either expert gemm.
|
||||
- workspace2 shape tuple: must be large enough to hold the
|
||||
result of the activation function.
|
||||
- output shape tuple: must be exact size of the final gemm output.
|
||||
- Workspace type: The dtype to use for the workspace tensors.
|
||||
- Note: in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens.
|
||||
- Note: workspace shapes can be 0 if the workspace is not needed.
|
||||
But in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens when the shape is
|
||||
not 0.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -561,7 +584,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
This function computes the intermediate result of a Mixture of Experts
|
||||
(MoE) layer using two sets of weights, w1 and w2.
|
||||
@ -600,7 +623,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _chunk_scales(
|
||||
def _slice_scales(
|
||||
scales: Optional[torch.Tensor], start: int, end: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
if scales is not None:
|
||||
@ -615,9 +638,10 @@ class SharedResizableBuffer:
|
||||
def __init__(self):
|
||||
self.buffer = None
|
||||
|
||||
def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype):
|
||||
if shape == () or shape is None:
|
||||
return None
|
||||
def get(
|
||||
self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
assert shape != ()
|
||||
shape_numel = prod(shape)
|
||||
if (
|
||||
self.buffer is None
|
||||
@ -678,131 +702,63 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
f"{fused_experts.activation_formats[0]}"
|
||||
)
|
||||
|
||||
def _do_fused_experts(
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
is reduced across all ranks.
|
||||
"""
|
||||
return self.prepare_finalize.output_is_reduced()
|
||||
|
||||
def _chunk_info(self, M: int) -> tuple[int, int]:
|
||||
"""
|
||||
Compute number of chunks and chunk size for given M.
|
||||
If chunking is not supported, set the CHUNK_SIZE to M so we
|
||||
get num_chunks == 1. Take max(M, 1) to avoid divide by zero.
|
||||
If there are no tokens to process, the number of chunks will be zero.
|
||||
"""
|
||||
CHUNK_SIZE = (
|
||||
max(M, 1)
|
||||
if not self.fused_experts.supports_chunking()
|
||||
else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
)
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
# If there are no tokens, then there should be no loop iterations.
|
||||
assert M > 0 or num_chunks == 0
|
||||
return num_chunks, CHUNK_SIZE
|
||||
|
||||
def _allocate_buffers(
|
||||
self,
|
||||
fused_out: Optional[torch.Tensor],
|
||||
a1: torch.Tensor,
|
||||
a1q: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
out_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
M_chunk: int,
|
||||
M_full: int,
|
||||
N: int,
|
||||
K: int,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
_, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Allocate temporary and output buffers for the fused experts op.
|
||||
Inputs:
|
||||
- out_dtype: output type of workspace and output tensors.
|
||||
- device: the device of the workspace and output tensors.
|
||||
See `workspace_shapes` for a description of the remainder of arguments.
|
||||
Returns a tuple of (workspace13, workspace2, output) tensors.
|
||||
"""
|
||||
assert M_full > 0 and M_chunk > 0
|
||||
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype) = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
a1,
|
||||
a1q,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
)
|
||||
num_chunks, _ = self._chunk_info(M_full)
|
||||
|
||||
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = buffers.workspace13.get(
|
||||
workspace13_shape, device=a1.device, dtype=workspace_dtype
|
||||
)
|
||||
workspace2 = buffers.workspace2.get(
|
||||
workspace2_shape, device=a1.device, dtype=workspace_dtype
|
||||
)
|
||||
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}"
|
||||
)
|
||||
if fused_out is None:
|
||||
# reuse workspace13 for the output
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
|
||||
self.fused_experts.apply(
|
||||
fused_out,
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
return fused_out
|
||||
|
||||
def _maybe_chunk_fused_experts(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1q: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
_, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
|
||||
# TODO(bnell): get rid of one level here, update slice functions
|
||||
# to nops on num_chunks==1
|
||||
|
||||
if not self.fused_experts.supports_chunking() or num_chunks == 1:
|
||||
return self._do_fused_experts(
|
||||
fused_out=None,
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=self.fused_experts.a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
# Chunking required case
|
||||
assert num_chunks > 1
|
||||
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||
a1,
|
||||
a1q,
|
||||
M,
|
||||
# Get intermediate workspace shapes based off the chunked M size.
|
||||
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(
|
||||
M_chunk,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
@ -810,102 +766,338 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
fused_out = buffers.fused_out.get(
|
||||
fused_out_shape, device=a1q.device, dtype=a1.dtype
|
||||
|
||||
# Get final output shape based on the full M size.
|
||||
_, _, fused_out_shape = self.fused_experts.workspace_shapes(
|
||||
M_full,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
return (
|
||||
a1q[s:e],
|
||||
_chunk_scales(a1q_scale, s, e),
|
||||
_chunk_scales(self.fused_experts.a2_scale, s, e),
|
||||
topk_ids[s:e],
|
||||
topk_weights[s:e],
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = buffers.workspace13.get(
|
||||
workspace13_shape, device=device, dtype=workspace_dtype
|
||||
)
|
||||
workspace2 = buffers.workspace2.get(
|
||||
workspace2_shape, device=device, dtype=workspace_dtype
|
||||
)
|
||||
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
# Reuse workspace13 for the output in the non-chunked case as long
|
||||
# as it is large enough. This will not always be the case for standard
|
||||
# format experts and with experts that have empty workspaces.
|
||||
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
else:
|
||||
fused_out = buffers.fused_out.get(
|
||||
fused_out_shape, device=device, dtype=out_dtype
|
||||
)
|
||||
|
||||
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
||||
assert fused_out.size(0) % M == 0, (
|
||||
f"fused_out shape {fused_out.shape} vs M {M}"
|
||||
)
|
||||
factor = fused_out.size(0) // M
|
||||
out_chunk_size = CHUNK_SIZE * factor
|
||||
s = chunk_idx * out_chunk_size
|
||||
e = min(s + out_chunk_size, fused_out.size(0))
|
||||
return fused_out[s:e]
|
||||
return workspace13, workspace2, fused_out
|
||||
|
||||
def slice_expert_tokens_metadata(
|
||||
full_expert_tokens_meta: ExpertTokensMetadata,
|
||||
chunk_topk_ids: torch.Tensor,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
) -> ExpertTokensMetadata:
|
||||
# The existing expert_num_tokens is for the entire a1q
|
||||
# input. Chunking forces recomputation of the number
|
||||
# of tokens assigned to each expert.
|
||||
c_expert_num_tokens = count_expert_num_tokens(
|
||||
chunk_topk_ids, local_num_experts, expert_map
|
||||
@staticmethod
|
||||
def _slice_output_tensor(
|
||||
fused_out: torch.Tensor,
|
||||
chunk_idx: int,
|
||||
num_chunks: int,
|
||||
CHUNK_SIZE: int,
|
||||
M: int,
|
||||
) -> torch.Tensor:
|
||||
if num_chunks == 1:
|
||||
return fused_out
|
||||
|
||||
assert fused_out.size(0) % M == 0, f"fused_out shape {fused_out.shape} vs M {M}"
|
||||
factor = fused_out.size(0) // M
|
||||
out_chunk_size = CHUNK_SIZE * factor
|
||||
s = chunk_idx * out_chunk_size
|
||||
e = min(s + out_chunk_size, fused_out.size(0))
|
||||
return fused_out[s:e]
|
||||
|
||||
@staticmethod
|
||||
def _slice_expert_tokens_metadata(
|
||||
num_chunks: int,
|
||||
full_expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
chunk_topk_ids: torch.Tensor,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
) -> Optional[ExpertTokensMetadata]:
|
||||
if num_chunks == 1 or full_expert_tokens_meta is None:
|
||||
return full_expert_tokens_meta
|
||||
|
||||
# The existing expert_num_tokens is for the entire a1q
|
||||
# input. Chunking forces recomputation of the number
|
||||
# of tokens assigned to each expert.
|
||||
c_expert_num_tokens = count_expert_num_tokens(
|
||||
chunk_topk_ids, local_num_experts, expert_map
|
||||
)
|
||||
|
||||
c_expert_num_tokens_cpu = None
|
||||
need_expert_num_tokens_cpu = (
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None
|
||||
)
|
||||
if need_expert_num_tokens_cpu:
|
||||
# This is blocking as some implementations need the count
|
||||
# on the CPU to determine appropriate input/out fused-moe
|
||||
# buffers
|
||||
c_expert_num_tokens_cpu = c_expert_num_tokens.to("cpu", non_blocking=False)
|
||||
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
|
||||
)
|
||||
|
||||
def _prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""
|
||||
The _prepare method is a wrapper around self.prepare_finalize.prepare
|
||||
that handles DBO and async.
|
||||
"""
|
||||
if not self.prepare_finalize.supports_async():
|
||||
# We shouldn't be running an a2a kernel that doesn't
|
||||
# support async prepare/finalize
|
||||
# TODO(lucas): enable in follow-up
|
||||
assert not dbo_enabled()
|
||||
|
||||
(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
expert_tokens_meta,
|
||||
_expert_topk_ids,
|
||||
_expert_topk_weights,
|
||||
) = self.prepare_finalize.prepare(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
dbo_maybe_run_recv_hook()
|
||||
prepare_ret = self.prepare_finalize.prepare_async(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
c_expert_num_tokens_cpu = None
|
||||
need_expert_num_tokens_cpu = (
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = (
|
||||
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
|
||||
)
|
||||
if need_expert_num_tokens_cpu:
|
||||
# This is blocking as some implementations need the count
|
||||
# on the CPU to determine appropriate input/out fused-moe
|
||||
# buffers
|
||||
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
|
||||
"cpu", non_blocking=False
|
||||
)
|
||||
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
expert_tokens_meta,
|
||||
_expert_topk_ids,
|
||||
_expert_topk_weights,
|
||||
) = receiver()
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
topk_weights = (
|
||||
topk_weights if _expert_topk_weights is None else _expert_topk_weights
|
||||
)
|
||||
|
||||
return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights
|
||||
|
||||
def _fused_experts(
|
||||
self,
|
||||
in_dtype: torch.dtype,
|
||||
a1q: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
) -> torch.Tensor:
|
||||
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
|
||||
a1q, w1, w2, topk_ids
|
||||
)
|
||||
|
||||
num_chunks, CHUNK_SIZE = self._chunk_info(M_full)
|
||||
|
||||
def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
|
||||
if num_chunks == 1:
|
||||
# Use a1q.size(0) here since batched format does not
|
||||
# keep M in the first dimension.
|
||||
return 0, a1q.size(0)
|
||||
else:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M_full)
|
||||
return s, e
|
||||
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
# incompatible all2all kernels like the DeepEP high-throughput
|
||||
# kernels. CUDAGraph compatible all2all kernels like the pplx
|
||||
# kernels and the DeepEP low-latency kernels are always batched
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
if M_full == 0:
|
||||
assert num_chunks == 0
|
||||
workspace13 = None
|
||||
workspace2 = None
|
||||
fused_out = torch.empty_like(a1q)
|
||||
else:
|
||||
assert num_chunks > 0
|
||||
workspace13, workspace2, fused_out = self._allocate_buffers(
|
||||
in_dtype,
|
||||
a1q.device,
|
||||
CHUNK_SIZE,
|
||||
M_full,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
||||
slice_input_tensors(chunk_idx)
|
||||
s, e = input_chunk_range(chunk_idx)
|
||||
|
||||
c_expert_tokens_meta = self._slice_expert_tokens_metadata(
|
||||
num_chunks,
|
||||
expert_tokens_meta,
|
||||
topk_ids[s:e],
|
||||
local_num_experts,
|
||||
expert_map,
|
||||
)
|
||||
|
||||
c_expert_tokens_meta = None
|
||||
if expert_tokens_meta is not None:
|
||||
c_expert_tokens_meta = slice_expert_tokens_metadata(
|
||||
expert_tokens_meta, c_topk_ids, local_num_experts, expert_map
|
||||
)
|
||||
c_fused_out = self._slice_output_tensor(
|
||||
fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full
|
||||
)
|
||||
|
||||
self._do_fused_experts(
|
||||
fused_out=slice_output_tensor(chunk_idx),
|
||||
a1=a1,
|
||||
a1q=c_a1q,
|
||||
self.fused_experts.apply(
|
||||
output=c_fused_out,
|
||||
hidden_states=a1q[s:e],
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=c_topk_weights,
|
||||
topk_ids=c_topk_ids,
|
||||
topk_weights=topk_weights[s:e],
|
||||
topk_ids=topk_ids[s:e],
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
a1q_scale=_slice_scales(a1q_scale, s, e),
|
||||
a2_scale=_slice_scales(self.fused_experts.a2_scale, e, e),
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
return fused_out
|
||||
|
||||
def _finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_out: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
The _finalize method is a wrapper around self.prepare_finalize.finalize
|
||||
that handles DBO, async and shared expert overlap.
|
||||
"""
|
||||
shared_output: Optional[torch.Tensor] = None
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
assert not dbo_enabled()
|
||||
|
||||
self.prepare_finalize.finalize(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
finalize_ret = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = (
|
||||
finalize_ret
|
||||
if isinstance(finalize_ret, tuple)
|
||||
else (None, finalize_ret)
|
||||
)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
receiver()
|
||||
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
else:
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -947,156 +1139,45 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
a1 = hidden_states
|
||||
output = a1 if inplace and self.shared_experts is None else torch.zeros_like(a1)
|
||||
if inplace and self.shared_experts is None:
|
||||
output = hidden_states
|
||||
else:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
# We shouldn't be running an a2a kernel that doesn't
|
||||
# support async prepare/finalize
|
||||
# TODO(lucas): enable in follow-up
|
||||
assert not dbo_enabled()
|
||||
|
||||
(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
expert_tokens_meta,
|
||||
_expert_topk_ids,
|
||||
_expert_topk_weights,
|
||||
) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
dbo_maybe_run_recv_hook()
|
||||
prepare_ret = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = (
|
||||
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
|
||||
)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
expert_tokens_meta,
|
||||
_expert_topk_ids,
|
||||
_expert_topk_weights,
|
||||
) = receiver()
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
topk_weights = (
|
||||
topk_weights if _expert_topk_weights is None else _expert_topk_weights
|
||||
a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights = self._prepare(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
fused_out = None
|
||||
fused_out = self._fused_experts(
|
||||
in_dtype=hidden_states.dtype,
|
||||
a1q=a1q,
|
||||
a1q_scale=a1q_scale,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
|
||||
if a1q.numel() == 0:
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
# incompatible all2all kernels like the DeepEP high-throughput
|
||||
# kernels. CUDAGraph compatible all2all kernels like the pplx
|
||||
# kernels and the DeepEP low-latency kernels are always batched
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
|
||||
else:
|
||||
fused_out = self._maybe_chunk_fused_experts(
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
shared_output: Optional[torch.Tensor] = None
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
assert not dbo_enabled()
|
||||
|
||||
self.prepare_finalize.finalize(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
else:
|
||||
finalize_ret = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = (
|
||||
finalize_ret
|
||||
if isinstance(finalize_ret, tuple)
|
||||
else (None, finalize_ret)
|
||||
)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
receiver()
|
||||
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
else:
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
return self._finalize(
|
||||
output,
|
||||
fused_out,
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@ -91,6 +91,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_async(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@ -27,6 +27,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
def num_dispatchers(self) -> int:
|
||||
return 1
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
|
||||
@ -83,8 +83,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -92,7 +90,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
@ -101,8 +99,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
):
|
||||
assert self.deep_gemm_expert is not None
|
||||
return self.deep_gemm_expert.workspace_shapes(
|
||||
a,
|
||||
aq,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@ -113,8 +109,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
else:
|
||||
return self.triton_expert.workspace_shapes(
|
||||
a,
|
||||
aq,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
|
||||
@ -52,8 +52,6 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@ -61,14 +59,12 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
# TODO(varun) : workspace1 is could be used as the output tensor. This
|
||||
# is error-prone. Allow the `workspace_shapes` to return None workspaces
|
||||
workspace1 = (M, K)
|
||||
workspace2 = (0, 0)
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int):
|
||||
# Number of tokens in the input tensor.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user