diff --git a/vllm/envs.py b/vllm/envs.py index f24ae64396f33..921052821ee3a 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -112,6 +112,7 @@ if TYPE_CHECKING: VLLM_DP_SIZE: int = 1 VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 + VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False @@ -773,6 +774,14 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_DP_MASTER_PORT": lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), + # In the context of executing MoE models with Data-Parallel, Expert-Parallel + # and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE + # dictates the quantum of tokens that can be dispatched from a DP + # rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE + # units. + "VLLM_MOE_DP_CHUNK_SIZE": + lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), + # Randomize inputs during dummy runs when using Data Parallel "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1", diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cf8e4ee6509cc..1fd8f2175886a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -61,10 +61,6 @@ else: fused_moe_pallas = None # type: ignore logger = init_logger(__name__) -# Note: this limit is somewhat arbitrary and might be changed later. -# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim. -MOE_DP_CHUNK_SIZE = 256 - @dataclass class FusedMoEParallelConfig: @@ -218,7 +214,12 @@ class MoEConfig: # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 - max_num_tokens: int = MOE_DP_CHUNK_SIZE + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE + + def __post_init__(self): + if self.dp_size > 1: + logger.debug("Using MOEConfig::max_num_tokens=%d", + self.max_num_tokens) @property def tp_size(self): @@ -913,7 +914,7 @@ class FusedMoE(torch.nn.Module): moe_parallel_config=self.moe_parallel_config, in_dtype=params_dtype, quant_dtype=quant_dtype, - max_num_tokens=MOE_DP_CHUNK_SIZE, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, ) self.moe_config = moe self.quant_config = quant_config @@ -952,12 +953,12 @@ class FusedMoE(torch.nn.Module): or self.moe_parallel_config.use_deepep_ll_kernels): act_dtype = vllm_config.model_config.dtype self.batched_hidden_states = torch.zeros( - (MOE_DP_CHUNK_SIZE, self.hidden_size), + (envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size), dtype=act_dtype, device=torch.cuda.current_device()) self.batched_router_logits = torch.zeros( - (MOE_DP_CHUNK_SIZE, self.global_num_experts), + (envs.VLLM_MOE_DP_CHUNK_SIZE, self.global_num_experts), dtype=act_dtype, device=torch.cuda.current_device())