mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 06:05:01 +08:00
[Feat] Refactor for parallel_config in FusedMoEModularKernel (#30282)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
b337647aa0
commit
3778673ea8
@ -594,7 +594,8 @@ def make_modular_kernel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
modular_kernel = mk.FusedMoEModularKernel(
|
modular_kernel = mk.FusedMoEModularKernel(
|
||||||
prepare_finalize=prepare_finalize, fused_experts=fused_experts
|
prepare_finalize=prepare_finalize,
|
||||||
|
fused_experts=fused_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
return modular_kernel
|
return modular_kernel
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
@ -107,6 +108,19 @@ class TestData:
|
|||||||
layer.w2_input_scale = a2_scale
|
layer.w2_input_scale = a2_scale
|
||||||
layer.w13_weight_scale = w13_weight_scale
|
layer.w13_weight_scale = w13_weight_scale
|
||||||
layer.w2_weight_scale = w2_weight_scale
|
layer.w2_weight_scale = w2_weight_scale
|
||||||
|
# Setup dummy config.
|
||||||
|
layer.moe_parallel_config = mk.FusedMoEParallelConfig(
|
||||||
|
tp_size=1,
|
||||||
|
pcp_size=1,
|
||||||
|
dp_size=1,
|
||||||
|
ep_size=1,
|
||||||
|
tp_rank=1,
|
||||||
|
pcp_rank=1,
|
||||||
|
dp_rank=1,
|
||||||
|
ep_rank=1,
|
||||||
|
use_ep=False,
|
||||||
|
all2all_backend="naive",
|
||||||
|
)
|
||||||
|
|
||||||
register_moe_scaling_factors(layer)
|
register_moe_scaling_factors(layer)
|
||||||
|
|
||||||
|
|||||||
@ -460,7 +460,6 @@ def cutlass_moe_fp8(
|
|||||||
expert_map: torch.Tensor | None = None,
|
expert_map: torch.Tensor | None = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
parallel_config=None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||||
@ -538,7 +537,6 @@ def cutlass_moe_fp8(
|
|||||||
c_strides2=c_strides2,
|
c_strides2=c_strides2,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
),
|
),
|
||||||
parallel_config=parallel_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return fn(
|
return fn(
|
||||||
|
|||||||
@ -293,7 +293,7 @@ def deep_gemm_moe_fp8(
|
|||||||
expert_map: torch.Tensor | None = None,
|
expert_map: torch.Tensor | None = None,
|
||||||
a1_scale: torch.Tensor | None = None,
|
a1_scale: torch.Tensor | None = None,
|
||||||
a2_scale: torch.Tensor | None = None,
|
a2_scale: torch.Tensor | None = None,
|
||||||
apply_router_weight_on_input=False,
|
apply_router_weight_on_input: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||||
|
|||||||
@ -43,11 +43,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
shared_experts: torch.nn.Module | None,
|
shared_experts: torch.nn.Module | None,
|
||||||
) -> "FusedMoEModularMethod":
|
) -> "FusedMoEModularMethod":
|
||||||
parallel_config = getattr(
|
|
||||||
getattr(moe_layer, "vllm_config", None),
|
|
||||||
"parallel_config",
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
return FusedMoEModularMethod(
|
return FusedMoEModularMethod(
|
||||||
old_quant_method,
|
old_quant_method,
|
||||||
FusedMoEModularKernel(
|
FusedMoEModularKernel(
|
||||||
@ -55,7 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
||||||
shared_experts,
|
shared_experts,
|
||||||
getattr(moe_layer, "shared_experts_stream", None),
|
getattr(moe_layer, "shared_experts_stream", None),
|
||||||
parallel_config=parallel_config,
|
moe_parallel_config=moe_layer.moe_parallel_config,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -10,10 +10,12 @@ from typing import final
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
|
||||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEParallelConfig,
|
||||||
|
FusedMoEQuantConfig,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
_resize_cache,
|
_resize_cache,
|
||||||
count_expert_num_tokens,
|
count_expert_num_tokens,
|
||||||
@ -681,7 +683,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||||
shared_experts: torch.nn.Module | None = None,
|
shared_experts: torch.nn.Module | None = None,
|
||||||
shared_experts_stream: torch.cuda.Stream | None = None,
|
shared_experts_stream: torch.cuda.Stream | None = None,
|
||||||
parallel_config: ParallelConfig | None = None,
|
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.prepare_finalize = prepare_finalize
|
self.prepare_finalize = prepare_finalize
|
||||||
@ -689,12 +691,15 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
self.shared_experts = shared_experts
|
self.shared_experts = shared_experts
|
||||||
self.shared_experts_stream = shared_experts_stream
|
self.shared_experts_stream = shared_experts_stream
|
||||||
|
|
||||||
# cache whether this worker is using DP+EP
|
# prefer an explicit FusedMoEParallelConfig when available (from
|
||||||
if parallel_config is None:
|
# FusedMoE layers / tests).
|
||||||
parallel_config = get_current_vllm_config().parallel_config
|
# if not provided, assume this kernel is
|
||||||
|
# running in a non-DP+EP context
|
||||||
|
self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
|
||||||
self.is_dp_ep = (
|
self.is_dp_ep = (
|
||||||
parallel_config.data_parallel_size > 1
|
moe_parallel_config is not None
|
||||||
and parallel_config.enable_expert_parallel
|
and moe_parallel_config.dp_size > 1
|
||||||
|
and moe_parallel_config.use_ep
|
||||||
)
|
)
|
||||||
|
|
||||||
self._post_init_setup()
|
self._post_init_setup()
|
||||||
|
|||||||
@ -1266,9 +1266,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
ab_strides2=self.ab_strides2,
|
ab_strides2=self.ab_strides2,
|
||||||
c_strides1=self.c_strides1,
|
c_strides1=self.c_strides1,
|
||||||
c_strides2=self.ab_strides1_c_strides2,
|
c_strides2=self.ab_strides1_c_strides2,
|
||||||
parallel_config=getattr(
|
|
||||||
getattr(layer, "vllm_config", None), "parallel_config", None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -247,11 +247,6 @@ def flashinfer_cutlass_moe_fp8(
|
|||||||
assert quant_config is not None
|
assert quant_config is not None
|
||||||
|
|
||||||
# Construct modular kernel with block-scale support when requested.
|
# Construct modular kernel with block-scale support when requested.
|
||||||
parallel_config = getattr(
|
|
||||||
getattr(layer, "vllm_config", None),
|
|
||||||
"parallel_config",
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
fused_experts = mk.FusedMoEModularKernel(
|
fused_experts = mk.FusedMoEModularKernel(
|
||||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||||
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||||
@ -262,7 +257,7 @@ def flashinfer_cutlass_moe_fp8(
|
|||||||
out_dtype=hidden_states.dtype,
|
out_dtype=hidden_states.dtype,
|
||||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||||
),
|
),
|
||||||
parallel_config=parallel_config,
|
moe_parallel_config=layer.moe_parallel_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user