mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 06:07:03 +08:00
reintegrate full cudagraphs
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
78228a67ce
commit
af68574e3d
@ -47,6 +47,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
print("SHOULDNT BE HERE DURING CAPTURE")
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
|
||||
@ -75,15 +75,15 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
n = num_splits.size(0)
|
||||
# logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
|
||||
if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1] // 2:
|
||||
if self.runner.full_cuda_graph:
|
||||
n = num_splits.size(0)
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_tile_scheduler_metadata is None:
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
if self.cg_buf_num_splits is None:
|
||||
self.cg_buf_num_splits = num_splits
|
||||
else:
|
||||
assert self.cg_buf_num_splits is not None
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
elif n <= self.cg_buf_num_splits.size(0):
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
|
||||
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
||||
assert (self.cg_buf_tile_scheduler_metadata.size() ==
|
||||
|
||||
@ -221,7 +221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_cuda_graph = (self.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager)
|
||||
# self.use_cuda_graph = True
|
||||
self.use_cuda_graph = True
|
||||
logger.info(f"self.use_cuda_graph {self.use_cuda_graph}")
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
# The convention is different.
|
||||
@ -1691,89 +1691,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.set_stream(root_stream)
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
return torch.cat(sorted_results, dim=0)
|
||||
|
||||
# def _run_for_real(input_ids,
|
||||
# positions,
|
||||
# intermediate_tensors,
|
||||
# input_embeds,
|
||||
# attn_metadata,
|
||||
# num_tokens_across_dp,
|
||||
# token_slices,
|
||||
# skip_cuda_graphs):
|
||||
# # run micro-batched
|
||||
# if len(token_slices) > 1:
|
||||
# assert len(token_slices) == 2
|
||||
# # num_tokens = ubatch_slices[1][1].stop
|
||||
# # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
||||
# # assert not is_dummy_run
|
||||
# model_output = _run_ubatches(token_slices,
|
||||
# attn_metadata,
|
||||
# is_dummy_run,
|
||||
# num_tokens_across_dp=num_tokens_across_dp)
|
||||
# # run single batch
|
||||
# else:
|
||||
# # print("RUN NORMAL")
|
||||
# num_tokens = token_slices[0].stop - token_slices[0].start
|
||||
# if num_tokens == 0:
|
||||
# num_tokens = 1
|
||||
# model_output = _run(
|
||||
# token_slices[0],
|
||||
# set_forward_context(attn_metadata,
|
||||
# vllm_config=self.vllm_config,
|
||||
# num_tokens=num_tokens,
|
||||
# num_tokens_across_dp=num_tokens_across_dp,
|
||||
# skip_cuda_graphs=skip_cuda_graphs),
|
||||
# is_dummy_run)
|
||||
# return model_output
|
||||
|
||||
# num_tokens = token_slice[0].stop - token_slice[1].start
|
||||
# # We have multiple sets of inputs here which is a bummer.
|
||||
# # We'll need to pass around a bunch of lists which super sucks.
|
||||
# input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
# model_inputs(token_slice, use_dummy_input)
|
||||
# if build_cuda_graph and num_tokens not in self.cudagraphs:
|
||||
# print(f"Capturing for {num_tokens}")
|
||||
# # assert use_dummy_input
|
||||
# using_ubaching = ubatch_slices is not None
|
||||
# assert using_ubaching
|
||||
# self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph(),
|
||||
# using_ubatching=using_ubaching)
|
||||
# with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph):
|
||||
# # TODO (Sage) I assume we can just get these before calling this function
|
||||
# # Args to delete:
|
||||
# # attn_metadata
|
||||
# # skip_cudagraphs
|
||||
# model_output = self._run_for_real(
|
||||
# input_ids=input_ids,
|
||||
# positions=positions,
|
||||
# intermediate_tensors=intermediate_tensors,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# attn_metadata=attn_metadata,
|
||||
# num_tokens_across_dp=num_tokens_across_dp,
|
||||
# skip_cuda_graphs=skip_cuda_graphs
|
||||
# )
|
||||
# self.cudagraphs[num_tokens].outputs = model_output
|
||||
# elif num_tokens in self.cudagraphs and not skip_cuda_graphs:
|
||||
# logger.info("GRAPH REPLAY")
|
||||
# assert self.cudagraphs[num_tokens].using_ubatching == ubatch_slices is not None
|
||||
# self.cudagraphs[num_tokens].cudagraph.replay()
|
||||
# model_output = self.cudagraphs[num_tokens].outputs
|
||||
# else:
|
||||
# # TODO (Sage) We need to figure out how to move some of this context management
|
||||
# # logic outside of the graph capture
|
||||
# model_output = self._run_for_real(
|
||||
# input_ids=input_ids,
|
||||
# positions=positions,
|
||||
# intermediate_tensors=intermediate_tensors,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# attn_metadata=attn_metadata,
|
||||
# num_tokens_across_dp=num_tokens_across_dp,
|
||||
# skip_cuda_graphs=skip_cuda_graphs
|
||||
# )
|
||||
# run micro-batched
|
||||
if ubatch_slices is not None:
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
@ -1786,25 +1708,63 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
is_dummy_run=is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp
|
||||
)
|
||||
model_output = _run_ubatches(ubatch_metadata)
|
||||
if num_scheduled_tokens not in self.cudagraphs \
|
||||
and not skip_cuda_graphs and build_cuda_graph:
|
||||
# DO capture
|
||||
self.cudagraphs[num_scheduled_tokens] = \
|
||||
CUDAGraphMetaData(
|
||||
cudagraph=torch.cuda.CUDAGraph(),
|
||||
using_ubatching=True
|
||||
)
|
||||
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
|
||||
model_output = _run_ubatches(ubatch_metadata)
|
||||
self.cudagraphs[num_scheduled_tokens].outputs = model_output
|
||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
||||
elif num_scheduled_tokens in self.cudagraphs:
|
||||
self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
|
||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
||||
else:
|
||||
return _run_ubatches(ubatch_metadata)
|
||||
# run single batch
|
||||
else:
|
||||
# print("RUN NORMAL")
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
|
||||
model_output = _run(
|
||||
context = set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs),
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors
|
||||
)
|
||||
|
||||
return model_output
|
||||
if num_scheduled_tokens not in self.cudagraphs \
|
||||
and not skip_cuda_graphs and build_cuda_graph:
|
||||
self.cudagraphs[num_scheduled_tokens] = \
|
||||
CUDAGraphMetaData(
|
||||
cudagraph=torch.cuda.CUDAGraph(),
|
||||
using_ubatching=False
|
||||
)
|
||||
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
|
||||
model_output = _run(
|
||||
context = set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs),
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors
|
||||
)
|
||||
self.cudagraphs[num_scheduled_tokens].outputs = model_output
|
||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
||||
elif num_scheduled_tokens in self.cudagraphs:
|
||||
self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
|
||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
||||
else:
|
||||
return _run(
|
||||
context = set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs),
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user