mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 15:47:09 +08:00
fix pplx a2a
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
5f4a501b9a
commit
e080e068ed
@ -100,13 +100,23 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, self.rank, self.world_size)
|
||||
|
||||
self.handle_cache = Cache()
|
||||
# self.handle_cache = Cache()
|
||||
self.handle_caches = [Cache(), Cache()]
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx
|
||||
return self.handle_cache.get_or_create(
|
||||
return self.handle_caches[0].get_or_create(
|
||||
kwargs, pplx.AllToAll.internode
|
||||
if self.internode else pplx.AllToAll.intranode)
|
||||
|
||||
def get_handles(self, kwargs):
|
||||
import pplx_kernels as pplx
|
||||
first_handle = self.handle_caches[0].get_or_create(kwargs, pplx.AllToAll.internode
|
||||
if self.internode else pplx.AllToAll.intranode)
|
||||
second_handle = self.handle_caches[1].get_or_create(kwargs, pplx.AllToAll.internode
|
||||
if self.internode else pplx.AllToAll.intranode)
|
||||
return [first_handle, second_handle]
|
||||
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
@ -116,9 +126,10 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
for handle_cache in self.handle_caches:
|
||||
with handle_cache._lock:
|
||||
for _, handle in handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||
|
||||
@ -272,10 +272,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
group_name=all2all_manager.cpu_group.group_name,
|
||||
)
|
||||
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
handles = all2all_manager.get_handles(all_to_all_args)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
handle,
|
||||
handles,
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
world_size=all2all_manager.world_size,
|
||||
rank=all2all_manager.rank,
|
||||
|
||||
@ -45,7 +45,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
assert False
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
ubatch_ctx = get_current_ubatch_context()
|
||||
@ -128,10 +127,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
dispatch(True) # Send
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
||||
dispatch(False) # Recv
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
torch.cuda.synchronize()
|
||||
@ -145,7 +144,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
assert False
|
||||
num_tokens = output.size(0) # M
|
||||
# This argument is optional
|
||||
# There's not much point setting this unless it is != topk_ids.size(0)
|
||||
@ -177,9 +175,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
combine(True)
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
||||
combine(False)
|
||||
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -1363,10 +1363,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
|
||||
use_dummy_input):
|
||||
print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
||||
|
||||
if save_results:
|
||||
results.append(model_output)
|
||||
print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
|
||||
def _run_ubatches(ubatch_slices, attn_metadata,
|
||||
is_dummy_run) -> torch.Tensor:
|
||||
|
||||
@ -8,6 +8,7 @@ from torch.library import custom_op
|
||||
|
||||
from vllm import forward_context
|
||||
from vllm.utils import current_stream
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
|
||||
|
||||
class UBatchContext:
|
||||
@ -69,28 +70,28 @@ class UBatchContext:
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
|
||||
def ctx_valid_state(self):
|
||||
# assert forward_context._forward_context == self.forward_context
|
||||
# assert current_stream() == self.current_stream
|
||||
# assert not self.cpu_wait_event.is_set()
|
||||
assert forward_context._forward_context == self.forward_context
|
||||
assert current_stream() == self.current_stream
|
||||
assert not self.cpu_wait_event.is_set()
|
||||
pass
|
||||
|
||||
def _signal_comm_done(self):
|
||||
# self.ctx_valid_state()
|
||||
self.ctx_valid_state()
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
|
||||
def _signal_compute_done(self):
|
||||
# self.ctx_valid_state()
|
||||
self.ctx_valid_state()
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
|
||||
def _wait_compute_done(self):
|
||||
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
||||
# self.ctx_valid_state()
|
||||
self.ctx_valid_state()
|
||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||
# print("Compute stream done", flush=True)
|
||||
|
||||
def _wait_comm_done(self):
|
||||
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
||||
# self.ctx_valid_state()
|
||||
self.ctx_valid_state()
|
||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||
# print("Comm stream done", flush=True)
|
||||
|
||||
@ -104,42 +105,42 @@ class UBatchContext:
|
||||
|
||||
def _cpu_yield(self):
|
||||
# print(f"UBatchContext: {self.id} yielding CPU", flush=True)
|
||||
# self.ctx_valid_state()
|
||||
self.ctx_valid_state()
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
# self.ctx_valid_state()
|
||||
self.ctx_valid_state()
|
||||
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert current_stream() == self.compute_stream
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Yield and switch from {self.stream_string()}", flush=True)
|
||||
# self.ctx_valid_state()
|
||||
dp_rank = get_dp_group().rank_in_group
|
||||
print(f"DP: {dp_rank} UB: {self.id} "
|
||||
f"Yield and switch from {self.stream_string()}", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self._signal_compute_done()
|
||||
self._cpu_yield()
|
||||
# self.ctx_valid_state()
|
||||
self.ctx_valid_state()
|
||||
assert self.current_stream == self.compute_stream
|
||||
self.update_stream(self.comm_stream)
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Resuming on stream {self.stream_string()}", flush=True)
|
||||
print(f"DP: {dp_rank} UB: {self.id} "
|
||||
f"Resuming on stream {self.stream_string()}", flush=True)
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
assert current_stream() == self.comm_stream
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Yield and switch from {self.stream_string()}", flush=True)
|
||||
# self.ctx_valid_state()
|
||||
dp_rank = get_dp_group().rank_in_group
|
||||
print(f"DP: {dp_rank} UB: {self.id} "
|
||||
f"Yield and switch from {self.stream_string()}", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
# self.ctx_valid_state()
|
||||
self.ctx_valid_state()
|
||||
assert self.current_stream == self.comm_stream
|
||||
self.update_stream(self.compute_stream)
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Resuming on stream {self.stream_string()}", flush=True)
|
||||
print(f"DP: {dp_rank} UB: {self.id} "
|
||||
f"Resuming on stream {self.stream_string()}", flush=True)
|
||||
self._wait_comm_done()
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user