mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 17:27:08 +08:00
seperate gpu wait
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
a8439e2fd4
commit
00f526f55b
@ -37,7 +37,8 @@ class UBatchContext:
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = self
|
||||
self._wait()
|
||||
self._cpu_wait()
|
||||
self.gpu_stream_wait()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
@ -53,9 +54,29 @@ class UBatchContext:
|
||||
torch.cuda.set_stream(self.stream)
|
||||
forward_context._forward_context = self.forward_context
|
||||
|
||||
def yield_(self):
|
||||
# Seperate GPU wait so we can do
|
||||
# ubatch0
|
||||
# 1) work
|
||||
# 2) dispatch
|
||||
# 3) yield
|
||||
# ubatch1
|
||||
# 1) work
|
||||
# 2) gpu wait
|
||||
# 3) dispatch
|
||||
# 4) yield
|
||||
#
|
||||
# This way we can have the CPU schedule ubatch1-dispatch while ubatch0
|
||||
# before yielding back to ubatch1 but ensure we wont start the dispatch
|
||||
# until ubatch0-dispatch is done avoiding overlapping dispatches that
|
||||
# might share underlying buffers
|
||||
def gpu_stream_wait(self):
|
||||
self.stream.wait_event(self.gpu_wait_event)
|
||||
|
||||
def yield_(self, gpu_wait: bool = True):
|
||||
self._signal()
|
||||
self._wait()
|
||||
self._cpu_wait()
|
||||
if gpu_wait:
|
||||
self.gpu_stream_wait()
|
||||
|
||||
def _signal(self):
|
||||
# Wait for the next batch to signal back
|
||||
@ -63,10 +84,9 @@ class UBatchContext:
|
||||
# Signal that this batch reached the barrier
|
||||
self.cpu_signal_event.set()
|
||||
|
||||
def _wait(self):
|
||||
def _cpu_wait(self):
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self.stream.wait_event(self.gpu_wait_event)
|
||||
self._restore_context()
|
||||
|
||||
_CURRENT_CONTEXT: dict = {}
|
||||
@ -121,7 +141,7 @@ def make_ubatch_context_chain(
|
||||
ctxs.append(ctx)
|
||||
|
||||
def start_hook(from_stream: torch.cuda.Stream):
|
||||
ctxs[0].cpu_wait_event.set()
|
||||
ctxs[0].gpu_wait_event.record(from_stream)
|
||||
ctxs[0].gpu_wait_event.record(from_stream)
|
||||
ctxs[0].cpu_wait_event.set()
|
||||
|
||||
return ctxs, start_hook
|
||||
Loading…
x
Reference in New Issue
Block a user