[v1][torch.compile] support managing cudagraph buffer (#10203)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
youkaichao 2024-11-11 11:10:27 -08:00 committed by GitHub
parent d7a4f2207b
commit 330e82d34a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 8 deletions

View File

@ -1,4 +1,5 @@
{ {
"use_cudagraph": true, "use_cudagraph": true,
"non_cudagraph_ops": ["silly.attention"] "non_cudagraph_ops": ["silly.attention"],
"cudagraph_copy_inputs": true
} }

View File

@ -80,7 +80,7 @@ def test_simple_piecewise_compile():
config = os.path.join(directory, "piecewise_compilation_config.json") config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
input_buffer = torch.randn(100).cuda() inputs = torch.randn(100).cuda()
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model num_graphs_seen=1, # one graph for the model
@ -92,15 +92,15 @@ def test_simple_piecewise_compile():
): ):
with set_compile_context([1, 2]): with set_compile_context([1, 2]):
model(input_buffer) model(inputs)
model(input_buffer[:2]) model(torch.randn(2).cuda())
model(input_buffer[:1]) model(torch.randn(1).cuda())
input_buffer[:2].zero_() input = torch.zeros(2).cuda()
global global_counter global global_counter
global_counter = 0 global_counter = 0
output = model(input_buffer[:2]) output = model(input)
assert global_counter == 2 assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))

View File

@ -389,6 +389,8 @@ class VllmBackend:
returned_callable: Callable returned_callable: Callable
# Inductor passes to run on the graph pre-defunctionalization # Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable] post_grad_passes: Sequence[Callable]
sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor]
def __init__(self, post_grad_passes: Sequence[Callable] = ()): def __init__(self, post_grad_passes: Sequence[Callable] = ()):
global global_graph_pool global global_graph_pool
@ -401,6 +403,9 @@ class VllmBackend:
self.graph_pool = global_graph_pool self.graph_pool = global_graph_pool
self.post_grad_passes = post_grad_passes self.post_grad_passes = post_grad_passes
self.sym_tensor_indices = []
self.input_buffers = []
# `torch.compile` is JIT compiled, so we don't need to # `torch.compile` is JIT compiled, so we don't need to
# do anything here # do anything here
@ -461,7 +466,46 @@ class VllmBackend:
self._called = True self._called = True
return self.split_gm if not self.compilation_configs.use_cudagraph or \
not self.compilation_configs.cudagraph_copy_inputs:
return self.split_gm
# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()
fake_args = [
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in example_inputs
]
# index of tensors that have symbolic shapes (batch size)
self.sym_tensor_indices = [
i for i, x in enumerate(fake_args)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
]
# compiler managed cudagraph input buffers
# we assume the first run with symbolic shapes
# has the maximum size among all the tensors
self.input_buffers = [
example_inputs[x].clone() for x in self.sym_tensor_indices
]
def copy_and_call(*args):
list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index]
runtime_shape = runtime_tensor.shape[0]
static_tensor = self.input_buffers[i][:runtime_shape]
# copy the tensor to the static buffer
static_tensor.copy_(runtime_tensor)
# replace the tensor in the list_args to the static buffer
list_args[index] = static_tensor
return self.split_gm(*list_args)
return copy_and_call
@dataclasses.dataclass @dataclasses.dataclass

View File

@ -32,6 +32,11 @@ class CompilationConfig(BaseModel):
It means the first several runs will be treated as warmup runs. It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs. cudagraph will be used for subsequent runs.
- cudagraph_copy_inputs: whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- Inductor compilation: - Inductor compilation:
- use_inductor: whether to use inductor compilation. - use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager. - False: inductor compilation is not used. graph runs in eager.
@ -78,6 +83,7 @@ class CompilationConfig(BaseModel):
non_cudagraph_ops: List[str] = Field(default_factory=list) non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0 cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False
dump_graph_stages: List[str] = Field(default_factory=list) dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path(".")) dump_graph_dir: Path = Field(default=Path("."))