From 617fb893d5df97de41b76037244500013066de45 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 26 Jul 2024 19:29:36 -0700 Subject: [PATCH] add compile --- vllm/worker/model_runner.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86d26b4a84c36..77d5a9c735e22 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -787,6 +787,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") + count = 0 + + def backend(gm, input): + nonlocal count + count += 1 + print(count) + return gm.forward + + self.model = torch.compile(self.model, backend=backend, fullgraph=True) + def save_sharded_state( self, path: str,