[7/N] torch.compile, reduce compilation time (#10460)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-20 11:20:38 -08:00 committed by GitHub
parent 5f1d6af2b6
commit 0cd3d9717e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 27 additions and 16 deletions

View File

@ -79,7 +79,7 @@ def test_simple_piecewise_compile():
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
use_cudagraph=True, use_cudagraph=True,
non_cudagraph_ops=["silly.attention"], splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True, cudagraph_copy_inputs=True,
)) ))
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):

View File

@ -258,7 +258,7 @@ def run_model(llama_config,
use_cudagraph=True, use_cudagraph=True,
) )
if split_attn: if split_attn:
compilation_config.non_cudagraph_ops = ["silly.attention"] compilation_config.splitting_ops = ["silly.attention"]
else: else:
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, ) level=CompilationLevel.NO_COMPILATION, )
@ -378,7 +378,7 @@ def benchmark():
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
use_cudagraph=True, use_cudagraph=True,
non_cudagraph_ops=["silly.attention"], splitting_ops=["silly.attention"],
) )
else: else:
compilation_config = CompilationConfig( compilation_config = CompilationConfig(

View File

@ -447,7 +447,7 @@ class VllmBackend:
self.add_passes_to_config() self.add_passes_to_config()
self.split_gm, self.piecewise_graphs = split_graph( self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.non_cudagraph_ops) graph, self.compilation_configs.splitting_ops)
from torch._dynamo.utils import lazy_format_graph_code from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s", lazy_format_graph_code("before split", self.graph)) logger.debug("%s", lazy_format_graph_code("before split", self.graph))

View File

@ -2089,13 +2089,15 @@ class CompilationConfig(BaseModel):
- 'none,+op1,+op2' to enable only op1 and op2 - 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor By default, all custom ops are enabled when running without Inductor
and disabled when running with Inductor (compile_level >= Inductor). and disabled when running with Inductor (compile_level >= Inductor).
- splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation.
- CudaGraph capture: - CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation. - use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used. - False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires - True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses. that all input buffers have fixed addresses, and all
Note that this is orthogonal to the cudagraph capture out splitting ops write their outputs to input buffers.
side of compilation. Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
TODO: move outside cudagraph logic into compilation. TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future. torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph. - cudagraph_capture_sizes: sizes to capture cudagraph.
@ -2149,6 +2151,11 @@ class CompilationConfig(BaseModel):
level: int = 0 level: int = 0
backend: str = "" backend: str = ""
custom_ops: List[str] = Field(default_factory=list) custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_flash_attention",
"vllm.unified_flash_infer",
"vllm.unified_v1_flash_attention",
])
use_inductor: bool = True use_inductor: bool = True
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
@ -2157,7 +2164,6 @@ class CompilationConfig(BaseModel):
inductor_passes: Dict[str, str] = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict)
use_cudagraph: bool = False use_cudagraph: bool = False
non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0 cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False cudagraph_copy_inputs: bool = False
@ -2348,9 +2354,6 @@ class VllmConfig:
# and avoid any potential issues with the inductor. # and avoid any potential issues with the inductor.
self.compilation_config.custom_ops = ["none"] self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True self.compilation_config.use_cudagraph = True
self.compilation_config.non_cudagraph_ops = [
"vllm.unified_v1_flash_attention"
]
self.compilation_config.use_inductor = True self.compilation_config.use_inductor = True
self.compilation_config.enable_fusion = False self.compilation_config.enable_fusion = False

View File

@ -1,6 +1,7 @@
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import os import os
import time
from typing import Dict, List, Optional, Set, Tuple, Type, Union from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
@ -189,6 +190,7 @@ class Worker(LocalOrDistributedWorkerBase):
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
start_time = time.time()
# Execute a forward pass with dummy inputs to profile the memory usage # Execute a forward pass with dummy inputs to profile the memory usage
# of the model. # of the model.
@ -229,12 +231,18 @@ class Worker(LocalOrDistributedWorkerBase):
num_gpu_blocks = max(num_gpu_blocks, 0) num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
end_time = time.time()
logger.info( logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB" "Memory profiling results: "
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" "duration=%.2f seconds, "
" memory_usage_post_profile=%.2fGiB" "total_gpu_memory=%.2fGiB, "
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB" "initial_memory_usage=%.2fGiB, "
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3), "peak_torch_memory=%.2fGiB, "
"memory_usage_post_profile=%.2fGiB, "
"non_torch_memory=%.2fGiB, "
"kv_cache_size=%.2fGiB, "
"gpu_memory_utilization=%.2f.", end_time - start_time,
total_gpu_memory / (1024**3),
(total_gpu_memory - free_memory_pre_profile) / (1024**3), (total_gpu_memory - free_memory_pre_profile) / (1024**3),
(peak_memory - non_torch_allocations) / (1024**3), (peak_memory - non_torch_allocations) / (1024**3),
total_allocated_bytes / (1024**3), total_allocated_bytes / (1024**3),