mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:45:15 +08:00
Remove graph_pool as member of VllmBackend and argument to CUDAGraphWrapper (#23385)
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
6fd45e7b8a
commit
6fad29b11b
@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
|
|
||||||
def __init__(self, module: torch.fx.GraphModule,
|
def __init__(self, module: torch.fx.GraphModule,
|
||||||
compile_submod_names: list[str], vllm_config: VllmConfig,
|
compile_submod_names: list[str], vllm_config: VllmConfig,
|
||||||
graph_pool, vllm_backend: "VllmBackend"):
|
vllm_backend: "VllmBackend"):
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
self.fake_mode = detect_fake_mode()
|
self.fake_mode = detect_fake_mode()
|
||||||
self.compile_submod_names = compile_submod_names
|
self.compile_submod_names = compile_submod_names
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
self.graph_pool = graph_pool
|
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.vllm_backend = vllm_backend
|
self.vllm_backend = vllm_backend
|
||||||
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
||||||
@ -359,7 +358,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
runnable=piecewise_backend,
|
runnable=piecewise_backend,
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
graph_pool=self.graph_pool,
|
|
||||||
cudagraph_options=CUDAGraphOptions(
|
cudagraph_options=CUDAGraphOptions(
|
||||||
debug_log_enable=piecewise_backend.is_first_graph,
|
debug_log_enable=piecewise_backend.is_first_graph,
|
||||||
gc_disable=not piecewise_backend.is_first_graph,
|
gc_disable=not piecewise_backend.is_first_graph,
|
||||||
@ -405,7 +403,6 @@ class VllmBackend:
|
|||||||
|
|
||||||
vllm_config: VllmConfig
|
vllm_config: VllmConfig
|
||||||
compilation_config: CompilationConfig
|
compilation_config: CompilationConfig
|
||||||
graph_pool: Any
|
|
||||||
_called: bool = False
|
_called: bool = False
|
||||||
# the graph we compiled
|
# the graph we compiled
|
||||||
graph: fx.GraphModule
|
graph: fx.GraphModule
|
||||||
@ -433,13 +430,6 @@ class VllmBackend:
|
|||||||
# them, e.g. backbone (default), eagle_head, etc.
|
# them, e.g. backbone (default), eagle_head, etc.
|
||||||
self.prefix = prefix or model_tag
|
self.prefix = prefix or model_tag
|
||||||
|
|
||||||
global_graph_pool = current_platform.get_global_graph_pool()
|
|
||||||
|
|
||||||
# TODO: in the future, if we want to use multiple
|
|
||||||
# streams, it might not be safe to share a global pool.
|
|
||||||
# only investigate this when we use multiple streams
|
|
||||||
self.graph_pool = global_graph_pool
|
|
||||||
|
|
||||||
# Passes to run on the graph post-grad.
|
# Passes to run on the graph post-grad.
|
||||||
self.post_grad_pass_manager = PostGradPassManager()
|
self.post_grad_pass_manager = PostGradPassManager()
|
||||||
|
|
||||||
@ -586,7 +576,7 @@ class VllmBackend:
|
|||||||
# propagate the split graph to the piecewise backend,
|
# propagate the split graph to the piecewise backend,
|
||||||
# compile submodules with symbolic shapes
|
# compile submodules with symbolic shapes
|
||||||
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
||||||
self.vllm_config, self.graph_pool,
|
self.vllm_config,
|
||||||
self).run(*example_inputs)
|
self).run(*example_inputs)
|
||||||
|
|
||||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||||
|
|||||||
@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||||
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
|
runtime_mode: CUDAGraphMode, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initializes the StaticGraphWrapper class with graph capturing and
|
Initializes the StaticGraphWrapper class with graph capturing and
|
||||||
execution-related configurations.
|
execution-related configurations.
|
||||||
@ -25,9 +25,6 @@ class AbstractStaticGraphWrapper(Protocol):
|
|||||||
graph runtime. See CUDAGraphMode in vllm/config.py.
|
graph runtime. See CUDAGraphMode in vllm/config.py.
|
||||||
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
|
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
|
||||||
are used as concrete runtime mode for cudagraph dispatching.
|
are used as concrete runtime mode for cudagraph dispatching.
|
||||||
graph_pool (Any):
|
|
||||||
Graph memory pool handle, e.g.,
|
|
||||||
`torch.cuda.graph_pool_handle()`.
|
|
||||||
Keyword Args:
|
Keyword Args:
|
||||||
kwargs: Additional keyword arguments for platform-specific
|
kwargs: Additional keyword arguments for platform-specific
|
||||||
configurations.
|
configurations.
|
||||||
|
|||||||
@ -67,11 +67,9 @@ class CUDAGraphWrapper:
|
|||||||
runnable: Callable,
|
runnable: Callable,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
runtime_mode: CUDAGraphMode,
|
runtime_mode: CUDAGraphMode,
|
||||||
graph_pool: Any = None,
|
|
||||||
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||||
self.runnable = runnable
|
self.runnable = runnable
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.graph_pool = graph_pool
|
|
||||||
self.runtime_mode = runtime_mode
|
self.runtime_mode = runtime_mode
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
|
||||||
@ -81,7 +79,9 @@ class CUDAGraphWrapper:
|
|||||||
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
|
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
|
||||||
# need to initialize a CUDAGraphWrapper.
|
# need to initialize a CUDAGraphWrapper.
|
||||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||||
if self.graph_pool is None:
|
# TODO: in the future, if we want to use multiple
|
||||||
|
# streams, it might not be safe to share a global pool.
|
||||||
|
# only investigate this when we use multiple streams
|
||||||
self.graph_pool = current_platform.get_global_graph_pool()
|
self.graph_pool = current_platform.get_global_graph_pool()
|
||||||
|
|
||||||
if cudagraph_options is None:
|
if cudagraph_options is None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user