From de92ab523b969054c22d360c84294bc9de283d9a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 19 Aug 2025 20:01:22 +0000 Subject: [PATCH] single deepep handle Signed-off-by: Lucas Wilkinson --- .../device_communicators/all2all.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index dc0db5a71e859..08e453b5eee73 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -109,16 +109,17 @@ class PPLXAll2AllManager(All2AllManagerBase): return self.handle_caches[0].get_or_create( kwargs, pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode) - + def get_handles(self, kwargs): import pplx_kernels as pplx - first_handle = self.handle_caches[0].get_or_create(kwargs, pplx.AllToAll.internode + first_handle = self.handle_caches[0].get_or_create( + kwargs, pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode) - second_handle = self.handle_caches[1].get_or_create(kwargs, pplx.AllToAll.internode + second_handle = self.handle_caches[1].get_or_create( + kwargs, pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode) return [first_handle, second_handle] - def dispatch(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): raise NotImplementedError @@ -223,6 +224,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): def __init__(self, cpu_group): super().__init__(cpu_group) + self.handle_cache = self.handle_caches[0] def _make_all2all_kwargs( self, @@ -266,7 +268,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): import deep_ep buffer_kwargs = self._make_all2all_kwargs(**kwargs) logger.debug("DeepEP all2all args %s", buffer_kwargs) - handle: deep_ep.Buffer = self.handle_caches[0].get_or_create( + handle: deep_ep.Buffer = self.handle_cache.get_or_create( buffer_kwargs, deep_ep.Buffer) # It is dangerous to set num sms outside this function. num_sms is not # a part of the hash-key that identifies this object. If we are in a @@ -276,8 +278,6 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): return handle def get_handles(self, kwargs): - import deep_ep - buffer_kwargs = self._make_all2all_kwargs(**kwargs) - first_handle = self.handle_caches[0].get_or_create(buffer_kwargs, deep_ep.Buffer) - second_handle = self.handle_caches[1].get_or_create(buffer_kwargs, deep_ep.Buffer) - return [first_handle, second_handle] \ No newline at end of file + handle = self.get_handle(kwargs) + # For DeepEP we use the same handle for microbatching + return [handle, handle]