mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 23:35:44 +08:00
seperate gpu wait
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
a8439e2fd4
commit
00f526f55b
@ -37,7 +37,8 @@ class UBatchContext:
|
|||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
global _CURRENT_CONTEXT
|
global _CURRENT_CONTEXT
|
||||||
_CURRENT_CONTEXT[threading.get_ident()] = self
|
_CURRENT_CONTEXT[threading.get_ident()] = self
|
||||||
self._wait()
|
self._cpu_wait()
|
||||||
|
self.gpu_stream_wait()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
@ -53,9 +54,29 @@ class UBatchContext:
|
|||||||
torch.cuda.set_stream(self.stream)
|
torch.cuda.set_stream(self.stream)
|
||||||
forward_context._forward_context = self.forward_context
|
forward_context._forward_context = self.forward_context
|
||||||
|
|
||||||
def yield_(self):
|
# Seperate GPU wait so we can do
|
||||||
|
# ubatch0
|
||||||
|
# 1) work
|
||||||
|
# 2) dispatch
|
||||||
|
# 3) yield
|
||||||
|
# ubatch1
|
||||||
|
# 1) work
|
||||||
|
# 2) gpu wait
|
||||||
|
# 3) dispatch
|
||||||
|
# 4) yield
|
||||||
|
#
|
||||||
|
# This way we can have the CPU schedule ubatch1-dispatch while ubatch0
|
||||||
|
# before yielding back to ubatch1 but ensure we wont start the dispatch
|
||||||
|
# until ubatch0-dispatch is done avoiding overlapping dispatches that
|
||||||
|
# might share underlying buffers
|
||||||
|
def gpu_stream_wait(self):
|
||||||
|
self.stream.wait_event(self.gpu_wait_event)
|
||||||
|
|
||||||
|
def yield_(self, gpu_wait: bool = True):
|
||||||
self._signal()
|
self._signal()
|
||||||
self._wait()
|
self._cpu_wait()
|
||||||
|
if gpu_wait:
|
||||||
|
self.gpu_stream_wait()
|
||||||
|
|
||||||
def _signal(self):
|
def _signal(self):
|
||||||
# Wait for the next batch to signal back
|
# Wait for the next batch to signal back
|
||||||
@ -63,10 +84,9 @@ class UBatchContext:
|
|||||||
# Signal that this batch reached the barrier
|
# Signal that this batch reached the barrier
|
||||||
self.cpu_signal_event.set()
|
self.cpu_signal_event.set()
|
||||||
|
|
||||||
def _wait(self):
|
def _cpu_wait(self):
|
||||||
self.cpu_wait_event.wait()
|
self.cpu_wait_event.wait()
|
||||||
self.cpu_wait_event.clear()
|
self.cpu_wait_event.clear()
|
||||||
self.stream.wait_event(self.gpu_wait_event)
|
|
||||||
self._restore_context()
|
self._restore_context()
|
||||||
|
|
||||||
_CURRENT_CONTEXT: dict = {}
|
_CURRENT_CONTEXT: dict = {}
|
||||||
@ -121,7 +141,7 @@ def make_ubatch_context_chain(
|
|||||||
ctxs.append(ctx)
|
ctxs.append(ctx)
|
||||||
|
|
||||||
def start_hook(from_stream: torch.cuda.Stream):
|
def start_hook(from_stream: torch.cuda.Stream):
|
||||||
ctxs[0].cpu_wait_event.set()
|
ctxs[0].gpu_wait_event.record(from_stream)
|
||||||
ctxs[0].gpu_wait_event.record(from_stream)
|
ctxs[0].cpu_wait_event.set()
|
||||||
|
|
||||||
return ctxs, start_hook
|
return ctxs, start_hook
|
||||||
Loading…
x
Reference in New Issue
Block a user