mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 22:37:09 +08:00
wip
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
18bf91e6a8
commit
2dc3b8b0a2
@ -1300,12 +1300,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(token_slice, use_dummy_input)
|
||||
with context:
|
||||
if isinstance(context, UBatchContext):
|
||||
print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape}")
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if isinstance(context, UBatchContext):
|
||||
print(f"Ran ubatch {context.id}putput {model_output.shape}")
|
||||
if isinstance(context, UBatchContext):
|
||||
# Clone before we leave the ubatch context
|
||||
model_output = model_output.clone()
|
||||
|
||||
@ -22,6 +22,7 @@ class UBatchContext:
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_wait_event: torch.cuda.Event,
|
||||
gpu_signal_event: torch.cuda.Event,
|
||||
gpu_wait_on_launch: bool = False,
|
||||
schedule="default"):
|
||||
self.id = id
|
||||
self.stream = stream
|
||||
@ -33,18 +34,27 @@ class UBatchContext:
|
||||
self.gpu_signal_event = gpu_signal_event
|
||||
self.schedule = schedule
|
||||
self.done_event = torch.cuda.Event()
|
||||
self.gpu_wait_on_launch = gpu_wait_on_launch
|
||||
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = self
|
||||
self._cpu_wait()
|
||||
self.gpu_stream_wait()
|
||||
start_event = torch.cuda.Event()
|
||||
self.original_stream.record_event(start_event)
|
||||
self.stream.wait_event(start_event)
|
||||
print("Starting ubatch %d" % self.id)
|
||||
if self.gpu_wait_on_launch:
|
||||
self.gpu_stream_wait()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = None
|
||||
torch.cuda.set_stream(self.original_stream)
|
||||
print("Finishing ubatch %d" % self.id)
|
||||
self._signal()
|
||||
self._signal()
|
||||
self._signal()
|
||||
return False
|
||||
|
||||
@ -72,9 +82,11 @@ class UBatchContext:
|
||||
def gpu_stream_wait(self):
|
||||
self.stream.wait_event(self.gpu_wait_event)
|
||||
|
||||
def yield_(self, gpu_wait: bool = True):
|
||||
def _yield(self, gpu_wait: bool = True):
|
||||
print("Yielding ubatch %d" % self.id)
|
||||
self._signal()
|
||||
self._cpu_wait()
|
||||
print("Resuming ubatch %d" % self.id)
|
||||
if gpu_wait:
|
||||
self.gpu_stream_wait()
|
||||
|
||||
@ -91,11 +103,19 @@ class UBatchContext:
|
||||
|
||||
_CURRENT_CONTEXT: dict = {}
|
||||
|
||||
def yield_impl(schedule="default"):
|
||||
def get_current_ubatch_context() -> Optional[UBatchContext]:
|
||||
global _CURRENT_CONTEXT
|
||||
"""
|
||||
Get the current UBatchContext for the current thread.
|
||||
"""
|
||||
return _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
||||
|
||||
def yield_impl(schedule="default", gpu_wait: bool = True):
|
||||
# Perform the barrier if a context exists for this thread
|
||||
ctx = _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
||||
if ctx is not None and ctx.schedule == schedule:
|
||||
ctx.yield_()
|
||||
ctx = get_current_ubatch_context()
|
||||
print("you are in yield_impl", ctx)
|
||||
if ctx is not None:
|
||||
ctx._yield(gpu_wait=gpu_wait)
|
||||
|
||||
|
||||
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
|
||||
@ -137,6 +157,7 @@ def make_ubatch_context_chain(
|
||||
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
|
||||
gpu_wait_event=gpu_events[i],
|
||||
gpu_signal_event=gpu_events[(i + 1) % num_micro_batches],
|
||||
gpu_wait_on_launch=(i > 0),
|
||||
)
|
||||
ctxs.append(ctx)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user