single deepep handle

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-08-19 20:01:22 +00:00
parent 9f04a6cf57
commit de92ab523b

View File

@ -109,16 +109,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
return self.handle_caches[0].get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
def get_handles(self, kwargs):
import pplx_kernels as pplx
first_handle = self.handle_caches[0].get_or_create(kwargs, pplx.AllToAll.internode
first_handle = self.handle_caches[0].get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
second_handle = self.handle_caches[1].get_or_create(kwargs, pplx.AllToAll.internode
second_handle = self.handle_caches[1].get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
return [first_handle, second_handle]
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
@ -223,6 +224,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)
self.handle_cache = self.handle_caches[0]
def _make_all2all_kwargs(
self,
@ -266,7 +268,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
import deep_ep
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_caches[0].get_or_create(
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
# It is dangerous to set num sms outside this function. num_sms is not
# a part of the hash-key that identifies this object. If we are in a
@ -276,8 +278,6 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
return handle
def get_handles(self, kwargs):
import deep_ep
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
first_handle = self.handle_caches[0].get_or_create(buffer_kwargs, deep_ep.Buffer)
second_handle = self.handle_caches[1].get_or_create(buffer_kwargs, deep_ep.Buffer)
return [first_handle, second_handle]
handle = self.get_handle(kwargs)
# For DeepEP we use the same handle for microbatching
return [handle, handle]