mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 03:47:03 +08:00
misc cleanups to prepare for rebase
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
252bf0809e
commit
0323e29153
@ -740,7 +740,7 @@ class PiecewiseBackend:
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode or envs.VLLM_CUDAGRAPH_SANITIZER:
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
|
||||
@ -972,9 +972,7 @@ def pplx_finalize():
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
_all_to_all_cache)
|
||||
for cache in _all_to_all_cache:
|
||||
cache.destroy()
|
||||
# _all_to_all_cache.destroy()
|
||||
_all_to_all_cache.destroy()
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
|
||||
@ -117,7 +117,6 @@ if TYPE_CHECKING:
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
||||
VLLM_ALL2ALL_BACKEND: str = "naive"
|
||||
VLLM_CUDAGRAPH_SANITIZER: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
|
||||
@ -306,24 +306,13 @@ class AllToAllCache:
|
||||
self._cache[key] = instance
|
||||
return instance
|
||||
|
||||
|
||||
from typing import List
|
||||
_all_to_all_cache: List[AllToAllCache] = [AllToAllCache(), AllToAllCache()]
|
||||
# Global singleton
|
||||
_all_to_all_cache = AllToAllCache()
|
||||
|
||||
|
||||
# Factory function with cache ID support
|
||||
def get_all_to_all(cache_id: int, **kwargs):
|
||||
"""Get or create an AllToAll instance from the specified cache.
|
||||
|
||||
Args:
|
||||
cache_id: Integer ID of the cache to use (0 or 1)
|
||||
**kwargs: Arguments passed to AllToAll creation
|
||||
|
||||
Returns:
|
||||
AllToAll instance from the specified cache
|
||||
"""
|
||||
assert cache_id in (0, 1), f"cache_id must be 0 or 1, got {cache_id}"
|
||||
return _all_to_all_cache[cache_id].get_or_create(**kwargs)
|
||||
# Factory function as a cleaner interface
|
||||
def get_all_to_all(**kwargs):
|
||||
return _all_to_all_cache.get_or_create(**kwargs)
|
||||
|
||||
|
||||
@CustomOp.register("unquantized_fused_moe")
|
||||
@ -453,8 +442,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
if isinstance(prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
# print("BatchedTritonExperts %s", self.moe)
|
||||
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
world_size=world_size,
|
||||
@ -466,7 +454,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
block_shape=None,
|
||||
)
|
||||
else:
|
||||
# print("TritonExperts %s", self.moe)
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
experts = TritonExperts(
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
@ -702,26 +690,25 @@ def _construct_prepare_finalize(
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
logger.debug("using PplxPrepareAndFinalize")
|
||||
|
||||
kwargs = {
|
||||
"max_num_tokens" :max_num_tokens,
|
||||
"num_experts" :moe.num_experts,
|
||||
"experts_per_token" :moe.experts_per_token, # topk
|
||||
"rank" :rank,
|
||||
"world_size" :world_size,
|
||||
"dp_size" :dp_size,
|
||||
"hidden_dim":moe.hidden_dim,
|
||||
"hidden_dim_bytes" :moe.hidden_dim * moe.in_dtype.itemsize,
|
||||
"hidden_dim_scale_bytes" :(0 if moe.in_dtype.itemsize != 1 else
|
||||
|
||||
all_to_all = get_all_to_all(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
experts_per_token=moe.experts_per_token, # topk
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to sizeof(float32)
|
||||
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
|
||||
((moe.hidden_dim + moe.block_size - 1) //
|
||||
moe.block_size * torch.float32.itemsize)),
|
||||
}
|
||||
|
||||
|
||||
a2as = [get_all_to_all(0, **kwargs), get_all_to_all(1, **kwargs)]
|
||||
moe.block_size * torch.float32.itemsize)))
|
||||
|
||||
return PplxPrepareAndFinalize(
|
||||
a2as,
|
||||
all_to_all,
|
||||
max_num_tokens=max_num_tokens,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
@ -1311,9 +1298,6 @@ class FusedMoE(torch.nn.Module):
|
||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
|
||||
|
||||
# if (ubatch_ctdx := get_current_ubatch_context()) is not None:
|
||||
# print("in fused moe, ubatch:", ubatch_ctdx.id, "chunk size:", max_tokens_across_dp, "moe_dp_chunk_size_per_rank", moe_dp_chunk_size_per_rank)
|
||||
|
||||
num_tokens = full_hidden_states.size(0)
|
||||
for chunk_start_ in range(0, max_tokens_across_dp,
|
||||
moe_dp_chunk_size_per_rank):
|
||||
@ -1412,8 +1396,6 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
assert self.quant_method is not None
|
||||
# if (ubatch_ctx := get_current_ubatch_context()) is not None:
|
||||
# print("in fused moe, ubatch:", ubatch_ctx.id, self)
|
||||
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user