mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:35:01 +08:00
[bugfix] fix weak ref in piecewise cudagraph and tractable test (#10048)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
235366fe2e
commit
ca9844b340
@ -1,6 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Test the piecewise compilation with a simple model, comparing the output
|
Test the piecewise compilation with a simple model, comparing the output
|
||||||
with and without the piecewise compilation.
|
with and without the piecewise compilation.
|
||||||
|
|
||||||
|
This is a tractable model, the weights and computation are specially designed
|
||||||
|
if the config `tractable_init` is set to True. Otherwise, the weights are
|
||||||
|
initialized randomly with a fixed seed.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -49,6 +53,12 @@ class LlamaConfig:
|
|||||||
mlp_size: int = 256
|
mlp_size: int = 256
|
||||||
vocab_size: int = 128
|
vocab_size: int = 128
|
||||||
num_layers: int = 2
|
num_layers: int = 2
|
||||||
|
init_value: float = 1.0
|
||||||
|
tractable_init: bool = False
|
||||||
|
random_seed: int = 0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.mlp_size >= self.hidden_size
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
@ -66,10 +76,23 @@ class LlamaMLP(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate_up_projection.weight.data.fill_(0.0)
|
if config.tractable_init:
|
||||||
self.down_projection.weight.data.fill_(0.0)
|
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
|
||||||
|
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
|
||||||
|
nn.init.eye_(self.down_projection.weight.data)
|
||||||
|
else:
|
||||||
|
nn.init.xavier_normal_(self.gate_up_projection.weight.data,
|
||||||
|
generator=torch.Generator().manual_seed(
|
||||||
|
config.random_seed),
|
||||||
|
gain=0.001)
|
||||||
|
nn.init.xavier_normal_(self.down_projection.weight.data,
|
||||||
|
generator=torch.Generator().manual_seed(
|
||||||
|
config.random_seed),
|
||||||
|
gain=0.001)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
# for tractable_init and positive input, this is
|
||||||
|
# essentially an elementwise-square
|
||||||
x = self.gate_up_projection(x)
|
x = self.gate_up_projection(x)
|
||||||
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
|
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
|
||||||
x[:, x.size(1) // 2:])
|
x[:, x.size(1) // 2:])
|
||||||
@ -84,21 +107,39 @@ class LlamaAttention(nn.Module):
|
|||||||
self.qkv_projection = nn.Linear(
|
self.qkv_projection = nn.Linear(
|
||||||
in_features=config.hidden_size,
|
in_features=config.hidden_size,
|
||||||
out_features=config.hidden_size * 3,
|
out_features=config.hidden_size * 3,
|
||||||
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.output_projection = nn.Linear(
|
self.output_projection = nn.Linear(
|
||||||
in_features=config.hidden_size,
|
in_features=config.hidden_size,
|
||||||
out_features=config.hidden_size,
|
out_features=config.hidden_size,
|
||||||
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qkv_projection.weight.data.fill_(0.0)
|
if config.tractable_init:
|
||||||
self.output_projection.weight.data.fill_(0.0)
|
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
|
||||||
|
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
|
||||||
|
config.hidden_size])
|
||||||
|
nn.init.eye_(self.qkv_projection.weight.data[2 *
|
||||||
|
config.hidden_size:])
|
||||||
|
nn.init.eye_(self.output_projection.weight.data)
|
||||||
|
else:
|
||||||
|
nn.init.xavier_normal_(self.qkv_projection.weight.data,
|
||||||
|
generator=torch.Generator().manual_seed(
|
||||||
|
config.random_seed),
|
||||||
|
gain=0.001)
|
||||||
|
nn.init.xavier_normal_(self.output_projection.weight.data,
|
||||||
|
generator=torch.Generator().manual_seed(
|
||||||
|
config.random_seed),
|
||||||
|
gain=0.001)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# for tractable_init, this is:
|
||||||
|
# output = (hidden_states * 3 + positions * 2)
|
||||||
qkv = self.qkv_projection(hidden_states)
|
qkv = self.qkv_projection(hidden_states)
|
||||||
hidden_size = qkv.size(-1) // 3
|
hidden_size = qkv.size(-1) // 3
|
||||||
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
|
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
|
||||||
@ -126,20 +167,29 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
For tractable computation:
|
||||||
|
- if residual is None, the outputs are:
|
||||||
|
- residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
|
||||||
|
- hidden_states = (residual + 1) ** 2
|
||||||
|
- if residual is not None, the outputs are:
|
||||||
|
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
|
||||||
|
- hidden_states = (residual + 1) ** 2
|
||||||
|
""" # noqa
|
||||||
if residual is None:
|
if residual is None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = hidden_states / 2
|
hidden_states = hidden_states + 1
|
||||||
else:
|
else:
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = hidden_states / 2
|
hidden_states = hidden_states + 1
|
||||||
|
|
||||||
hidden_states = self.self_attention(positions=positions,
|
hidden_states = self.self_attention(positions=positions,
|
||||||
hidden_states=hidden_states)
|
hidden_states=hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = hidden_states / 2
|
hidden_states = hidden_states + 1
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
@ -156,7 +206,8 @@ class LlamaModel(nn.Module):
|
|||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
|
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
|
||||||
|
|
||||||
self.embedding_tokens.weight.data.fill_(0.0)
|
# this is the initial value of the hidden states
|
||||||
|
self.embedding_tokens.weight.data.fill_(config.init_value)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -170,6 +221,28 @@ class LlamaModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def tractable_computation(input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
config: LlamaConfig,
|
||||||
|
init_value: float = 1.0) -> torch.Tensor:
|
||||||
|
hidden_states = torch.ones(input_ids.size(0),
|
||||||
|
config.hidden_size,
|
||||||
|
device=input_ids.device,
|
||||||
|
dtype=input_ids.dtype) * init_value
|
||||||
|
|
||||||
|
# first layer
|
||||||
|
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||||
|
hidden_states = (residual + 1)**2
|
||||||
|
|
||||||
|
# following layers
|
||||||
|
for _ in range(config.num_layers - 1):
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||||
|
hidden_states = (residual + 1)**2
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(llama_config,
|
def run_model(llama_config,
|
||||||
use_compile: bool,
|
use_compile: bool,
|
||||||
@ -213,7 +286,15 @@ def run_model(llama_config,
|
|||||||
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
|
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
|
||||||
set_compilation_config(None)
|
set_compilation_config(None)
|
||||||
|
|
||||||
return output.cpu()
|
output = output.cpu()
|
||||||
|
|
||||||
|
if llama_config.tractable_init:
|
||||||
|
expected_output = tractable_computation(input_ids[:2], positions[:2],
|
||||||
|
llama_config).cpu()
|
||||||
|
|
||||||
|
assert torch.allclose(output, expected_output)
|
||||||
|
else:
|
||||||
|
return output.cpu()
|
||||||
|
|
||||||
|
|
||||||
def test_toy_llama():
|
def test_toy_llama():
|
||||||
@ -222,7 +303,13 @@ def test_toy_llama():
|
|||||||
llama_config = LlamaConfig(hidden_size=128,
|
llama_config = LlamaConfig(hidden_size=128,
|
||||||
mlp_size=256,
|
mlp_size=256,
|
||||||
vocab_size=128,
|
vocab_size=128,
|
||||||
num_layers=2)
|
num_layers=12)
|
||||||
|
|
||||||
|
tractable_config = LlamaConfig(hidden_size=128,
|
||||||
|
mlp_size=256,
|
||||||
|
vocab_size=128,
|
||||||
|
num_layers=2,
|
||||||
|
tractable_init=True)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
@ -233,6 +320,8 @@ def test_toy_llama():
|
|||||||
num_cudagraph_caputured=0,
|
num_cudagraph_caputured=0,
|
||||||
):
|
):
|
||||||
outputs.append(run_model(llama_config, use_compile=False))
|
outputs.append(run_model(llama_config, use_compile=False))
|
||||||
|
run_model(tractable_config, use_compile=False)
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
num_piecewise_graphs_seen=1,
|
num_piecewise_graphs_seen=1,
|
||||||
@ -242,6 +331,7 @@ def test_toy_llama():
|
|||||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
outputs.append(run_model(llama_config, use_compile=True))
|
outputs.append(run_model(llama_config, use_compile=True))
|
||||||
|
run_model(tractable_config, use_compile=True)
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
@ -257,6 +347,7 @@ def test_toy_llama():
|
|||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(
|
||||||
run_model(llama_config, use_compile=True, split_attn=True))
|
run_model(llama_config, use_compile=True, split_attn=True))
|
||||||
|
run_model(tractable_config, use_compile=True, split_attn=True)
|
||||||
|
|
||||||
for i in range(1, len(outputs)):
|
for i in range(1, len(outputs)):
|
||||||
assert torch.allclose(outputs[0], outputs[i])
|
assert torch.allclose(outputs[0], outputs[i])
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import weak_ref_tensors
|
from vllm.utils import weak_ref_tensors
|
||||||
|
|
||||||
@ -193,6 +194,7 @@ def wrap_inductor(graph,
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class SplitItem:
|
class SplitItem:
|
||||||
submod_name: str
|
submod_name: str
|
||||||
|
graph_id: int
|
||||||
is_splitting_graph: bool
|
is_splitting_graph: bool
|
||||||
graph: fx.GraphModule
|
graph: fx.GraphModule
|
||||||
|
|
||||||
@ -226,9 +228,7 @@ def split_graph(graph: fx.GraphModule,
|
|||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
# sort the names to make sure the order is deterministic
|
|
||||||
names = [name for (name, module) in split_gm.named_modules()]
|
names = [name for (name, module) in split_gm.named_modules()]
|
||||||
names.sort()
|
|
||||||
|
|
||||||
for name in names:
|
for name in names:
|
||||||
if "." in name or name == "":
|
if "." in name or name == "":
|
||||||
@ -238,7 +238,11 @@ def split_graph(graph: fx.GraphModule,
|
|||||||
module = getattr(split_gm, name)
|
module = getattr(split_gm, name)
|
||||||
|
|
||||||
graph_id = int(name.replace("submod_", ""))
|
graph_id = int(name.replace("submod_", ""))
|
||||||
outputs.append(SplitItem(name, graph_id in split_op_graphs, module))
|
outputs.append(
|
||||||
|
SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
||||||
|
|
||||||
|
# sort by intetger graph_id, rather than string name
|
||||||
|
outputs.sort(key=lambda x: x.graph_id)
|
||||||
|
|
||||||
return split_gm, outputs
|
return split_gm, outputs
|
||||||
|
|
||||||
@ -252,6 +256,11 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
It runs the given graph with fake inputs, and compile some
|
It runs the given graph with fake inputs, and compile some
|
||||||
submodules specified by `compile_submod_names` with the given
|
submodules specified by `compile_submod_names` with the given
|
||||||
compilation configs.
|
compilation configs.
|
||||||
|
|
||||||
|
NOTE: the order in `compile_submod_names` matters, because
|
||||||
|
it will be used to determine the order of the compiled piecewise
|
||||||
|
graphs. The first graph will handle logging, and the last graph
|
||||||
|
has some special cudagraph output handling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, module: torch.fx.GraphModule,
|
def __init__(self, module: torch.fx.GraphModule,
|
||||||
@ -263,7 +272,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
self.compile_submod_names = compile_submod_names
|
self.compile_submod_names = compile_submod_names
|
||||||
self.compilation_configs = compilation_configs
|
self.compilation_configs = compilation_configs
|
||||||
self.graph_pool = graph_pool
|
self.graph_pool = graph_pool
|
||||||
self.have_seen_first_graph = False
|
|
||||||
|
|
||||||
def run(self, *args):
|
def run(self, *args):
|
||||||
fake_args = [
|
fake_args = [
|
||||||
@ -279,6 +287,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
output = super().call_module(target, args, kwargs)
|
output = super().call_module(target, args, kwargs)
|
||||||
|
|
||||||
if target in self.compile_submod_names:
|
if target in self.compile_submod_names:
|
||||||
|
index = self.compile_submod_names.index(target)
|
||||||
submod = self.fetch_attr(target)
|
submod = self.fetch_attr(target)
|
||||||
sym_shape_indices = [
|
sym_shape_indices = [
|
||||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||||
@ -288,15 +297,14 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
args,
|
args,
|
||||||
self.compilation_configs.inductor_compile_config,
|
self.compilation_configs.inductor_compile_config,
|
||||||
runtime_shape=None,
|
runtime_shape=None,
|
||||||
do_logging=not self.have_seen_first_graph,
|
do_logging=index == 0,
|
||||||
use_inductor=self.compilation_configs.use_inductor)
|
use_inductor=self.compilation_configs.use_inductor)
|
||||||
|
|
||||||
self.module.__dict__[target] = PiecewiseBackend(
|
self.module.__dict__[target] = PiecewiseBackend(
|
||||||
submod, self.compilation_configs, self.graph_pool,
|
submod, self.compilation_configs, self.graph_pool, index,
|
||||||
not self.have_seen_first_graph, sym_shape_indices,
|
len(self.compile_submod_names), sym_shape_indices,
|
||||||
compiled_graph_for_general_shape)
|
compiled_graph_for_general_shape)
|
||||||
|
|
||||||
self.have_seen_first_graph = True
|
|
||||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@ -352,8 +360,9 @@ class VllmBackend:
|
|||||||
graph, self.compilation_configs.non_cudagraph_ops)
|
graph, self.compilation_configs.non_cudagraph_ops)
|
||||||
|
|
||||||
from torch._dynamo.utils import lazy_format_graph_code
|
from torch._dynamo.utils import lazy_format_graph_code
|
||||||
logger.debug("%s",
|
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
|
||||||
lazy_format_graph_code("stiching module", self.split_gm))
|
logger.debug("%s", lazy_format_graph_code("after split",
|
||||||
|
self.split_gm))
|
||||||
|
|
||||||
compilation_counter.num_piecewise_graphs_seen += len(
|
compilation_counter.num_piecewise_graphs_seen += len(
|
||||||
self.piecewise_graphs)
|
self.piecewise_graphs)
|
||||||
@ -385,12 +394,17 @@ class ConcreteSizeEntry:
|
|||||||
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||||
output: Optional[Any] = None
|
output: Optional[Any] = None
|
||||||
|
|
||||||
|
# for cudagraph debugging, track the input addresses
|
||||||
|
# during capture, and check if they are the same during replay
|
||||||
|
input_addresses: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
class PiecewiseBackend:
|
class PiecewiseBackend:
|
||||||
|
|
||||||
def __init__(self, graph: fx.GraphModule,
|
def __init__(self, graph: fx.GraphModule,
|
||||||
compilation_configs: CompilationConfig, graph_pool: Any,
|
compilation_configs: CompilationConfig, graph_pool: Any,
|
||||||
is_first_graph: bool, sym_shape_indices: List[int],
|
piecewise_compile_index: int, total_piecewise_compiles: int,
|
||||||
|
sym_shape_indices: List[int],
|
||||||
compiled_graph_for_general_shape: Callable):
|
compiled_graph_for_general_shape: Callable):
|
||||||
"""
|
"""
|
||||||
The backend for piecewise compilation.
|
The backend for piecewise compilation.
|
||||||
@ -408,7 +422,12 @@ class PiecewiseBackend:
|
|||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.compilation_configs = compilation_configs
|
self.compilation_configs = compilation_configs
|
||||||
self.graph_pool = graph_pool
|
self.graph_pool = graph_pool
|
||||||
self.is_first_graph = is_first_graph
|
self.piecewise_compile_index = piecewise_compile_index
|
||||||
|
self.total_piecewise_compiles = total_piecewise_compiles
|
||||||
|
|
||||||
|
self.is_first_graph = piecewise_compile_index == 0
|
||||||
|
self.is_last_graph = (
|
||||||
|
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||||
|
|
||||||
self.compile_sizes: Set[int] = set(
|
self.compile_sizes: Set[int] = set(
|
||||||
self.compilation_configs.compile_sizes)
|
self.compilation_configs.compile_sizes)
|
||||||
@ -422,6 +441,8 @@ class PiecewiseBackend:
|
|||||||
|
|
||||||
self.sym_shape_indices = sym_shape_indices
|
self.sym_shape_indices = sym_shape_indices
|
||||||
|
|
||||||
|
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||||
|
|
||||||
# the entries for different shapes that we need to either
|
# the entries for different shapes that we need to either
|
||||||
# compile or capture cudagraph
|
# compile or capture cudagraph
|
||||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||||
@ -476,14 +497,45 @@ class PiecewiseBackend:
|
|||||||
logger.info("Capturing a cudagraph for shape %s",
|
logger.info("Capturing a cudagraph for shape %s",
|
||||||
runtime_shape)
|
runtime_shape)
|
||||||
|
|
||||||
|
input_addresses = [
|
||||||
|
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||||
|
]
|
||||||
|
entry.input_addresses = input_addresses
|
||||||
cudagraph = torch.cuda.CUDAGraph()
|
cudagraph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
|
# mind-exploding: carefully manage the reference and memory.
|
||||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||||
entry.output = weak_ref_tensors(entry.runnable(*args))
|
# `output` is managed by pytorch's cudagraph pool
|
||||||
|
output = entry.runnable(*args)
|
||||||
|
if self.is_last_graph:
|
||||||
|
# by converting it to weak ref,
|
||||||
|
# the original `output` will immediately be released
|
||||||
|
# to save memory. It is only safe to do this for
|
||||||
|
# the last graph, because the output of the last graph
|
||||||
|
# will not be used by any other cuda graph.
|
||||||
|
output = weak_ref_tensors(output)
|
||||||
|
|
||||||
|
# here we always use weak ref for the output
|
||||||
|
# to save memory
|
||||||
|
entry.output = weak_ref_tensors(output)
|
||||||
|
entry.cudagraph = cudagraph
|
||||||
|
|
||||||
compilation_counter.num_cudagraph_caputured += 1
|
compilation_counter.num_cudagraph_caputured += 1
|
||||||
|
|
||||||
entry.cudagraph = cudagraph
|
# important: we need to return the output, rather than
|
||||||
return entry.output
|
# the weak ref of the output, so that pytorch can correctly
|
||||||
|
# manage the memory during cuda graph capture
|
||||||
|
return output
|
||||||
|
|
||||||
|
if self.is_debugging_mode:
|
||||||
|
# check if the input addresses are the same
|
||||||
|
new_input_addresses = [
|
||||||
|
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||||
|
]
|
||||||
|
assert new_input_addresses == entry.input_addresses, (
|
||||||
|
"Input addresses for cudagraphs are different during replay."
|
||||||
|
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
||||||
|
)
|
||||||
|
|
||||||
entry.cudagraph.replay()
|
entry.cudagraph.replay()
|
||||||
return entry.output
|
return entry.output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user