mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 01:37:02 +08:00
cleanup
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
3a41a3dcff
commit
06cc133a63
@ -605,7 +605,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens_after_padding = None
|
||||
ubatch_bailout = False
|
||||
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
|
||||
logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}")
|
||||
if num_pad_tokens > 0:
|
||||
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
|
||||
self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens)
|
||||
@ -613,7 +612,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# We bail out of ubatching here. This accounts for the case where
|
||||
# the padding would result in an "empty" second ubatch.
|
||||
# TODO: just make the second ubatch a dummy ubatch
|
||||
# logger.info("FALLING BACK AND DISABLING UBATCHING")
|
||||
ubatch_bailout = True
|
||||
|
||||
# Note that if we are attempting to ubatch by this point then we know that no
|
||||
@ -1344,7 +1342,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
num_tokens_padded = num_tokens_unpadded
|
||||
|
||||
# logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}")
|
||||
if (self.use_cuda_graph
|
||||
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
@ -1548,7 +1545,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
|
||||
if use_dummy_input:
|
||||
logger.info(f"NUM DUMMY TOKENS: {num_dummy_tokens} token slize: {tokens_slice}")
|
||||
assert num_dummy_tokens == tokens_slice.stop - tokens_slice.start
|
||||
return self._get_dummy_model_inputs(num_dummy_tokens)
|
||||
else:
|
||||
@ -1624,7 +1620,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# run micro-batched
|
||||
if ubatch_slices is not None:
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
print(f"RUNNING UBATCH {ubatch_slices} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
||||
|
||||
compute_stream = torch.cuda.current_stream()
|
||||
ubatch_metadata = _make_ubatch_metadata(
|
||||
@ -1640,7 +1635,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
|
||||
logger.info(f"NORMAL RUN {num_scheduled_tokens}")
|
||||
with set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
@ -1723,9 +1717,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_input_tokens += num_pad_tokens
|
||||
self.pad_out_ubatch_second_stage(ubatch_slices, num_input_tokens)
|
||||
elif ubatch_slices is None:
|
||||
# logger.info("ATTEMPTING TO PAD NORMAL BATCH")
|
||||
num_pad, num_tokens_after_padding = self.get_padding(num_input_tokens)
|
||||
# logger.info("NORMAL BATCH DONE")
|
||||
num_input_tokens += num_pad
|
||||
|
||||
# Some attention backends only support CUDA Graphs in pure decode.
|
||||
@ -1733,7 +1725,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# compiled with full CUDA graphs, we have to skip them entirely.
|
||||
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
|
||||
|
||||
# logger.info("RUNNING MODEL")
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
@ -2347,7 +2338,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# We currently only microbatch if the number of tokens is
|
||||
# over a certain threshold.
|
||||
# logger.info("PADDING DUMMY DONE")
|
||||
attn_metadata: Optional[dict[str, Any]] = None
|
||||
if capture_attn_cudagraph:
|
||||
attn_metadata = {}
|
||||
|
||||
@ -50,7 +50,6 @@ class UBatchContext:
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = None
|
||||
# print("Finishing ubatch %d\n" % self.id, flush=True)
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.clear()
|
||||
self.current_stream = self.compute_stream
|
||||
@ -80,16 +79,12 @@ class UBatchContext:
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
|
||||
def _wait_compute_done(self):
|
||||
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||
# print("Compute stream done", flush=True)
|
||||
|
||||
def _wait_comm_done(self):
|
||||
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||
# print("Comm stream done", flush=True)
|
||||
|
||||
def stream_string(self):
|
||||
if current_stream() == self.compute_stream:
|
||||
@ -100,43 +95,31 @@ class UBatchContext:
|
||||
return "COMM"
|
||||
|
||||
def _cpu_yield(self):
|
||||
# print(f"UBatchContext: {self.id} yielding CPU", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
self.ctx_valid_state()
|
||||
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert current_stream() == self.compute_stream
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Yield and switch from {self.stream_string()}", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self._signal_compute_done()
|
||||
self._cpu_yield()
|
||||
self.ctx_valid_state()
|
||||
assert self.current_stream == self.compute_stream
|
||||
self.update_stream(self.comm_stream)
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Resuming on stream {self.stream_string()}", flush=True)
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
assert current_stream() == self.comm_stream
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Yield and switch from {self.stream_string()}", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
self.ctx_valid_state()
|
||||
assert self.current_stream == self.comm_stream
|
||||
self.update_stream(self.compute_stream)
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Resuming on stream {self.stream_string()}", flush=True)
|
||||
self._wait_comm_done()
|
||||
|
||||
|
||||
@ -154,7 +137,6 @@ def get_current_ubatch_context() -> Optional[UBatchContext]:
|
||||
def yield_and_switch_from_compute_to_comm(schedule="default"):
|
||||
# Perform the barrier if a context exists for this thread
|
||||
ctx = get_current_ubatch_context()
|
||||
#print("you are in yield_impl", ctx)
|
||||
if ctx is not None and ctx.schedule == schedule:
|
||||
ctx.yield_and_switch_from_compute_to_comm()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user