Support taesd in TorchCompileVAE

This commit is contained in:
Yondon Fu 2024-12-11 17:46:13 -05:00
parent 8c590fd5a0
commit d089395bb5

View File

@ -342,15 +342,41 @@ class TorchCompileVAE:
def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder):
if compile_encoder:
if not self._compiled_encoder:
encoder_name = "encoder"
if hasattr(vae.first_stage_model, "taesd_encoder"):
encoder_name = "taesd_encoder"
try:
vae.first_stage_model.encoder = torch.compile(vae.first_stage_model.encoder, mode=mode, fullgraph=fullgraph, backend=backend)
setattr(
vae.first_stage_model,
encoder_name,
torch.compile(
getattr(vae.first_stage_model, encoder_name),
mode=mode,
fullgraph=fullgraph,
backend=backend,
),
)
self._compiled_encoder = True
except:
raise RuntimeError("Failed to compile model")
if compile_decoder:
if not self._compiled_decoder:
decoder_name = "decoder"
if hasattr(vae.first_stage_model, "taesd_decoder"):
decoder_name = "taesd_decoder"
try:
vae.first_stage_model.decoder = torch.compile(vae.first_stage_model.decoder, mode=mode, fullgraph=fullgraph, backend=backend)
setattr(
vae.first_stage_model,
decoder_name,
torch.compile(
getattr(vae.first_stage_model, decoder_name),
mode=mode,
fullgraph=fullgraph,
backend=backend,
),
)
self._compiled_decoder = True
except:
raise RuntimeError("Failed to compile model")