mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-23 09:27:04 +08:00
factored out some of the context creation code along with misc commeted infra
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
44a2b3494e
commit
e2ba707d64
@ -76,7 +76,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
n = num_splits.size(0)
|
n = num_splits.size(0)
|
||||||
logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
|
# logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
|
||||||
if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1] // 2:
|
if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1] // 2:
|
||||||
# First time around (CUDAGraph capture), allocate the static buffer
|
# First time around (CUDAGraph capture), allocate the static buffer
|
||||||
if self.cg_buf_tile_scheduler_metadata is None:
|
if self.cg_buf_tile_scheduler_metadata is None:
|
||||||
|
|||||||
@ -95,6 +95,7 @@ import dataclasses
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class CUDAGraphMetaData:
|
class CUDAGraphMetaData:
|
||||||
cudagraph: torch.cuda.CUDAGraph
|
cudagraph: torch.cuda.CUDAGraph
|
||||||
|
using_ubatching: bool
|
||||||
outputs: Optional[Any] = None
|
outputs: Optional[Any] = None
|
||||||
|
|
||||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||||
@ -220,7 +221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.use_cuda_graph = (self.compilation_config.level
|
self.use_cuda_graph = (self.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE
|
== CompilationLevel.PIECEWISE
|
||||||
and not self.model_config.enforce_eager)
|
and not self.model_config.enforce_eager)
|
||||||
self.use_cuda_graph = True
|
# self.use_cuda_graph = True
|
||||||
logger.info(f"self.use_cuda_graph {self.use_cuda_graph}")
|
logger.info(f"self.use_cuda_graph {self.use_cuda_graph}")
|
||||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||||
# The convention is different.
|
# The convention is different.
|
||||||
@ -230,7 +231,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
reversed(self.compilation_config.cudagraph_capture_sizes))
|
reversed(self.compilation_config.cudagraph_capture_sizes))
|
||||||
logger.info(f"cudagraph capture sizes {self.cudagraph_batch_sizes}")
|
logger.info(f"cudagraph capture sizes {self.cudagraph_batch_sizes}")
|
||||||
self.full_cuda_graph = self.compilation_config.full_cuda_graph
|
self.full_cuda_graph = self.compilation_config.full_cuda_graph
|
||||||
self.full_cuda_graph = True
|
# self.full_cuda_graph = True
|
||||||
logger.info(f"full_cuda_graph {self.full_cuda_graph}")
|
logger.info(f"full_cuda_graph {self.full_cuda_graph}")
|
||||||
|
|
||||||
# Cache the device properties.
|
# Cache the device properties.
|
||||||
@ -1567,8 +1568,47 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
skip_cuda_graphs: bool = False,
|
skip_cuda_graphs: bool = False,
|
||||||
build_cuda_graph: bool = False):
|
build_cuda_graph: bool = False):
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class UbatchMetadata:
|
||||||
|
ubatch_id: int
|
||||||
|
context: UBatchContext
|
||||||
|
ubatch_slice: UbatchSlice
|
||||||
|
|
||||||
|
input_ids: torch.Tensor
|
||||||
|
positions: torch.Tensor
|
||||||
|
inputs_embeds: Optional[torch.Tensor]
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors]
|
||||||
|
|
||||||
|
|
||||||
num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1
|
num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1
|
||||||
|
|
||||||
|
def _make_ubatch_contexts(ubatch_slices,
|
||||||
|
attn_metadata,
|
||||||
|
is_dummy_run,
|
||||||
|
num_tokens_across_dp) -> list[UBatchContext]:
|
||||||
|
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
||||||
|
compute_stream=current_stream(),
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
|
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||||
|
assert not is_dummy_ubatch or i == len(
|
||||||
|
ubatch_slices) - 1 or is_dummy_run
|
||||||
|
|
||||||
|
num_tokens = num_dummy_tokens if is_dummy_ubatch or \
|
||||||
|
is_dummy_run else (tokens_slice.stop - tokens_slice.start)
|
||||||
|
# TODO (Sage) Instead of using this setter we should be able
|
||||||
|
# to just create the forward context in advance and pass it
|
||||||
|
# to the UBatchContext's __init__ method
|
||||||
|
ubatch_ctxs[i].forward_context = create_forward_context(
|
||||||
|
attn_metadata[i]
|
||||||
|
if attn_metadata is not None else None,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
skip_cuda_graphs=skip_cuda_graphs)
|
||||||
|
return ubatch_ctxs
|
||||||
|
|
||||||
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
|
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
|
||||||
if use_dummy_input:
|
if use_dummy_input:
|
||||||
# print("MAKING DUMMY BATCH")
|
# print("MAKING DUMMY BATCH")
|
||||||
@ -1580,41 +1620,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
def _run(token_slice: slice,
|
def _run(token_slice: slice,
|
||||||
context,
|
context,
|
||||||
use_dummy_input: bool = False,
|
use_dummy_input: bool = False):
|
||||||
build_cuda_graph: bool = False):
|
|
||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
model_inputs(token_slice, use_dummy_input)
|
model_inputs(token_slice, use_dummy_input)
|
||||||
with context:
|
with context:
|
||||||
# model_output = self.model(
|
model_output = self.model(
|
||||||
# input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
# positions=positions,
|
positions=positions,
|
||||||
# intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
# inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
# )
|
)
|
||||||
num_tokens = token_slice.stop - token_slice.start
|
|
||||||
if build_cuda_graph and num_tokens not in self.cudagraphs:
|
|
||||||
print(f"Capturing for {num_tokens}")
|
|
||||||
assert use_dummy_input
|
|
||||||
self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph())
|
|
||||||
with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph):
|
|
||||||
model_output = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
)
|
|
||||||
self.cudagraphs[num_tokens].outputs = model_output
|
|
||||||
elif num_tokens in self.cudagraphs and not skip_cuda_graphs:
|
|
||||||
logger.info("GRAPH REPLAY")
|
|
||||||
self.cudagraphs[num_tokens].cudagraph.replay()
|
|
||||||
model_output = self.cudagraphs[num_tokens].outputs
|
|
||||||
else:
|
|
||||||
model_output = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
)
|
|
||||||
if isinstance(context, UBatchContext):
|
if isinstance(context, UBatchContext):
|
||||||
# Clone before we leave the ubatch context
|
# Clone before we leave the ubatch context
|
||||||
model_output = model_output.clone()
|
model_output = model_output.clone()
|
||||||
@ -1623,45 +1638,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
|
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
|
||||||
use_dummy_input, build_cuda_graph):
|
use_dummy_input):
|
||||||
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input, build_cuda_graph)
|
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
||||||
|
|
||||||
if save_results:
|
if save_results:
|
||||||
results.append((ubatch_ctx.id, model_output))
|
results.append((ubatch_ctx.id, model_output))
|
||||||
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||||
|
|
||||||
def _run_ubatches(ubatch_slices, attn_metadata,
|
def _run_ubatches(ubatch_slices, attn_metadata,
|
||||||
is_dummy_run, num_tokens_across_dp, build_cuda_graph) -> torch.Tensor:
|
is_dummy_run, num_tokens_across_dp) -> torch.Tensor:
|
||||||
results: list[tuple[int, torch.Tensor]] = []
|
results: list[tuple[int, torch.Tensor]] = []
|
||||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||||
root_stream = current_stream()
|
root_stream = current_stream()
|
||||||
|
|
||||||
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
ubatch_ctxs = _make_ubatch_contexts(ubatch_slices=ubatch_slices,
|
||||||
compute_stream=root_stream,
|
attn_metadata=attn_metadata,
|
||||||
device=self.device)
|
is_dummy_run=is_dummy_run,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp)
|
||||||
|
|
||||||
# Ubatches will manually manage the forward context, so we override
|
# Ubatches will manually manage the forward context, so we override
|
||||||
# it to None here so we can have it restored correctly later
|
# it to None here so we can have it restored correctly later
|
||||||
with override_forward_context(None):
|
with override_forward_context(None):
|
||||||
ubatch_threads = []
|
ubatch_threads = []
|
||||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
|
# TODO (Sage) Consolidate all of this is_dummy_run
|
||||||
|
# is_dummy_ubatch, is attn_metadata==None, num_tokens==0
|
||||||
|
# nonsense into some unified structure. It's way to hard
|
||||||
|
# to keep track of and keep consistent right now.
|
||||||
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||||
assert not is_dummy_ubatch or i == len(
|
assert not is_dummy_ubatch or i == len(
|
||||||
ubatch_slices) - 1 or is_dummy_run
|
ubatch_slices) - 1 or is_dummy_run
|
||||||
|
|
||||||
num_tokens = num_dummy_tokens if is_dummy_ubatch or \
|
|
||||||
is_dummy_run else (tokens_slice.stop - tokens_slice.start)
|
|
||||||
# if num_tokens_across_dp is None:
|
|
||||||
# print(f"GOING TO CALL AR: {i}")
|
|
||||||
ubatch_ctxs[i].forward_context = create_forward_context(
|
|
||||||
attn_metadata[i]
|
|
||||||
if attn_metadata is not None else None,
|
|
||||||
self.vllm_config,
|
|
||||||
num_tokens=num_tokens,
|
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
|
||||||
skip_cuda_graphs=skip_cuda_graphs)
|
|
||||||
|
|
||||||
thread = threading.Thread(target=_ubatch_thread,
|
thread = threading.Thread(target=_ubatch_thread,
|
||||||
args=(
|
args=(
|
||||||
ubatch_ctxs[i],
|
ubatch_ctxs[i],
|
||||||
@ -1670,8 +1678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
not is_dummy_ubatch
|
not is_dummy_ubatch
|
||||||
or is_dummy_run,
|
or is_dummy_run,
|
||||||
is_dummy_ubatch
|
is_dummy_ubatch
|
||||||
or is_dummy_run,
|
or is_dummy_run
|
||||||
build_cuda_graph
|
|
||||||
))
|
))
|
||||||
ubatch_threads.append(thread)
|
ubatch_threads.append(thread)
|
||||||
thread.start()
|
thread.start()
|
||||||
@ -1685,6 +1692,84 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
sorted_results = [value for position, value in sorted(results)]
|
sorted_results = [value for position, value in sorted(results)]
|
||||||
return torch.cat(sorted_results, dim=0)
|
return torch.cat(sorted_results, dim=0)
|
||||||
|
|
||||||
|
# def _run_for_real(input_ids,
|
||||||
|
# positions,
|
||||||
|
# intermediate_tensors,
|
||||||
|
# input_embeds,
|
||||||
|
# attn_metadata,
|
||||||
|
# num_tokens_across_dp,
|
||||||
|
# token_slices,
|
||||||
|
# skip_cuda_graphs):
|
||||||
|
# # run micro-batched
|
||||||
|
# if len(token_slices) > 1:
|
||||||
|
# assert len(token_slices) == 2
|
||||||
|
# # num_tokens = ubatch_slices[1][1].stop
|
||||||
|
# # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
||||||
|
# # assert not is_dummy_run
|
||||||
|
# model_output = _run_ubatches(token_slices,
|
||||||
|
# attn_metadata,
|
||||||
|
# is_dummy_run,
|
||||||
|
# num_tokens_across_dp=num_tokens_across_dp)
|
||||||
|
# # run single batch
|
||||||
|
# else:
|
||||||
|
# # print("RUN NORMAL")
|
||||||
|
# num_tokens = token_slices[0].stop - token_slices[0].start
|
||||||
|
# if num_tokens == 0:
|
||||||
|
# num_tokens = 1
|
||||||
|
# model_output = _run(
|
||||||
|
# token_slices[0],
|
||||||
|
# set_forward_context(attn_metadata,
|
||||||
|
# vllm_config=self.vllm_config,
|
||||||
|
# num_tokens=num_tokens,
|
||||||
|
# num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
# skip_cuda_graphs=skip_cuda_graphs),
|
||||||
|
# is_dummy_run)
|
||||||
|
# return model_output
|
||||||
|
|
||||||
|
# num_tokens = token_slice[0].stop - token_slice[1].start
|
||||||
|
# # We have multiple sets of inputs here which is a bummer.
|
||||||
|
# # We'll need to pass around a bunch of lists which super sucks.
|
||||||
|
# input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
|
# model_inputs(token_slice, use_dummy_input)
|
||||||
|
# if build_cuda_graph and num_tokens not in self.cudagraphs:
|
||||||
|
# print(f"Capturing for {num_tokens}")
|
||||||
|
# # assert use_dummy_input
|
||||||
|
# using_ubaching = ubatch_slices is not None
|
||||||
|
# assert using_ubaching
|
||||||
|
# self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph(),
|
||||||
|
# using_ubatching=using_ubaching)
|
||||||
|
# with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph):
|
||||||
|
# # TODO (Sage) I assume we can just get these before calling this function
|
||||||
|
# # Args to delete:
|
||||||
|
# # attn_metadata
|
||||||
|
# # skip_cudagraphs
|
||||||
|
# model_output = self._run_for_real(
|
||||||
|
# input_ids=input_ids,
|
||||||
|
# positions=positions,
|
||||||
|
# intermediate_tensors=intermediate_tensors,
|
||||||
|
# inputs_embeds=inputs_embeds,
|
||||||
|
# attn_metadata=attn_metadata,
|
||||||
|
# num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
# skip_cuda_graphs=skip_cuda_graphs
|
||||||
|
# )
|
||||||
|
# self.cudagraphs[num_tokens].outputs = model_output
|
||||||
|
# elif num_tokens in self.cudagraphs and not skip_cuda_graphs:
|
||||||
|
# logger.info("GRAPH REPLAY")
|
||||||
|
# assert self.cudagraphs[num_tokens].using_ubatching == ubatch_slices is not None
|
||||||
|
# self.cudagraphs[num_tokens].cudagraph.replay()
|
||||||
|
# model_output = self.cudagraphs[num_tokens].outputs
|
||||||
|
# else:
|
||||||
|
# # TODO (Sage) We need to figure out how to move some of this context management
|
||||||
|
# # logic outside of the graph capture
|
||||||
|
# model_output = self._run_for_real(
|
||||||
|
# input_ids=input_ids,
|
||||||
|
# positions=positions,
|
||||||
|
# intermediate_tensors=intermediate_tensors,
|
||||||
|
# inputs_embeds=inputs_embeds,
|
||||||
|
# attn_metadata=attn_metadata,
|
||||||
|
# num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
# skip_cuda_graphs=skip_cuda_graphs
|
||||||
|
# )
|
||||||
# run micro-batched
|
# run micro-batched
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
# num_tokens = ubatch_slices[1][1].stop
|
# num_tokens = ubatch_slices[1][1].stop
|
||||||
@ -1693,8 +1778,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
model_output = _run_ubatches(ubatch_slices,
|
model_output = _run_ubatches(ubatch_slices,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
is_dummy_run,
|
is_dummy_run,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp)
|
||||||
build_cuda_graph=build_cuda_graph)
|
|
||||||
# run single batch
|
# run single batch
|
||||||
else:
|
else:
|
||||||
# print("RUN NORMAL")
|
# print("RUN NORMAL")
|
||||||
@ -1705,8 +1789,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_tokens=num_scheduled_tokens or 1,
|
num_tokens=num_scheduled_tokens or 1,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
skip_cuda_graphs=skip_cuda_graphs),
|
skip_cuda_graphs=skip_cuda_graphs),
|
||||||
is_dummy_run,
|
is_dummy_run)
|
||||||
build_cuda_graph=build_cuda_graph)
|
|
||||||
|
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user