mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 16:56:07 +08:00
[Bugfix] [pytorch] Patch AOTAutogradCache._get_shape_env (#17142)
Signed-off-by: James Wu <jjwu@meta.com>
This commit is contained in:
parent
5e83a7277f
commit
a6e72e1e4f
@ -195,7 +195,6 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
hash_str, file_path = None, None
|
hash_str, file_path = None, None
|
||||||
from torch._inductor.codecache import (FxGraphCache,
|
from torch._inductor.codecache import (FxGraphCache,
|
||||||
compiled_fx_graph_hash)
|
compiled_fx_graph_hash)
|
||||||
|
|
||||||
if torch.__version__.startswith("2.5"):
|
if torch.__version__.startswith("2.5"):
|
||||||
original_load = FxGraphCache.load
|
original_load = FxGraphCache.load
|
||||||
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||||
@ -280,6 +279,16 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||||
_get_shape_env))
|
_get_shape_env))
|
||||||
|
|
||||||
|
from torch._functorch._aot_autograd.autograd_cache import (
|
||||||
|
AOTAutogradCache)
|
||||||
|
|
||||||
|
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||||
|
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||||
|
stack.enter_context(
|
||||||
|
patch(
|
||||||
|
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||||
|
_get_shape_env))
|
||||||
|
|
||||||
# for forcing the graph to be cached
|
# for forcing the graph to be cached
|
||||||
stack.enter_context(
|
stack.enter_context(
|
||||||
patch(
|
patch(
|
||||||
@ -325,11 +334,19 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
assert isinstance(handle[1], str)
|
assert isinstance(handle[1], str)
|
||||||
hash_str = handle[0]
|
hash_str = handle[0]
|
||||||
|
|
||||||
|
from torch._functorch._aot_autograd.autograd_cache import (
|
||||||
|
AOTAutogradCache)
|
||||||
from torch._inductor.codecache import FxGraphCache
|
from torch._inductor.codecache import FxGraphCache
|
||||||
with ExitStack() as exit_stack:
|
with ExitStack() as exit_stack:
|
||||||
exit_stack.enter_context(
|
exit_stack.enter_context(
|
||||||
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||||
lambda *args, **kwargs: AlwaysHitShapeEnv()))
|
lambda *args, **kwargs: AlwaysHitShapeEnv()))
|
||||||
|
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||||
|
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||||
|
exit_stack.enter_context(
|
||||||
|
patch(
|
||||||
|
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||||
|
lambda *args, **kwargs: AlwaysHitShapeEnv()))
|
||||||
|
|
||||||
# Dynamo metrics context, see method for more details.
|
# Dynamo metrics context, see method for more details.
|
||||||
exit_stack.enter_context(self.metrics_context())
|
exit_stack.enter_context(self.metrics_context())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user