[Misc] Turn MOE_DP_CHUNK_SIZE into an env var (#19506)

This commit is contained in:
Varun Sundar Rabindranath 2025-06-12 14:01:16 -04:00 committed by GitHub
parent 017ef648e9
commit 9d880f594d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 8 deletions

View File

@ -112,6 +112,7 @@ if TYPE_CHECKING:
VLLM_DP_SIZE: int = 1
VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0
VLLM_MOE_DP_CHUNK_SIZE: int = 256
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
@ -773,6 +774,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DP_MASTER_PORT":
lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),
# In the context of executing MoE models with Data-Parallel, Expert-Parallel
# and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE
# dictates the quantum of tokens that can be dispatched from a DP
# rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE
# units.
"VLLM_MOE_DP_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")),
# Randomize inputs during dummy runs when using Data Parallel
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1",

View File

@ -61,10 +61,6 @@ else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
# Note: this limit is somewhat arbitrary and might be changed later.
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
MOE_DP_CHUNK_SIZE = 256
@dataclass
class FusedMoEParallelConfig:
@ -218,7 +214,12 @@ class MoEConfig:
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128
max_num_tokens: int = MOE_DP_CHUNK_SIZE
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
def __post_init__(self):
if self.dp_size > 1:
logger.debug("Using MOEConfig::max_num_tokens=%d",
self.max_num_tokens)
@property
def tp_size(self):
@ -913,7 +914,7 @@ class FusedMoE(torch.nn.Module):
moe_parallel_config=self.moe_parallel_config,
in_dtype=params_dtype,
quant_dtype=quant_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
)
self.moe_config = moe
self.quant_config = quant_config
@ -952,12 +953,12 @@ class FusedMoE(torch.nn.Module):
or self.moe_parallel_config.use_deepep_ll_kernels):
act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size),
(envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size),
dtype=act_dtype,
device=torch.cuda.current_device())
self.batched_router_logits = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
(envs.VLLM_MOE_DP_CHUNK_SIZE, self.global_num_experts),
dtype=act_dtype,
device=torch.cuda.current_device())