From 0323e29153b41479f173cd5c4acc973e9713c80a Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 2 Jun 2025 14:13:30 +0000 Subject: [PATCH] misc cleanups to prepare for rebase Signed-off-by: Sage Moore --- vllm/compilation/backends.py | 2 +- vllm/distributed/parallel_state.py | 4 +- vllm/envs.py | 1 - vllm/model_executor/layers/fused_moe/layer.py | 64 +++++++------------ 4 files changed, 25 insertions(+), 46 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index ae2372b2a6e6b..0c1381a565c16 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -740,7 +740,7 @@ class PiecewiseBackend: # manage the memory during cuda graph capture return output - if self.is_debugging_mode or envs.VLLM_CUDAGRAPH_SANITIZER: + if self.is_debugging_mode: # check if the input addresses are the same new_input_addresses = [ x.data_ptr() for x in args if isinstance(x, torch.Tensor) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 2edd8ec790021..51c519d8f8623 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -972,9 +972,7 @@ def pplx_finalize(): logger.debug("PPLX NVSHMEM finalize") from vllm.model_executor.layers.fused_moe.layer import ( _all_to_all_cache) - for cache in _all_to_all_cache: - cache.destroy() - # _all_to_all_cache.destroy() + _all_to_all_cache.destroy() nvshmem_finalize() diff --git a/vllm/envs.py b/vllm/envs.py index 9d226b298cefb..cd545f32c4301 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -117,7 +117,6 @@ if TYPE_CHECKING: VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 VLLM_ALL2ALL_BACKEND: str = "naive" - VLLM_CUDAGRAPH_SANITIZER: bool = False def get_default_cache_root(): diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ab825610d66a1..822bde906dc97 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -306,24 +306,13 @@ class AllToAllCache: self._cache[key] = instance return instance - -from typing import List -_all_to_all_cache: List[AllToAllCache] = [AllToAllCache(), AllToAllCache()] +# Global singleton +_all_to_all_cache = AllToAllCache() -# Factory function with cache ID support -def get_all_to_all(cache_id: int, **kwargs): - """Get or create an AllToAll instance from the specified cache. - - Args: - cache_id: Integer ID of the cache to use (0 or 1) - **kwargs: Arguments passed to AllToAll creation - - Returns: - AllToAll instance from the specified cache - """ - assert cache_id in (0, 1), f"cache_id must be 0 or 1, got {cache_id}" - return _all_to_all_cache[cache_id].get_or_create(**kwargs) +# Factory function as a cleaner interface +def get_all_to_all(**kwargs): + return _all_to_all_cache.get_or_create(**kwargs) @CustomOp.register("unquantized_fused_moe") @@ -453,8 +442,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): if isinstance(prepare_finalize, (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): - # print("BatchedTritonExperts %s", self.moe) - + logger.debug("BatchedTritonExperts %s", self.moe) experts = BatchedTritonExperts( max_num_tokens=MOE_DP_CHUNK_SIZE, world_size=world_size, @@ -466,7 +454,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): block_shape=None, ) else: - # print("TritonExperts %s", self.moe) + logger.debug("TritonExperts %s", self.moe) experts = TritonExperts( use_fp8_w8a8=False, use_int8_w8a8=False, @@ -702,26 +690,25 @@ def _construct_prepare_finalize( if moe.use_pplx_kernels: logger.debug("using PplxPrepareAndFinalize") - - kwargs = { - "max_num_tokens" :max_num_tokens, - "num_experts" :moe.num_experts, - "experts_per_token" :moe.experts_per_token, # topk - "rank" :rank, - "world_size" :world_size, - "dp_size" :dp_size, - "hidden_dim":moe.hidden_dim, - "hidden_dim_bytes" :moe.hidden_dim * moe.in_dtype.itemsize, - "hidden_dim_scale_bytes" :(0 if moe.in_dtype.itemsize != 1 else + + all_to_all = get_all_to_all( + max_num_tokens=max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ((moe.hidden_dim + moe.block_size - 1) // - moe.block_size * torch.float32.itemsize)), - } - - - a2as = [get_all_to_all(0, **kwargs), get_all_to_all(1, **kwargs)] + moe.block_size * torch.float32.itemsize))) return PplxPrepareAndFinalize( - a2as, + all_to_all, max_num_tokens=max_num_tokens, world_size=world_size, rank=rank, @@ -1311,9 +1298,6 @@ class FusedMoE(torch.nn.Module): max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE - # if (ubatch_ctdx := get_current_ubatch_context()) is not None: - # print("in fused moe, ubatch:", ubatch_ctdx.id, "chunk size:", max_tokens_across_dp, "moe_dp_chunk_size_per_rank", moe_dp_chunk_size_per_rank) - num_tokens = full_hidden_states.size(0) for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): @@ -1412,8 +1396,6 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] assert self.quant_method is not None - # if (ubatch_ctx := get_current_ubatch_context()) is not None: - # print("in fused moe, ubatch:", ubatch_ctx.id, self) return self.forward_impl(hidden_states, router_logits)