mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 10:29:50 +08:00
[v1][torch.compile] support managing cudagraph buffer (#10203)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
d7a4f2207b
commit
330e82d34a
@ -1,4 +1,5 @@
|
|||||||
{
|
{
|
||||||
"use_cudagraph": true,
|
"use_cudagraph": true,
|
||||||
"non_cudagraph_ops": ["silly.attention"]
|
"non_cudagraph_ops": ["silly.attention"],
|
||||||
|
"cudagraph_copy_inputs": true
|
||||||
}
|
}
|
||||||
@ -80,7 +80,7 @@ def test_simple_piecewise_compile():
|
|||||||
config = os.path.join(directory, "piecewise_compilation_config.json")
|
config = os.path.join(directory, "piecewise_compilation_config.json")
|
||||||
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
||||||
|
|
||||||
input_buffer = torch.randn(100).cuda()
|
inputs = torch.randn(100).cuda()
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
@ -92,15 +92,15 @@ def test_simple_piecewise_compile():
|
|||||||
):
|
):
|
||||||
|
|
||||||
with set_compile_context([1, 2]):
|
with set_compile_context([1, 2]):
|
||||||
model(input_buffer)
|
model(inputs)
|
||||||
|
|
||||||
model(input_buffer[:2])
|
model(torch.randn(2).cuda())
|
||||||
model(input_buffer[:1])
|
model(torch.randn(1).cuda())
|
||||||
|
|
||||||
input_buffer[:2].zero_()
|
input = torch.zeros(2).cuda()
|
||||||
global global_counter
|
global global_counter
|
||||||
global_counter = 0
|
global_counter = 0
|
||||||
output = model(input_buffer[:2])
|
output = model(input)
|
||||||
assert global_counter == 2
|
assert global_counter == 2
|
||||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||||
|
|
||||||
|
|||||||
@ -389,6 +389,8 @@ class VllmBackend:
|
|||||||
returned_callable: Callable
|
returned_callable: Callable
|
||||||
# Inductor passes to run on the graph pre-defunctionalization
|
# Inductor passes to run on the graph pre-defunctionalization
|
||||||
post_grad_passes: Sequence[Callable]
|
post_grad_passes: Sequence[Callable]
|
||||||
|
sym_tensor_indices: List[int]
|
||||||
|
input_buffers: List[torch.Tensor]
|
||||||
|
|
||||||
def __init__(self, post_grad_passes: Sequence[Callable] = ()):
|
def __init__(self, post_grad_passes: Sequence[Callable] = ()):
|
||||||
global global_graph_pool
|
global global_graph_pool
|
||||||
@ -401,6 +403,9 @@ class VllmBackend:
|
|||||||
self.graph_pool = global_graph_pool
|
self.graph_pool = global_graph_pool
|
||||||
self.post_grad_passes = post_grad_passes
|
self.post_grad_passes = post_grad_passes
|
||||||
|
|
||||||
|
self.sym_tensor_indices = []
|
||||||
|
self.input_buffers = []
|
||||||
|
|
||||||
# `torch.compile` is JIT compiled, so we don't need to
|
# `torch.compile` is JIT compiled, so we don't need to
|
||||||
# do anything here
|
# do anything here
|
||||||
|
|
||||||
@ -461,7 +466,46 @@ class VllmBackend:
|
|||||||
|
|
||||||
self._called = True
|
self._called = True
|
||||||
|
|
||||||
return self.split_gm
|
if not self.compilation_configs.use_cudagraph or \
|
||||||
|
not self.compilation_configs.cudagraph_copy_inputs:
|
||||||
|
return self.split_gm
|
||||||
|
|
||||||
|
# if we need to copy input buffers for cudagraph
|
||||||
|
from torch._guards import detect_fake_mode
|
||||||
|
fake_mode = detect_fake_mode()
|
||||||
|
fake_args = [
|
||||||
|
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||||
|
for t in example_inputs
|
||||||
|
]
|
||||||
|
|
||||||
|
# index of tensors that have symbolic shapes (batch size)
|
||||||
|
self.sym_tensor_indices = [
|
||||||
|
i for i, x in enumerate(fake_args)
|
||||||
|
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
|
||||||
|
]
|
||||||
|
|
||||||
|
# compiler managed cudagraph input buffers
|
||||||
|
# we assume the first run with symbolic shapes
|
||||||
|
# has the maximum size among all the tensors
|
||||||
|
self.input_buffers = [
|
||||||
|
example_inputs[x].clone() for x in self.sym_tensor_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
def copy_and_call(*args):
|
||||||
|
list_args = list(args)
|
||||||
|
for i, index in enumerate(self.sym_tensor_indices):
|
||||||
|
runtime_tensor = list_args[index]
|
||||||
|
runtime_shape = runtime_tensor.shape[0]
|
||||||
|
static_tensor = self.input_buffers[i][:runtime_shape]
|
||||||
|
|
||||||
|
# copy the tensor to the static buffer
|
||||||
|
static_tensor.copy_(runtime_tensor)
|
||||||
|
|
||||||
|
# replace the tensor in the list_args to the static buffer
|
||||||
|
list_args[index] = static_tensor
|
||||||
|
return self.split_gm(*list_args)
|
||||||
|
|
||||||
|
return copy_and_call
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|||||||
@ -32,6 +32,11 @@ class CompilationConfig(BaseModel):
|
|||||||
It means the first several runs will be treated as warmup runs.
|
It means the first several runs will be treated as warmup runs.
|
||||||
Only after that, the execution will be recorded, and the recorded
|
Only after that, the execution will be recorded, and the recorded
|
||||||
cudagraph will be used for subsequent runs.
|
cudagraph will be used for subsequent runs.
|
||||||
|
- cudagraph_copy_inputs: whether to copy input tensors for
|
||||||
|
cudagraph. If the caller can guarantee that the same input buffers
|
||||||
|
are always used, it can set this to False. Otherwise, it should
|
||||||
|
set this to True, and the compiler will copy the input to an
|
||||||
|
internally managed buffer. Default is False.
|
||||||
- Inductor compilation:
|
- Inductor compilation:
|
||||||
- use_inductor: whether to use inductor compilation.
|
- use_inductor: whether to use inductor compilation.
|
||||||
- False: inductor compilation is not used. graph runs in eager.
|
- False: inductor compilation is not used. graph runs in eager.
|
||||||
@ -78,6 +83,7 @@ class CompilationConfig(BaseModel):
|
|||||||
non_cudagraph_ops: List[str] = Field(default_factory=list)
|
non_cudagraph_ops: List[str] = Field(default_factory=list)
|
||||||
cudagraph_num_of_warmups: int = 0
|
cudagraph_num_of_warmups: int = 0
|
||||||
cudagraph_capture_sizes: Optional[List[int]] = None
|
cudagraph_capture_sizes: Optional[List[int]] = None
|
||||||
|
cudagraph_copy_inputs: bool = False
|
||||||
|
|
||||||
dump_graph_stages: List[str] = Field(default_factory=list)
|
dump_graph_stages: List[str] = Field(default_factory=list)
|
||||||
dump_graph_dir: Path = Field(default=Path("."))
|
dump_graph_dir: Path = Field(default=Path("."))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user