mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 12:22:17 +08:00
comments cleanup etc
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
10ca263058
commit
82ae694de6
@ -593,6 +593,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
total_num_scheduled_tokens >= \
|
total_num_scheduled_tokens >= \
|
||||||
self.parallel_config.microbatching_token_threshold \
|
self.parallel_config.microbatching_token_threshold \
|
||||||
and max_num_scheduled_tokens == 1
|
and max_num_scheduled_tokens == 1
|
||||||
|
|
||||||
# Don't microbatch unless every other DP worker is also microbatching
|
# Don't microbatch unless every other DP worker is also microbatching
|
||||||
should_ubatch = self.should_ubatch(should_attempt_ubatching)
|
should_ubatch = self.should_ubatch(should_attempt_ubatching)
|
||||||
if not should_ubatch:
|
if not should_ubatch:
|
||||||
@ -611,20 +612,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
]
|
]
|
||||||
num_pad_tokens = 0
|
num_pad_tokens = 0
|
||||||
num_tokens_after_padding = None
|
num_tokens_after_padding = None
|
||||||
ubatch_bailout = False
|
ubatch_abort = False
|
||||||
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
|
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
|
||||||
if num_pad_tokens > 0:
|
if num_pad_tokens > 0:
|
||||||
|
# Check if the padding would result in an empty second ubatch.
|
||||||
|
# If so abort ubatching
|
||||||
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
|
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
|
||||||
self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens)
|
self.pad_out_ubatch_first_stage(ubatch_slices, num_pad_tokens)
|
||||||
else:
|
else:
|
||||||
# We bail out of ubatching here. This accounts for the case where
|
ubatch_abort = True
|
||||||
# the padding would result in an "empty" second ubatch.
|
|
||||||
# TODO: just make the second ubatch a dummy ubatch
|
|
||||||
ubatch_bailout = True
|
|
||||||
|
|
||||||
# Note that if we are attempting to ubatch by this point then we know that no
|
# Note that if we are attempting to ubatch by this point then we know that no
|
||||||
# DP ranks are doing dummy runs
|
# DP ranks are doing dummy runs. Meaning, we don't need a second call to
|
||||||
should_ubatch = self.should_ubatch(False if ubatch_bailout else True)
|
# should_ubatch in _dummy_run
|
||||||
|
should_ubatch = self.should_ubatch(False if ubatch_abort else True)
|
||||||
if not should_ubatch:
|
if not should_ubatch:
|
||||||
return (None, 0, None)
|
return (None, 0, None)
|
||||||
return (ubatch_slices, num_pad_tokens, num_tokens_after_padding)
|
return (ubatch_slices, num_pad_tokens, num_tokens_after_padding)
|
||||||
|
|||||||
@ -65,26 +65,16 @@ class UBatchContext:
|
|||||||
self.current_stream = stream
|
self.current_stream = stream
|
||||||
torch.cuda.set_stream(self.current_stream)
|
torch.cuda.set_stream(self.current_stream)
|
||||||
|
|
||||||
def ctx_valid_state(self):
|
|
||||||
assert forward_context._forward_context == self.forward_context
|
|
||||||
assert current_stream() == self.current_stream
|
|
||||||
assert not self.cpu_wait_event.is_set()
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _signal_comm_done(self):
|
def _signal_comm_done(self):
|
||||||
self.ctx_valid_state()
|
|
||||||
self.gpu_comm_done_event.record(self.comm_stream)
|
self.gpu_comm_done_event.record(self.comm_stream)
|
||||||
|
|
||||||
def _signal_compute_done(self):
|
def _signal_compute_done(self):
|
||||||
self.ctx_valid_state()
|
|
||||||
self.gpu_compute_done_event.record(self.compute_stream)
|
self.gpu_compute_done_event.record(self.compute_stream)
|
||||||
|
|
||||||
def _wait_compute_done(self):
|
def _wait_compute_done(self):
|
||||||
self.ctx_valid_state()
|
|
||||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||||
|
|
||||||
def _wait_comm_done(self):
|
def _wait_comm_done(self):
|
||||||
self.ctx_valid_state()
|
|
||||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||||
|
|
||||||
def stream_string(self):
|
def stream_string(self):
|
||||||
@ -96,29 +86,30 @@ class UBatchContext:
|
|||||||
return "COMM"
|
return "COMM"
|
||||||
|
|
||||||
def _cpu_yield(self):
|
def _cpu_yield(self):
|
||||||
self.ctx_valid_state()
|
# It is critical for correctness that only one thread is running
|
||||||
|
# at a time. These asserts just make sure that this is the only
|
||||||
|
# thread running before waking the other one up and going to sleep
|
||||||
|
assert forward_context._forward_context == self.forward_context
|
||||||
|
assert current_stream() == self.current_stream
|
||||||
|
assert not self.cpu_wait_event.is_set()
|
||||||
|
|
||||||
self.cpu_signal_event.set()
|
self.cpu_signal_event.set()
|
||||||
self.cpu_wait_event.wait()
|
self.cpu_wait_event.wait()
|
||||||
self.cpu_wait_event.clear()
|
self.cpu_wait_event.clear()
|
||||||
self._restore_context()
|
self._restore_context()
|
||||||
self.ctx_valid_state()
|
|
||||||
|
|
||||||
def yield_and_switch_from_compute_to_comm(self):
|
def yield_and_switch_from_compute_to_comm(self):
|
||||||
assert current_stream() == self.compute_stream
|
assert current_stream() == self.compute_stream
|
||||||
self.ctx_valid_state()
|
|
||||||
self._signal_compute_done()
|
self._signal_compute_done()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
self.ctx_valid_state()
|
|
||||||
assert self.current_stream == self.compute_stream
|
assert self.current_stream == self.compute_stream
|
||||||
self.update_stream(self.comm_stream)
|
self.update_stream(self.comm_stream)
|
||||||
self._wait_compute_done()
|
self._wait_compute_done()
|
||||||
|
|
||||||
def yield_and_switch_from_comm_to_compute(self):
|
def yield_and_switch_from_comm_to_compute(self):
|
||||||
assert current_stream() == self.comm_stream
|
assert current_stream() == self.comm_stream
|
||||||
self.ctx_valid_state()
|
|
||||||
self._signal_comm_done()
|
self._signal_comm_done()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
self.ctx_valid_state()
|
|
||||||
assert self.current_stream == self.comm_stream
|
assert self.current_stream == self.comm_stream
|
||||||
self.update_stream(self.compute_stream)
|
self.update_stream(self.compute_stream)
|
||||||
self._wait_comm_done()
|
self._wait_comm_done()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user