mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 07:37:51 +08:00
fix pplx a2a
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
5f4a501b9a
commit
e080e068ed
@ -100,14 +100,24 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
|||||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||||
nvshmem_init(uid, self.rank, self.world_size)
|
nvshmem_init(uid, self.rank, self.world_size)
|
||||||
|
|
||||||
self.handle_cache = Cache()
|
# self.handle_cache = Cache()
|
||||||
|
self.handle_caches = [Cache(), Cache()]
|
||||||
|
|
||||||
def get_handle(self, kwargs):
|
def get_handle(self, kwargs):
|
||||||
import pplx_kernels as pplx
|
import pplx_kernels as pplx
|
||||||
return self.handle_cache.get_or_create(
|
return self.handle_caches[0].get_or_create(
|
||||||
kwargs, pplx.AllToAll.internode
|
kwargs, pplx.AllToAll.internode
|
||||||
if self.internode else pplx.AllToAll.intranode)
|
if self.internode else pplx.AllToAll.intranode)
|
||||||
|
|
||||||
|
def get_handles(self, kwargs):
|
||||||
|
import pplx_kernels as pplx
|
||||||
|
first_handle = self.handle_caches[0].get_or_create(kwargs, pplx.AllToAll.internode
|
||||||
|
if self.internode else pplx.AllToAll.intranode)
|
||||||
|
second_handle = self.handle_caches[1].get_or_create(kwargs, pplx.AllToAll.internode
|
||||||
|
if self.internode else pplx.AllToAll.intranode)
|
||||||
|
return [first_handle, second_handle]
|
||||||
|
|
||||||
|
|
||||||
def dispatch(self, hidden_states: torch.Tensor,
|
def dispatch(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -116,9 +126,10 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
with self.handle_cache._lock:
|
for handle_cache in self.handle_caches:
|
||||||
for _, handle in self.handle_cache._cache.items():
|
with handle_cache._lock:
|
||||||
handle.destroy()
|
for _, handle in handle_cache._cache.items():
|
||||||
|
handle.destroy()
|
||||||
|
|
||||||
if self.internode:
|
if self.internode:
|
||||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||||
|
|||||||
@ -272,10 +272,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
group_name=all2all_manager.cpu_group.group_name,
|
group_name=all2all_manager.cpu_group.group_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
handle = all2all_manager.get_handle(all_to_all_args)
|
handles = all2all_manager.get_handles(all_to_all_args)
|
||||||
|
|
||||||
prepare_finalize = PplxPrepareAndFinalize(
|
prepare_finalize = PplxPrepareAndFinalize(
|
||||||
handle,
|
handles,
|
||||||
max_num_tokens=moe.max_num_tokens,
|
max_num_tokens=moe.max_num_tokens,
|
||||||
world_size=all2all_manager.world_size,
|
world_size=all2all_manager.world_size,
|
||||||
rank=all2all_manager.rank,
|
rank=all2all_manager.rank,
|
||||||
|
|||||||
@ -45,7 +45,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
assert False
|
|
||||||
num_tokens = a1.size(0) # M
|
num_tokens = a1.size(0) # M
|
||||||
hidden_dim = a1.size(-1) # K
|
hidden_dim = a1.size(-1) # K
|
||||||
ubatch_ctx = get_current_ubatch_context()
|
ubatch_ctx = get_current_ubatch_context()
|
||||||
@ -128,10 +127,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
dispatch(True) # Send
|
dispatch(True) # Send
|
||||||
# torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
||||||
dispatch(False) # Recv
|
dispatch(False) # Recv
|
||||||
# torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
|
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -145,7 +144,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert False
|
|
||||||
num_tokens = output.size(0) # M
|
num_tokens = output.size(0) # M
|
||||||
# This argument is optional
|
# This argument is optional
|
||||||
# There's not much point setting this unless it is != topk_ids.size(0)
|
# There's not much point setting this unless it is != topk_ids.size(0)
|
||||||
@ -177,9 +175,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
combine(True)
|
combine(True)
|
||||||
# torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
||||||
combine(False)
|
combine(False)
|
||||||
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
|
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||||
# torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|||||||
@ -1363,10 +1363,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
|
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
|
||||||
use_dummy_input):
|
use_dummy_input):
|
||||||
|
print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
||||||
|
|
||||||
if save_results:
|
if save_results:
|
||||||
results.append(model_output)
|
results.append(model_output)
|
||||||
|
print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||||
|
|
||||||
def _run_ubatches(ubatch_slices, attn_metadata,
|
def _run_ubatches(ubatch_slices, attn_metadata,
|
||||||
is_dummy_run) -> torch.Tensor:
|
is_dummy_run) -> torch.Tensor:
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from torch.library import custom_op
|
|||||||
|
|
||||||
from vllm import forward_context
|
from vllm import forward_context
|
||||||
from vllm.utils import current_stream
|
from vllm.utils import current_stream
|
||||||
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
|
|
||||||
|
|
||||||
class UBatchContext:
|
class UBatchContext:
|
||||||
@ -69,28 +70,28 @@ class UBatchContext:
|
|||||||
torch.cuda.set_stream(self.current_stream)
|
torch.cuda.set_stream(self.current_stream)
|
||||||
|
|
||||||
def ctx_valid_state(self):
|
def ctx_valid_state(self):
|
||||||
# assert forward_context._forward_context == self.forward_context
|
assert forward_context._forward_context == self.forward_context
|
||||||
# assert current_stream() == self.current_stream
|
assert current_stream() == self.current_stream
|
||||||
# assert not self.cpu_wait_event.is_set()
|
assert not self.cpu_wait_event.is_set()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _signal_comm_done(self):
|
def _signal_comm_done(self):
|
||||||
# self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.gpu_comm_done_event.record(self.comm_stream)
|
self.gpu_comm_done_event.record(self.comm_stream)
|
||||||
|
|
||||||
def _signal_compute_done(self):
|
def _signal_compute_done(self):
|
||||||
# self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.gpu_compute_done_event.record(self.compute_stream)
|
self.gpu_compute_done_event.record(self.compute_stream)
|
||||||
|
|
||||||
def _wait_compute_done(self):
|
def _wait_compute_done(self):
|
||||||
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
||||||
# self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||||
# print("Compute stream done", flush=True)
|
# print("Compute stream done", flush=True)
|
||||||
|
|
||||||
def _wait_comm_done(self):
|
def _wait_comm_done(self):
|
||||||
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
||||||
# self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||||
# print("Comm stream done", flush=True)
|
# print("Comm stream done", flush=True)
|
||||||
|
|
||||||
@ -104,42 +105,42 @@ class UBatchContext:
|
|||||||
|
|
||||||
def _cpu_yield(self):
|
def _cpu_yield(self):
|
||||||
# print(f"UBatchContext: {self.id} yielding CPU", flush=True)
|
# print(f"UBatchContext: {self.id} yielding CPU", flush=True)
|
||||||
# self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.cpu_signal_event.set()
|
self.cpu_signal_event.set()
|
||||||
self.cpu_wait_event.wait()
|
self.cpu_wait_event.wait()
|
||||||
self.cpu_wait_event.clear()
|
self.cpu_wait_event.clear()
|
||||||
self._restore_context()
|
self._restore_context()
|
||||||
# self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
||||||
|
|
||||||
def yield_and_switch_from_compute_to_comm(self):
|
def yield_and_switch_from_compute_to_comm(self):
|
||||||
assert current_stream() == self.compute_stream
|
assert current_stream() == self.compute_stream
|
||||||
# dp_rank = get_dp_group().rank_in_group
|
dp_rank = get_dp_group().rank_in_group
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
# f"Yield and switch from {self.stream_string()}", flush=True)
|
f"Yield and switch from {self.stream_string()}", flush=True)
|
||||||
# self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self._signal_compute_done()
|
self._signal_compute_done()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
# self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
assert self.current_stream == self.compute_stream
|
assert self.current_stream == self.compute_stream
|
||||||
self.update_stream(self.comm_stream)
|
self.update_stream(self.comm_stream)
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
# f"Resuming on stream {self.stream_string()}", flush=True)
|
f"Resuming on stream {self.stream_string()}", flush=True)
|
||||||
self._wait_compute_done()
|
self._wait_compute_done()
|
||||||
|
|
||||||
def yield_and_switch_from_comm_to_compute(self):
|
def yield_and_switch_from_comm_to_compute(self):
|
||||||
assert current_stream() == self.comm_stream
|
assert current_stream() == self.comm_stream
|
||||||
# dp_rank = get_dp_group().rank_in_group
|
dp_rank = get_dp_group().rank_in_group
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
# f"Yield and switch from {self.stream_string()}", flush=True)
|
f"Yield and switch from {self.stream_string()}", flush=True)
|
||||||
# self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self._signal_comm_done()
|
self._signal_comm_done()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
# self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
assert self.current_stream == self.comm_stream
|
assert self.current_stream == self.comm_stream
|
||||||
self.update_stream(self.compute_stream)
|
self.update_stream(self.compute_stream)
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
# f"Resuming on stream {self.stream_string()}", flush=True)
|
f"Resuming on stream {self.stream_string()}", flush=True)
|
||||||
self._wait_comm_done()
|
self._wait_comm_done()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user