misc fixes

This commit is contained in:
Sage Moore 2025-05-29 00:09:25 +00:00
parent f0b66d6929
commit 5cc573e791
4 changed files with 52 additions and 17 deletions

View File

@ -67,18 +67,18 @@ def _dump_engine_exception(config: VllmConfig,
scheduler_stats: Optional[SchedulerStats]):
logger.error("Dumping input data")
logger.error(
"V1 LLM engine (v%s) with config: %s, ",
VLLM_VERSION,
config,
)
# logger.error(
# "V1 LLM engine (v%s) with config: %s, ",
# VLLM_VERSION,
# config,
# )
try:
dump_obj = prepare_object_to_dump(scheduler_output)
logger.error("Dumping scheduler output for model execution:")
logger.error(dump_obj)
if scheduler_stats:
logger.error(scheduler_stats)
except BaseException as exception:
logger.error("Error preparing object to dump")
logger.error(repr(exception))
# try:
# dump_obj = prepare_object_to_dump(scheduler_output)
# logger.error("Dumping scheduler output for model execution:")
# logger.error(dump_obj)
# if scheduler_stats:
# logger.error(scheduler_stats)
# except BaseException as exception:
# logger.error("Error preparing object to dump")
# logger.error(repr(exception))

View File

@ -481,6 +481,7 @@ class MLACommonMetadataBuilder(Generic[M]):
device = self.runner.device
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[req_slice]
print(f"num_reqs: {num_reqs} bloc_table_shape: {block_table_tensor.shape}")
slot_mapping = block_table.slot_mapping_cpu[token_slice].to(
device, non_blocking=True).long()

View File

@ -1358,6 +1358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
))
ubatch_threads.append(thread)
thread.start()
ubatch_ctxs[0].cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()

View File

@ -32,6 +32,7 @@ class UBatchContext:
self.forward_context = None #fwd_ctx
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
self.current_stream = compute_stream
self.gpu_comm_done_event = gpu_comm_done_event
self.gpu_compute_done_event = gpu_compute_done_event
self.schedule = schedule
@ -39,6 +40,10 @@ class UBatchContext:
def __enter__(self):
global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = self
# self.cpu_wait_event.clear()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
# Assume we start on the compute stream
assert current_stream() == self.compute_stream, \
@ -50,42 +55,70 @@ class UBatchContext:
_CURRENT_CONTEXT[threading.get_ident()] = None
print("Finishing ubatch %d\n" % self.id)
self.cpu_signal_event.set()
torch.cuda.set_stream(self.compute_stream)
self.cpu_wait_event.clear()
self.current_stream = self.compute_stream
torch.cuda.set_stream(self.original_stream)
return False
def _restore_context(self):
forward_context._forward_context = self.forward_context
torch.cuda.set_stream(self.current_stream)
def update_stream(self, stream):
self.current_stream = stream
torch.cuda.set_stream(self.current_stream)
def ctx_valid_state(self):
assert forward_context._forward_context == self.forward_context
assert current_stream() == self.current_stream
assert not self.cpu_wait_event.is_set()
def _signal_comm_done(self):
self.ctx_valid_state()
self.gpu_comm_done_event.record(self.comm_stream)
def _signal_compute_done(self):
self.ctx_valid_state()
self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self):
print("Waiting on compute stream")
self.ctx_valid_state()
self.comm_stream.wait_event(self.gpu_compute_done_event)
print("Compute stream done")
def _wait_comm_done(self):
print("Waiting on comm stream")
self.ctx_valid_state()
self.compute_stream.wait_event(self.gpu_comm_done_event)
print("Comm stream done")
def _cpu_yield(self):
print("UBatchContext: %d yielding CPU\n" % self.id)
self.ctx_valid_state()
self.cpu_signal_event.set()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
self.ctx_valid_state()
print("UBatchContext: %d resuming CPU\n" % self.id)
def yield_and_switch_from_compute_to_comm(self):
print("Yield and switch from compute")
self.ctx_valid_state()
self._signal_compute_done()
self._cpu_yield()
torch.cuda.set_stream(self.comm_stream)
self.ctx_valid_state()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
self.ctx_valid_state()
self._signal_comm_done()
self._cpu_yield()
torch.cuda.set_stream(self.compute_stream)
self.ctx_valid_state()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
self._wait_comm_done()