diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index ce482572b401b..fc02a08f74265 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -104,6 +104,7 @@ class VllmSerializableFunction(SerializableCallable): state = pickle.loads(data) fake_mode = FakeTensorMode(shape_env=ShapeEnv()) state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) + state["graph_module"].recompile() state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"])