From d089395bb55d5f1f8cf98374d074c336b9223e88 Mon Sep 17 00:00:00 2001 From: Yondon Fu Date: Wed, 11 Dec 2024 17:46:13 -0500 Subject: [PATCH] Support taesd in TorchCompileVAE --- nodes/model_optimization_nodes.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 277bea5..fcc6f98 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -342,15 +342,41 @@ class TorchCompileVAE: def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder): if compile_encoder: if not self._compiled_encoder: + encoder_name = "encoder" + if hasattr(vae.first_stage_model, "taesd_encoder"): + encoder_name = "taesd_encoder" + try: - vae.first_stage_model.encoder = torch.compile(vae.first_stage_model.encoder, mode=mode, fullgraph=fullgraph, backend=backend) + setattr( + vae.first_stage_model, + encoder_name, + torch.compile( + getattr(vae.first_stage_model, encoder_name), + mode=mode, + fullgraph=fullgraph, + backend=backend, + ), + ) self._compiled_encoder = True except: raise RuntimeError("Failed to compile model") if compile_decoder: if not self._compiled_decoder: + decoder_name = "decoder" + if hasattr(vae.first_stage_model, "taesd_decoder"): + decoder_name = "taesd_decoder" + try: - vae.first_stage_model.decoder = torch.compile(vae.first_stage_model.decoder, mode=mode, fullgraph=fullgraph, backend=backend) + setattr( + vae.first_stage_model, + decoder_name, + torch.compile( + getattr(vae.first_stage_model, decoder_name), + mode=mode, + fullgraph=fullgraph, + backend=backend, + ), + ) self._compiled_decoder = True except: raise RuntimeError("Failed to compile model")