diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index ede3dec5fe147..693e537dc44c3 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -147,7 +147,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): has_deepep = importlib.util.find_spec("deep_ep") is not None assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa super().__init__(cpu_group) - self.handle_cache = Cache() + self.handle_caches = [Cache(), Cache()] # This is the DeepEP default. Stick to it till we can establish # reasonable defaults based on profiling. @@ -174,6 +174,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): def __init__(self, cpu_group): super().__init__(cpu_group) + self.handle_cache = self.handle_caches[0] def _make_all2all_kwargs(self) -> dict[Any, Any]: # Defaults for internode and intranode are taken from DeepEP tests. @@ -265,7 +266,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): import deep_ep buffer_kwargs = self._make_all2all_kwargs(**kwargs) logger.debug("DeepEP all2all args %s", buffer_kwargs) - handle: deep_ep.Buffer = self.handle_cache.get_or_create( + handle: deep_ep.Buffer = self.handle_caches[0].get_or_create( buffer_kwargs, deep_ep.Buffer) # It is dangerous to set num sms outside this function. num_sms is not # a part of the hash-key that identifies this object. If we are in a @@ -273,3 +274,10 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): # in get_or_create must be updated. handle.set_num_sms(self.num_sms) return handle + + def get_handles(self, kwargs): + import deep_ep + buffer_kwargs = self._make_all2all_kwargs(**kwargs) + first_handle = self.handle_caches[0].get_or_create(buffer_kwargs, deep_ep.Buffer) + second_handle = self.handle_caches[1].get_or_create(buffer_kwargs, deep_ep.Buffer) + return [first_handle, second_handle] diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 3484a7a8a496a..c2aa5a831b0f5 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -7,6 +7,9 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) +from vllm.v1.worker.ubatching import ( + get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl, + yield_and_switch_from_compute_to_comm_impl) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -38,7 +41,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168] def __init__(self, - buffer: deep_ep.Buffer, + buffers: list[deep_ep.Buffer], world_size: int, dp_size: int, max_tokens_per_rank: int, @@ -47,7 +50,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): use_fp8_dispatch: bool = False): super().__init__() - self.buffer = buffer + self.buffers = buffers self.world_size = world_size self.dp_size = dp_size self.quant_dtype = quant_dtype @@ -127,9 +130,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): Optional[torch.Tensor], Optional[torch.Tensor]]: hidden_size = a1.size(1) - assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ - (f"Hidden Size {hidden_size} not in supported list of hidden sizes" - f"{self.SUPPORTED_HIDDEN_SIZES}") + ubatch_ctx = get_current_ubatch_context() + ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 + a2a_idx = 0 if ubatch_id == -1 else ubatch_id + # assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ + # (f"Hidden Size {hidden_size} not in supported list of hidden sizes" + # f"{self.SUPPORTED_HIDDEN_SIZES}") if self.use_fp8_dispatch: assert hidden_size % 128 == 0, \ @@ -150,7 +156,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # Dispatch expert_x, expert_num_tokens, self.handle, event, hook = \ - self.buffer.low_latency_dispatch(a1, + self.buffers[a2a_idx].low_latency_dispatch(a1, rank_topk_ids, self.max_tokens_per_rank, num_experts, @@ -168,6 +174,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool) -> None: assert self.handle is not None + ubatch_ctx = get_current_ubatch_context() + ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 + a2a_idx = 0 if ubatch_id == -1 else ubatch_id combine_topk_weights = topk_weights if apply_router_weight_on_input: @@ -175,7 +184,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): combine_topk_weights = torch.ones_like(topk_weights) # TODO (varun) : Enable zero copy mode - _, event, hook = self.buffer.low_latency_combine( + _, event, hook = self.buffers[a2a_idx].low_latency_combine( fused_expert_output, topk_ids, combine_topk_weights, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 56e134be9be26..97099b9b6f5c0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -377,12 +377,12 @@ class FusedMoEMethodBase(QuantizeMethodBase): num_global_experts=moe.num_experts, num_local_experts=moe.num_experts // all2all_manager.world_size) - handle = all2all_manager.get_handle(all_to_all_args) + handles = all2all_manager.get_handles(all_to_all_args) # Note (varun): Whether to use FP8 dispatch or not needs some # profiling. Turning it off for now. prepare_finalize = DeepEPLLPrepareAndFinalize( - handle, + handles, world_size=all2all_manager.world_size, dp_size=all2all_manager.dp_world_size, max_tokens_per_rank=moe.max_num_tokens, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 842700ea23583..e95f4594c0a13 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1467,6 +1467,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): dtype=self.model_config.dtype, device=self.device)) + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + slice(0, num_tokens), None, False) + + return input_ids, positions, inputs_embeds, intermediate_tensors def _get_model_inputs(self, tokens_slice: slice,