mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 22:57:09 +08:00
factored out some of the context creation code along with misc commeted infra
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
44a2b3494e
commit
e2ba707d64
@ -76,7 +76,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
)
|
||||
|
||||
n = num_splits.size(0)
|
||||
logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
|
||||
# logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
|
||||
if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1] // 2:
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_tile_scheduler_metadata is None:
|
||||
|
||||
@ -95,6 +95,7 @@ import dataclasses
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphMetaData:
|
||||
cudagraph: torch.cuda.CUDAGraph
|
||||
using_ubatching: bool
|
||||
outputs: Optional[Any] = None
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
@ -220,7 +221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_cuda_graph = (self.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager)
|
||||
self.use_cuda_graph = True
|
||||
# self.use_cuda_graph = True
|
||||
logger.info(f"self.use_cuda_graph {self.use_cuda_graph}")
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
# The convention is different.
|
||||
@ -230,7 +231,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
reversed(self.compilation_config.cudagraph_capture_sizes))
|
||||
logger.info(f"cudagraph capture sizes {self.cudagraph_batch_sizes}")
|
||||
self.full_cuda_graph = self.compilation_config.full_cuda_graph
|
||||
self.full_cuda_graph = True
|
||||
# self.full_cuda_graph = True
|
||||
logger.info(f"full_cuda_graph {self.full_cuda_graph}")
|
||||
|
||||
# Cache the device properties.
|
||||
@ -1567,8 +1568,47 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
skip_cuda_graphs: bool = False,
|
||||
build_cuda_graph: bool = False):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UbatchMetadata:
|
||||
ubatch_id: int
|
||||
context: UBatchContext
|
||||
ubatch_slice: UbatchSlice
|
||||
|
||||
input_ids: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
intermediate_tensors: Optional[IntermediateTensors]
|
||||
|
||||
|
||||
num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1
|
||||
|
||||
def _make_ubatch_contexts(ubatch_slices,
|
||||
attn_metadata,
|
||||
is_dummy_run,
|
||||
num_tokens_across_dp) -> list[UBatchContext]:
|
||||
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
||||
compute_stream=current_stream(),
|
||||
device=self.device)
|
||||
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||
assert not is_dummy_ubatch or i == len(
|
||||
ubatch_slices) - 1 or is_dummy_run
|
||||
|
||||
num_tokens = num_dummy_tokens if is_dummy_ubatch or \
|
||||
is_dummy_run else (tokens_slice.stop - tokens_slice.start)
|
||||
# TODO (Sage) Instead of using this setter we should be able
|
||||
# to just create the forward context in advance and pass it
|
||||
# to the UBatchContext's __init__ method
|
||||
ubatch_ctxs[i].forward_context = create_forward_context(
|
||||
attn_metadata[i]
|
||||
if attn_metadata is not None else None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs)
|
||||
return ubatch_ctxs
|
||||
|
||||
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
|
||||
if use_dummy_input:
|
||||
# print("MAKING DUMMY BATCH")
|
||||
@ -1580,41 +1620,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _run(token_slice: slice,
|
||||
context,
|
||||
use_dummy_input: bool = False,
|
||||
build_cuda_graph: bool = False):
|
||||
use_dummy_input: bool = False):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(token_slice, use_dummy_input)
|
||||
with context:
|
||||
# model_output = self.model(
|
||||
# input_ids=input_ids,
|
||||
# positions=positions,
|
||||
# intermediate_tensors=intermediate_tensors,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# )
|
||||
num_tokens = token_slice.stop - token_slice.start
|
||||
if build_cuda_graph and num_tokens not in self.cudagraphs:
|
||||
print(f"Capturing for {num_tokens}")
|
||||
assert use_dummy_input
|
||||
self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph())
|
||||
with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph):
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
self.cudagraphs[num_tokens].outputs = model_output
|
||||
elif num_tokens in self.cudagraphs and not skip_cuda_graphs:
|
||||
logger.info("GRAPH REPLAY")
|
||||
self.cudagraphs[num_tokens].cudagraph.replay()
|
||||
model_output = self.cudagraphs[num_tokens].outputs
|
||||
else:
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if isinstance(context, UBatchContext):
|
||||
# Clone before we leave the ubatch context
|
||||
model_output = model_output.clone()
|
||||
@ -1623,45 +1638,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
|
||||
use_dummy_input, build_cuda_graph):
|
||||
use_dummy_input):
|
||||
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input, build_cuda_graph)
|
||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
||||
|
||||
if save_results:
|
||||
results.append((ubatch_ctx.id, model_output))
|
||||
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
|
||||
def _run_ubatches(ubatch_slices, attn_metadata,
|
||||
is_dummy_run, num_tokens_across_dp, build_cuda_graph) -> torch.Tensor:
|
||||
is_dummy_run, num_tokens_across_dp) -> torch.Tensor:
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
root_stream = current_stream()
|
||||
|
||||
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
||||
compute_stream=root_stream,
|
||||
device=self.device)
|
||||
ubatch_ctxs = _make_ubatch_contexts(ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
is_dummy_run=is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp)
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
# TODO (Sage) Consolidate all of this is_dummy_run
|
||||
# is_dummy_ubatch, is attn_metadata==None, num_tokens==0
|
||||
# nonsense into some unified structure. It's way to hard
|
||||
# to keep track of and keep consistent right now.
|
||||
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||
assert not is_dummy_ubatch or i == len(
|
||||
ubatch_slices) - 1 or is_dummy_run
|
||||
|
||||
num_tokens = num_dummy_tokens if is_dummy_ubatch or \
|
||||
is_dummy_run else (tokens_slice.stop - tokens_slice.start)
|
||||
# if num_tokens_across_dp is None:
|
||||
# print(f"GOING TO CALL AR: {i}")
|
||||
ubatch_ctxs[i].forward_context = create_forward_context(
|
||||
attn_metadata[i]
|
||||
if attn_metadata is not None else None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs)
|
||||
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
ubatch_ctxs[i],
|
||||
@ -1670,8 +1678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
not is_dummy_ubatch
|
||||
or is_dummy_run,
|
||||
is_dummy_ubatch
|
||||
or is_dummy_run,
|
||||
build_cuda_graph
|
||||
or is_dummy_run
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
@ -1685,6 +1692,84 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
return torch.cat(sorted_results, dim=0)
|
||||
|
||||
# def _run_for_real(input_ids,
|
||||
# positions,
|
||||
# intermediate_tensors,
|
||||
# input_embeds,
|
||||
# attn_metadata,
|
||||
# num_tokens_across_dp,
|
||||
# token_slices,
|
||||
# skip_cuda_graphs):
|
||||
# # run micro-batched
|
||||
# if len(token_slices) > 1:
|
||||
# assert len(token_slices) == 2
|
||||
# # num_tokens = ubatch_slices[1][1].stop
|
||||
# # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
||||
# # assert not is_dummy_run
|
||||
# model_output = _run_ubatches(token_slices,
|
||||
# attn_metadata,
|
||||
# is_dummy_run,
|
||||
# num_tokens_across_dp=num_tokens_across_dp)
|
||||
# # run single batch
|
||||
# else:
|
||||
# # print("RUN NORMAL")
|
||||
# num_tokens = token_slices[0].stop - token_slices[0].start
|
||||
# if num_tokens == 0:
|
||||
# num_tokens = 1
|
||||
# model_output = _run(
|
||||
# token_slices[0],
|
||||
# set_forward_context(attn_metadata,
|
||||
# vllm_config=self.vllm_config,
|
||||
# num_tokens=num_tokens,
|
||||
# num_tokens_across_dp=num_tokens_across_dp,
|
||||
# skip_cuda_graphs=skip_cuda_graphs),
|
||||
# is_dummy_run)
|
||||
# return model_output
|
||||
|
||||
# num_tokens = token_slice[0].stop - token_slice[1].start
|
||||
# # We have multiple sets of inputs here which is a bummer.
|
||||
# # We'll need to pass around a bunch of lists which super sucks.
|
||||
# input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
# model_inputs(token_slice, use_dummy_input)
|
||||
# if build_cuda_graph and num_tokens not in self.cudagraphs:
|
||||
# print(f"Capturing for {num_tokens}")
|
||||
# # assert use_dummy_input
|
||||
# using_ubaching = ubatch_slices is not None
|
||||
# assert using_ubaching
|
||||
# self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph(),
|
||||
# using_ubatching=using_ubaching)
|
||||
# with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph):
|
||||
# # TODO (Sage) I assume we can just get these before calling this function
|
||||
# # Args to delete:
|
||||
# # attn_metadata
|
||||
# # skip_cudagraphs
|
||||
# model_output = self._run_for_real(
|
||||
# input_ids=input_ids,
|
||||
# positions=positions,
|
||||
# intermediate_tensors=intermediate_tensors,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# attn_metadata=attn_metadata,
|
||||
# num_tokens_across_dp=num_tokens_across_dp,
|
||||
# skip_cuda_graphs=skip_cuda_graphs
|
||||
# )
|
||||
# self.cudagraphs[num_tokens].outputs = model_output
|
||||
# elif num_tokens in self.cudagraphs and not skip_cuda_graphs:
|
||||
# logger.info("GRAPH REPLAY")
|
||||
# assert self.cudagraphs[num_tokens].using_ubatching == ubatch_slices is not None
|
||||
# self.cudagraphs[num_tokens].cudagraph.replay()
|
||||
# model_output = self.cudagraphs[num_tokens].outputs
|
||||
# else:
|
||||
# # TODO (Sage) We need to figure out how to move some of this context management
|
||||
# # logic outside of the graph capture
|
||||
# model_output = self._run_for_real(
|
||||
# input_ids=input_ids,
|
||||
# positions=positions,
|
||||
# intermediate_tensors=intermediate_tensors,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# attn_metadata=attn_metadata,
|
||||
# num_tokens_across_dp=num_tokens_across_dp,
|
||||
# skip_cuda_graphs=skip_cuda_graphs
|
||||
# )
|
||||
# run micro-batched
|
||||
if ubatch_slices is not None:
|
||||
# num_tokens = ubatch_slices[1][1].stop
|
||||
@ -1693,8 +1778,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
model_output = _run_ubatches(ubatch_slices,
|
||||
attn_metadata,
|
||||
is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
build_cuda_graph=build_cuda_graph)
|
||||
num_tokens_across_dp=num_tokens_across_dp)
|
||||
# run single batch
|
||||
else:
|
||||
# print("RUN NORMAL")
|
||||
@ -1705,8 +1789,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs),
|
||||
is_dummy_run,
|
||||
build_cuda_graph=build_cuda_graph)
|
||||
is_dummy_run)
|
||||
|
||||
return model_output
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user