mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[BugFix] Make sure to allocate worst case MoE workspace during profile run in the DP + EP case (#27426)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
1bed891f72
commit
1840c5cb18
@ -55,7 +55,7 @@ if TYPE_CHECKING:
|
||||
VLLM_CPU_SGL_KERNEL: bool = False
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
|
||||
VLLM_FUSED_MOE_CHUNK_SIZE: int = 16 * 1024
|
||||
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True
|
||||
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
|
||||
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
|
||||
@ -785,7 +785,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Enable SPMD mode for TPU backend.
|
||||
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
|
||||
"VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(
|
||||
os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")
|
||||
os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(16 * 1024))
|
||||
),
|
||||
# Control whether to use fused MoE activation chunking. Current chunking
|
||||
# logic is incompatible with torch.compile and causes IMA. See issue
|
||||
|
||||
@ -10,6 +10,9 @@ from typing import final
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
@ -26,6 +29,8 @@ from vllm.v1.worker.ubatching import (
|
||||
dbo_yield,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
#
|
||||
# This file defines a set of base classes used to make MoE kernels more modular.
|
||||
# The goal is to be able to utilize different communication mechanisms with
|
||||
@ -798,6 +803,42 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
|
||||
|
||||
# Force worst-case allocation in profiling run for
|
||||
# "mk.FusedMoEModularKernel.Standard" formats where this is only bounded
|
||||
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
|
||||
# DP+EP due to the random token routing.
|
||||
is_profile_run = (
|
||||
is_forward_context_available()
|
||||
and get_forward_context().attn_metadata is None
|
||||
)
|
||||
if is_profile_run and self.fused_experts.supports_chunking():
|
||||
parallel_config = get_current_vllm_config().parallel_config
|
||||
is_dp_ep = (
|
||||
parallel_config.data_parallel_size > 1
|
||||
and parallel_config.enable_expert_parallel
|
||||
)
|
||||
if is_dp_ep:
|
||||
max_workspace_13, max_workspace_2, max_fused_out_shape = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
envs.VLLM_FUSED_MOE_CHUNK_SIZE,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
)
|
||||
buffers.workspace13.get(
|
||||
max_workspace_13, device=device, dtype=workspace_dtype
|
||||
)
|
||||
buffers.workspace2.get(
|
||||
max_workspace_2, device=device, dtype=workspace_dtype
|
||||
)
|
||||
buffers.fused_out.get(
|
||||
max_fused_out_shape, device=device, dtype=workspace_dtype
|
||||
)
|
||||
|
||||
# Get intermediate workspace shapes based off the chunked M size.
|
||||
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(
|
||||
M_chunk,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user