[7/N] torch.compile, reduce compilation time (#10460)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-20 11:20:38 -08:00 committed by GitHub
parent 5f1d6af2b6
commit 0cd3d9717e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 27 additions and 16 deletions

View File

@ -79,7 +79,7 @@ def test_simple_piecewise_compile():
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True,
))
with set_current_vllm_config(vllm_config):

View File

@ -258,7 +258,7 @@ def run_model(llama_config,
use_cudagraph=True,
)
if split_attn:
compilation_config.non_cudagraph_ops = ["silly.attention"]
compilation_config.splitting_ops = ["silly.attention"]
else:
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
@ -378,7 +378,7 @@ def benchmark():
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
splitting_ops=["silly.attention"],
)
else:
compilation_config = CompilationConfig(

View File

@ -447,7 +447,7 @@ class VllmBackend:
self.add_passes_to_config()
self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.non_cudagraph_ops)
graph, self.compilation_configs.splitting_ops)
from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s", lazy_format_graph_code("before split", self.graph))

View File

@ -2089,13 +2089,15 @@ class CompilationConfig(BaseModel):
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor
and disabled when running with Inductor (compile_level >= Inductor).
- splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation.
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses.
Note that this is orthogonal to the cudagraph capture out
side of compilation.
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
@ -2149,6 +2151,11 @@ class CompilationConfig(BaseModel):
level: int = 0
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_flash_attention",
"vllm.unified_flash_infer",
"vllm.unified_v1_flash_attention",
])
use_inductor: bool = True
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
@ -2157,7 +2164,6 @@ class CompilationConfig(BaseModel):
inductor_passes: Dict[str, str] = Field(default_factory=dict)
use_cudagraph: bool = False
non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False
@ -2348,9 +2354,6 @@ class VllmConfig:
# and avoid any potential issues with the inductor.
self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True
self.compilation_config.non_cudagraph_ops = [
"vllm.unified_v1_flash_attention"
]
self.compilation_config.use_inductor = True
self.compilation_config.enable_fusion = False

View File

@ -1,6 +1,7 @@
"""A GPU worker class."""
import gc
import os
import time
from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch
@ -189,6 +190,7 @@ class Worker(LocalOrDistributedWorkerBase):
torch.cuda.reset_peak_memory_stats()
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
start_time = time.time()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
@ -229,12 +231,18 @@ class Worker(LocalOrDistributedWorkerBase):
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
end_time = time.time()
logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB"
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
" memory_usage_post_profile=%.2fGiB"
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
"Memory profiling results: "
"duration=%.2f seconds, "
"total_gpu_memory=%.2fGiB, "
"initial_memory_usage=%.2fGiB, "
"peak_torch_memory=%.2fGiB, "
"memory_usage_post_profile=%.2fGiB, "
"non_torch_memory=%.2fGiB, "
"kv_cache_size=%.2fGiB, "
"gpu_memory_utilization=%.2f.", end_time - start_time,
total_gpu_memory / (1024**3),
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
(peak_memory - non_torch_allocations) / (1024**3),
total_allocated_bytes / (1024**3),