mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 01:17:03 +08:00
added initial code for cuda graph capturing ubatches
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
930efd02ab
commit
96c0c4ea66
@ -1649,7 +1649,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if ubatch_slices is not None:
|
||||
# num_tokens = ubatch_slices[1][1].stop
|
||||
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
||||
assert not is_dummy_run
|
||||
# assert not is_dummy_run
|
||||
model_output = _run_ubatches(ubatch_slices, attn_metadata,
|
||||
is_dummy_run, num_tokens_across_dp=num_tokens_across_dp)
|
||||
# run single batch
|
||||
@ -2224,10 +2224,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
allow_microbatching: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
should_microbatch = False
|
||||
if allow_microbatching:
|
||||
logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN")
|
||||
|
||||
|
||||
# TODO(Sage) We need some more code to properly handle
|
||||
# mixing normal and dummy runs. The DP padding needs to
|
||||
# be properly setup. Since we only support microbatching
|
||||
# in cuda graph capture it's fine to ignore the DP padding
|
||||
# for now.
|
||||
should_ubatch = num_tokens >= \
|
||||
self.parallel_config.microbatching_token_threshold and \
|
||||
allow_microbatching
|
||||
# _dummy_run doesn't go through _prepare_inputs so
|
||||
# we synchronize with other DP ranks here
|
||||
self.should_ubatch(should_microbatch)
|
||||
should_ubatch = self.should_ubatch(allow_microbatching)
|
||||
assert not should_ubatch
|
||||
# Padding for DP
|
||||
# logger.info("PADDING DUMMY")
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||
@ -2278,19 +2290,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
# should_microbatch = (
|
||||
# allow_microbatching
|
||||
# and self.vllm_config.parallel_config.enable_microbatching
|
||||
# and self.vllm_config.parallel_config.always_microbatch_if_enabled)
|
||||
# dummy_microbatches = [(slice(0, 0), slice(0, 0)),
|
||||
# (slice(0, 0), slice(0, 0))]
|
||||
dummy_microbatches = None
|
||||
# We currently only microbatch if the number of tokens is
|
||||
# over a certain threshold.
|
||||
if should_ubatch:
|
||||
assert num_tokens % 2 == 0
|
||||
# TODO (Sage) Add actual slices here
|
||||
assert False
|
||||
dummy_microbatches = [(slice(0, 0), slice(0, 0)),
|
||||
(slice(0, 0), slice(0, 0))]
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
outputs = self._run_model(
|
||||
attn_metadata,
|
||||
num_tokens,
|
||||
ubatch_slices=None,
|
||||
ubatch_slices=dummy_microbatches,
|
||||
is_dummy_run=True,
|
||||
num_tokens_across_dp=num_tokens_across_dp
|
||||
)
|
||||
@ -2488,8 +2503,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
total=len(self.cudagraph_batch_sizes)):
|
||||
for _ in range(
|
||||
self.compilation_config.cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg)
|
||||
self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg)
|
||||
self._dummy_run(num_tokens,
|
||||
capture_attn_cudagraph=full_cg,
|
||||
allow_microbatching=allow_microbatching)
|
||||
self._dummy_run(num_tokens,
|
||||
capture_attn_cudagraph=full_cg,
|
||||
allow_microbatching=allow_microbatching)
|
||||
|
||||
logger.info("CAPTURE MODEL END")
|
||||
end_time = time.perf_counter()
|
||||
|
||||
@ -319,7 +319,7 @@ class Worker(WorkerBase):
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
# TODO: adding allow_microbatching will break non-gpu backends
|
||||
self.model_runner._dummy_run(1, allow_microbatching=True)
|
||||
self.model_runner._dummy_run(1)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user