mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 05:15:42 +08:00
single deepep handle
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
9f04a6cf57
commit
de92ab523b
@ -112,13 +112,14 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
|||||||
|
|
||||||
def get_handles(self, kwargs):
|
def get_handles(self, kwargs):
|
||||||
import pplx_kernels as pplx
|
import pplx_kernels as pplx
|
||||||
first_handle = self.handle_caches[0].get_or_create(kwargs, pplx.AllToAll.internode
|
first_handle = self.handle_caches[0].get_or_create(
|
||||||
|
kwargs, pplx.AllToAll.internode
|
||||||
if self.internode else pplx.AllToAll.intranode)
|
if self.internode else pplx.AllToAll.intranode)
|
||||||
second_handle = self.handle_caches[1].get_or_create(kwargs, pplx.AllToAll.internode
|
second_handle = self.handle_caches[1].get_or_create(
|
||||||
|
kwargs, pplx.AllToAll.internode
|
||||||
if self.internode else pplx.AllToAll.intranode)
|
if self.internode else pplx.AllToAll.intranode)
|
||||||
return [first_handle, second_handle]
|
return [first_handle, second_handle]
|
||||||
|
|
||||||
|
|
||||||
def dispatch(self, hidden_states: torch.Tensor,
|
def dispatch(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -223,6 +224,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
|
|
||||||
def __init__(self, cpu_group):
|
def __init__(self, cpu_group):
|
||||||
super().__init__(cpu_group)
|
super().__init__(cpu_group)
|
||||||
|
self.handle_cache = self.handle_caches[0]
|
||||||
|
|
||||||
def _make_all2all_kwargs(
|
def _make_all2all_kwargs(
|
||||||
self,
|
self,
|
||||||
@ -266,7 +268,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
import deep_ep
|
import deep_ep
|
||||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||||
handle: deep_ep.Buffer = self.handle_caches[0].get_or_create(
|
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||||
buffer_kwargs, deep_ep.Buffer)
|
buffer_kwargs, deep_ep.Buffer)
|
||||||
# It is dangerous to set num sms outside this function. num_sms is not
|
# It is dangerous to set num sms outside this function. num_sms is not
|
||||||
# a part of the hash-key that identifies this object. If we are in a
|
# a part of the hash-key that identifies this object. If we are in a
|
||||||
@ -276,8 +278,6 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
return handle
|
return handle
|
||||||
|
|
||||||
def get_handles(self, kwargs):
|
def get_handles(self, kwargs):
|
||||||
import deep_ep
|
handle = self.get_handle(kwargs)
|
||||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
# For DeepEP we use the same handle for microbatching
|
||||||
first_handle = self.handle_caches[0].get_or_create(buffer_kwargs, deep_ep.Buffer)
|
return [handle, handle]
|
||||||
second_handle = self.handle_caches[1].get_or_create(buffer_kwargs, deep_ep.Buffer)
|
|
||||||
return [first_handle, second_handle]
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user