mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
[7/N] torch.compile, reduce compilation time (#10460)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
5f1d6af2b6
commit
0cd3d9717e
@ -79,7 +79,7 @@ def test_simple_piecewise_compile():
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["silly.attention"],
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_copy_inputs=True,
|
||||
))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
|
||||
@ -258,7 +258,7 @@ def run_model(llama_config,
|
||||
use_cudagraph=True,
|
||||
)
|
||||
if split_attn:
|
||||
compilation_config.non_cudagraph_ops = ["silly.attention"]
|
||||
compilation_config.splitting_ops = ["silly.attention"]
|
||||
else:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION, )
|
||||
@ -378,7 +378,7 @@ def benchmark():
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["silly.attention"],
|
||||
splitting_ops=["silly.attention"],
|
||||
)
|
||||
else:
|
||||
compilation_config = CompilationConfig(
|
||||
|
||||
@ -447,7 +447,7 @@ class VllmBackend:
|
||||
self.add_passes_to_config()
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
graph, self.compilation_configs.non_cudagraph_ops)
|
||||
graph, self.compilation_configs.splitting_ops)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
|
||||
|
||||
@ -2089,13 +2089,15 @@ class CompilationConfig(BaseModel):
|
||||
- 'none,+op1,+op2' to enable only op1 and op2
|
||||
By default, all custom ops are enabled when running without Inductor
|
||||
and disabled when running with Inductor (compile_level >= Inductor).
|
||||
- splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation.
|
||||
- CudaGraph capture:
|
||||
- use_cudagraph: whether to use cudagraph inside compilation.
|
||||
- False: cudagraph inside compilation is not used.
|
||||
- True: cudagraph inside compilation is used. It requires
|
||||
that all input buffers have fixed addresses.
|
||||
Note that this is orthogonal to the cudagraph capture out
|
||||
side of compilation.
|
||||
that all input buffers have fixed addresses, and all
|
||||
splitting ops write their outputs to input buffers.
|
||||
Note that this is orthogonal to the cudagraph capture logic
|
||||
outside of compilation.
|
||||
TODO: move outside cudagraph logic into compilation.
|
||||
torch.compile will handle cudagraph capture logic in the future.
|
||||
- cudagraph_capture_sizes: sizes to capture cudagraph.
|
||||
@ -2149,6 +2151,11 @@ class CompilationConfig(BaseModel):
|
||||
level: int = 0
|
||||
backend: str = ""
|
||||
custom_ops: List[str] = Field(default_factory=list)
|
||||
splitting_ops: List[str] = Field(default_factory=lambda: [
|
||||
"vllm.unified_flash_attention",
|
||||
"vllm.unified_flash_infer",
|
||||
"vllm.unified_v1_flash_attention",
|
||||
])
|
||||
|
||||
use_inductor: bool = True
|
||||
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
|
||||
@ -2157,7 +2164,6 @@ class CompilationConfig(BaseModel):
|
||||
inductor_passes: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
use_cudagraph: bool = False
|
||||
non_cudagraph_ops: List[str] = Field(default_factory=list)
|
||||
cudagraph_num_of_warmups: int = 0
|
||||
cudagraph_capture_sizes: Optional[List[int]] = None
|
||||
cudagraph_copy_inputs: bool = False
|
||||
@ -2348,9 +2354,6 @@ class VllmConfig:
|
||||
# and avoid any potential issues with the inductor.
|
||||
self.compilation_config.custom_ops = ["none"]
|
||||
self.compilation_config.use_cudagraph = True
|
||||
self.compilation_config.non_cudagraph_ops = [
|
||||
"vllm.unified_v1_flash_attention"
|
||||
]
|
||||
self.compilation_config.use_inductor = True
|
||||
self.compilation_config.enable_fusion = False
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
@ -189,6 +190,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
start_time = time.time()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
@ -229,12 +231,18 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
"Memory profiling results: total_gpu_memory=%.2fGiB"
|
||||
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
|
||||
" memory_usage_post_profile=%.2fGiB"
|
||||
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
|
||||
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
|
||||
"Memory profiling results: "
|
||||
"duration=%.2f seconds, "
|
||||
"total_gpu_memory=%.2fGiB, "
|
||||
"initial_memory_usage=%.2fGiB, "
|
||||
"peak_torch_memory=%.2fGiB, "
|
||||
"memory_usage_post_profile=%.2fGiB, "
|
||||
"non_torch_memory=%.2fGiB, "
|
||||
"kv_cache_size=%.2fGiB, "
|
||||
"gpu_memory_utilization=%.2f.", end_time - start_time,
|
||||
total_gpu_memory / (1024**3),
|
||||
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
|
||||
(peak_memory - non_torch_allocations) / (1024**3),
|
||||
total_allocated_bytes / (1024**3),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user