mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 04:08:00 +08:00
[Model Runner V2] Minor fix for cudagraph_utils (#29256)
This commit is contained in:
parent
389aa1b2eb
commit
20ee418adc
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -140,6 +139,7 @@ class CudaGraphManager:
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=batch_size,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
hidden_states = model(
|
||||
@ -148,15 +148,16 @@ class CudaGraphManager:
|
||||
)
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture the graph.
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with (
|
||||
patch("torch.cuda.empty_cache", lambda: None),
|
||||
set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=batch_size,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
@ -183,7 +184,7 @@ class CudaGraphManager:
|
||||
if is_global_first_rank():
|
||||
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
|
||||
|
||||
with freeze_gc(), graph_capture(device=self.device):
|
||||
with graph_capture(device=self.device):
|
||||
for batch_size in sizes_to_capture:
|
||||
self.capture_graph(
|
||||
batch_size,
|
||||
@ -199,13 +200,3 @@ class CudaGraphManager:
|
||||
self.graphs[batch_size].replay()
|
||||
assert self.hidden_states is not None
|
||||
return self.hidden_states[:batch_size]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def freeze_gc():
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
gc.unfreeze()
|
||||
|
||||
@ -298,6 +298,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return 0
|
||||
|
||||
start_time = time.perf_counter()
|
||||
torch.cuda.empty_cache()
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user