mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 18:04:27 +08:00
wip
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
18bf91e6a8
commit
2dc3b8b0a2
@ -1300,12 +1300,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
model_inputs(token_slice, use_dummy_input)
|
model_inputs(token_slice, use_dummy_input)
|
||||||
with context:
|
with context:
|
||||||
|
if isinstance(context, UBatchContext):
|
||||||
|
print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape}")
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
if isinstance(context, UBatchContext):
|
||||||
|
print(f"Ran ubatch {context.id}putput {model_output.shape}")
|
||||||
if isinstance(context, UBatchContext):
|
if isinstance(context, UBatchContext):
|
||||||
# Clone before we leave the ubatch context
|
# Clone before we leave the ubatch context
|
||||||
model_output = model_output.clone()
|
model_output = model_output.clone()
|
||||||
|
|||||||
@ -22,6 +22,7 @@ class UBatchContext:
|
|||||||
cpu_signal_event: threading.Event,
|
cpu_signal_event: threading.Event,
|
||||||
gpu_wait_event: torch.cuda.Event,
|
gpu_wait_event: torch.cuda.Event,
|
||||||
gpu_signal_event: torch.cuda.Event,
|
gpu_signal_event: torch.cuda.Event,
|
||||||
|
gpu_wait_on_launch: bool = False,
|
||||||
schedule="default"):
|
schedule="default"):
|
||||||
self.id = id
|
self.id = id
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
@ -33,18 +34,27 @@ class UBatchContext:
|
|||||||
self.gpu_signal_event = gpu_signal_event
|
self.gpu_signal_event = gpu_signal_event
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
self.done_event = torch.cuda.Event()
|
self.done_event = torch.cuda.Event()
|
||||||
|
self.gpu_wait_on_launch = gpu_wait_on_launch
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
global _CURRENT_CONTEXT
|
global _CURRENT_CONTEXT
|
||||||
_CURRENT_CONTEXT[threading.get_ident()] = self
|
_CURRENT_CONTEXT[threading.get_ident()] = self
|
||||||
self._cpu_wait()
|
self._cpu_wait()
|
||||||
self.gpu_stream_wait()
|
start_event = torch.cuda.Event()
|
||||||
|
self.original_stream.record_event(start_event)
|
||||||
|
self.stream.wait_event(start_event)
|
||||||
|
print("Starting ubatch %d" % self.id)
|
||||||
|
if self.gpu_wait_on_launch:
|
||||||
|
self.gpu_stream_wait()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
global _CURRENT_CONTEXT
|
global _CURRENT_CONTEXT
|
||||||
_CURRENT_CONTEXT[threading.get_ident()] = None
|
_CURRENT_CONTEXT[threading.get_ident()] = None
|
||||||
torch.cuda.set_stream(self.original_stream)
|
torch.cuda.set_stream(self.original_stream)
|
||||||
|
print("Finishing ubatch %d" % self.id)
|
||||||
|
self._signal()
|
||||||
|
self._signal()
|
||||||
self._signal()
|
self._signal()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -72,9 +82,11 @@ class UBatchContext:
|
|||||||
def gpu_stream_wait(self):
|
def gpu_stream_wait(self):
|
||||||
self.stream.wait_event(self.gpu_wait_event)
|
self.stream.wait_event(self.gpu_wait_event)
|
||||||
|
|
||||||
def yield_(self, gpu_wait: bool = True):
|
def _yield(self, gpu_wait: bool = True):
|
||||||
|
print("Yielding ubatch %d" % self.id)
|
||||||
self._signal()
|
self._signal()
|
||||||
self._cpu_wait()
|
self._cpu_wait()
|
||||||
|
print("Resuming ubatch %d" % self.id)
|
||||||
if gpu_wait:
|
if gpu_wait:
|
||||||
self.gpu_stream_wait()
|
self.gpu_stream_wait()
|
||||||
|
|
||||||
@ -91,11 +103,19 @@ class UBatchContext:
|
|||||||
|
|
||||||
_CURRENT_CONTEXT: dict = {}
|
_CURRENT_CONTEXT: dict = {}
|
||||||
|
|
||||||
def yield_impl(schedule="default"):
|
def get_current_ubatch_context() -> Optional[UBatchContext]:
|
||||||
|
global _CURRENT_CONTEXT
|
||||||
|
"""
|
||||||
|
Get the current UBatchContext for the current thread.
|
||||||
|
"""
|
||||||
|
return _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
||||||
|
|
||||||
|
def yield_impl(schedule="default", gpu_wait: bool = True):
|
||||||
# Perform the barrier if a context exists for this thread
|
# Perform the barrier if a context exists for this thread
|
||||||
ctx = _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
ctx = get_current_ubatch_context()
|
||||||
if ctx is not None and ctx.schedule == schedule:
|
print("you are in yield_impl", ctx)
|
||||||
ctx.yield_()
|
if ctx is not None:
|
||||||
|
ctx._yield(gpu_wait=gpu_wait)
|
||||||
|
|
||||||
|
|
||||||
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
|
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
|
||||||
@ -137,6 +157,7 @@ def make_ubatch_context_chain(
|
|||||||
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
|
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
|
||||||
gpu_wait_event=gpu_events[i],
|
gpu_wait_event=gpu_events[i],
|
||||||
gpu_signal_event=gpu_events[(i + 1) % num_micro_batches],
|
gpu_signal_event=gpu_events[(i + 1) % num_micro_batches],
|
||||||
|
gpu_wait_on_launch=(i > 0),
|
||||||
)
|
)
|
||||||
ctxs.append(ctx)
|
ctxs.append(ctx)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user