more fixes

This commit is contained in:
Sage Moore 2025-05-30 21:17:06 +00:00
parent 5b0249b86e
commit 62da375465
8 changed files with 23 additions and 16 deletions

View File

@ -33,6 +33,7 @@ from time import sleep
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import get_open_port, FlexibleArgumentParser from vllm.utils import get_open_port, FlexibleArgumentParser
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
import torch
def parse_args(): def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
@ -82,7 +83,7 @@ def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] * 100 ]
# with DP, each rank should process different prompts. # with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset, # usually all the DP ranks process a full dataset,

View File

@ -972,7 +972,9 @@ def pplx_finalize():
logger.debug("PPLX NVSHMEM finalize") logger.debug("PPLX NVSHMEM finalize")
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
_all_to_all_cache) _all_to_all_cache)
_all_to_all_cache.destroy() for cache in _all_to_all_cache:
cache.destroy()
# _all_to_all_cache.destroy()
nvshmem_finalize() nvshmem_finalize()

View File

@ -453,7 +453,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if isinstance(prepare_finalize, if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
print("BatchedTritonExperts %s", self.moe) # print("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts( experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE, max_num_tokens=MOE_DP_CHUNK_SIZE,
@ -466,7 +466,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
block_shape=None, block_shape=None,
) )
else: else:
print("TritonExperts %s", self.moe) # print("TritonExperts %s", self.moe)
experts = TritonExperts( experts = TritonExperts(
use_fp8_w8a8=False, use_fp8_w8a8=False,
use_int8_w8a8=False, use_int8_w8a8=False,

View File

@ -135,7 +135,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True) # print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default") yield_and_switch_from_comm_to_compute_impl(schedule="default")
torch.cuda.synchronize()
return expert_x, expert_x_scale, expert_num_tokens return expert_x, expert_x_scale, expert_num_tokens
def finalize( def finalize(
@ -179,6 +179,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True) # print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False) combine(False)
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True) # print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default") yield_and_switch_from_comm_to_compute_impl(schedule="default")
# torch.cuda.synchronize()

View File

@ -481,7 +481,7 @@ class MLACommonMetadataBuilder(Generic[M]):
device = self.runner.device device = self.runner.device
block_table = self.block_table block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[req_slice] block_table_tensor = block_table.get_device_tensor()[req_slice]
print(f"num_reqs: {num_reqs} bloc_table_shape: {block_table_tensor.shape}") # print(f"num_reqs: {num_reqs} bloc_table_shape: {block_table_tensor.shape}")
slot_mapping = block_table.slot_mapping_cpu[token_slice].to( slot_mapping = block_table.slot_mapping_cpu[token_slice].to(
device, non_blocking=True).long() device, non_blocking=True).long()

View File

@ -1301,16 +1301,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_ids, positions, inputs_embeds, intermediate_tensors = \ input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(token_slice, use_dummy_input) model_inputs(token_slice, use_dummy_input)
with context: with context:
if isinstance(context, UBatchContext): # if isinstance(context, UBatchContext):
print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape} use_dummy_input {use_dummy_input} token_slice {token_slice}") # print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape} use_dummy_input {use_dummy_input} token_slice {token_slice}")
model_output = self.model( model_output = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
if isinstance(context, UBatchContext): # if isinstance(context, UBatchContext):
print(f"Ran ubatch {context.id}putput {model_output.shape}") # print(f"Ran ubatch {context.id}putput {model_output.shape}")
if isinstance(context, UBatchContext): if isinstance(context, UBatchContext):
# Clone before we leave the ubatch context # Clone before we leave the ubatch context
model_output = model_output.clone() model_output = model_output.clone()
@ -1342,7 +1342,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
print("ubatch", i, "tokens slice", tokens_slice, "is dummy ubatch", is_dummy_ubatch, "is dummy run", is_dummy_run) # print("ubatch", i, "tokens slice", tokens_slice, "is dummy ubatch", is_dummy_ubatch, "is dummy run", is_dummy_run)
num_tokens = num_dummy_tokens if is_dummy_ubatch or is_dummy_run else (tokens_slice.stop - tokens_slice.start) num_tokens = num_dummy_tokens if is_dummy_ubatch or is_dummy_run else (tokens_slice.stop - tokens_slice.start)
ubatch_ctxs[i].forward_context = create_forward_context( ubatch_ctxs[i].forward_context = create_forward_context(
@ -1363,6 +1363,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for thread in ubatch_threads: for thread in ubatch_threads:
thread.join() thread.join()
torch.cuda.synchronize()
torch.cuda.set_stream(root_stream) torch.cuda.set_stream(root_stream)
return torch.cat(results, dim=0) return torch.cat(results, dim=0)

View File

@ -289,8 +289,10 @@ class Worker(WorkerBase):
if self.profiler is None: if self.profiler is None:
raise RuntimeError("Profiler is not enabled.") raise RuntimeError("Profiler is not enabled.")
if is_start: if is_start:
assert False
self.profiler.start() self.profiler.start()
else: else:
assert False
self.profiler.stop() self.profiler.stop()
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:

View File

@ -70,9 +70,10 @@ class UBatchContext:
torch.cuda.set_stream(self.current_stream) torch.cuda.set_stream(self.current_stream)
def ctx_valid_state(self): def ctx_valid_state(self):
assert forward_context._forward_context == self.forward_context # assert forward_context._forward_context == self.forward_context
assert current_stream() == self.current_stream # assert current_stream() == self.current_stream
assert not self.cpu_wait_event.is_set() # assert not self.cpu_wait_event.is_set()
pass
def _signal_comm_done(self): def _signal_comm_done(self):
self.ctx_valid_state() self.ctx_valid_state()
self.gpu_comm_done_event.record(self.comm_stream) self.gpu_comm_done_event.record(self.comm_stream)