misc cleanups to prepare for rebase

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-02 14:13:30 +00:00
parent 252bf0809e
commit 0323e29153
4 changed files with 25 additions and 46 deletions

View File

@ -740,7 +740,7 @@ class PiecewiseBackend:
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode or envs.VLLM_CUDAGRAPH_SANITIZER:
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)

View File

@ -972,9 +972,7 @@ def pplx_finalize():
logger.debug("PPLX NVSHMEM finalize")
from vllm.model_executor.layers.fused_moe.layer import (
_all_to_all_cache)
for cache in _all_to_all_cache:
cache.destroy()
# _all_to_all_cache.destroy()
_all_to_all_cache.destroy()
nvshmem_finalize()

View File

@ -117,7 +117,6 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_CUDAGRAPH_SANITIZER: bool = False
def get_default_cache_root():

View File

@ -306,24 +306,13 @@ class AllToAllCache:
self._cache[key] = instance
return instance
from typing import List
_all_to_all_cache: List[AllToAllCache] = [AllToAllCache(), AllToAllCache()]
# Global singleton
_all_to_all_cache = AllToAllCache()
# Factory function with cache ID support
def get_all_to_all(cache_id: int, **kwargs):
"""Get or create an AllToAll instance from the specified cache.
Args:
cache_id: Integer ID of the cache to use (0 or 1)
**kwargs: Arguments passed to AllToAll creation
Returns:
AllToAll instance from the specified cache
"""
assert cache_id in (0, 1), f"cache_id must be 0 or 1, got {cache_id}"
return _all_to_all_cache[cache_id].get_or_create(**kwargs)
# Factory function as a cleaner interface
def get_all_to_all(**kwargs):
return _all_to_all_cache.get_or_create(**kwargs)
@CustomOp.register("unquantized_fused_moe")
@ -453,8 +442,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
# print("BatchedTritonExperts %s", self.moe)
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
@ -466,7 +454,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
block_shape=None,
)
else:
# print("TritonExperts %s", self.moe)
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
@ -702,26 +690,25 @@ def _construct_prepare_finalize(
if moe.use_pplx_kernels:
logger.debug("using PplxPrepareAndFinalize")
kwargs = {
"max_num_tokens" :max_num_tokens,
"num_experts" :moe.num_experts,
"experts_per_token" :moe.experts_per_token, # topk
"rank" :rank,
"world_size" :world_size,
"dp_size" :dp_size,
"hidden_dim":moe.hidden_dim,
"hidden_dim_bytes" :moe.hidden_dim * moe.in_dtype.itemsize,
"hidden_dim_scale_bytes" :(0 if moe.in_dtype.itemsize != 1 else
all_to_all = get_all_to_all(
max_num_tokens=max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) //
moe.block_size * torch.float32.itemsize)),
}
a2as = [get_all_to_all(0, **kwargs), get_all_to_all(1, **kwargs)]
moe.block_size * torch.float32.itemsize)))
return PplxPrepareAndFinalize(
a2as,
all_to_all,
max_num_tokens=max_num_tokens,
world_size=world_size,
rank=rank,
@ -1311,9 +1298,6 @@ class FusedMoE(torch.nn.Module):
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
# if (ubatch_ctdx := get_current_ubatch_context()) is not None:
# print("in fused moe, ubatch:", ubatch_ctdx.id, "chunk size:", max_tokens_across_dp, "moe_dp_chunk_size_per_rank", moe_dp_chunk_size_per_rank)
num_tokens = full_hidden_states.size(0)
for chunk_start_ in range(0, max_tokens_across_dp,
moe_dp_chunk_size_per_rank):
@ -1412,8 +1396,6 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
# if (ubatch_ctx := get_current_ubatch_context()) is not None:
# print("in fused moe, ubatch:", ubatch_ctx.id, self)
return self.forward_impl(hidden_states, router_logits)