Merge pull request #155 from yondonfu/taesd-compile

Support taesd in TorchCompileVAE
This commit is contained in:
Jukka Seppänen 2025-01-31 12:25:09 +02:00 committed by GitHub
commit df7edf7893
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -375,15 +375,41 @@ class TorchCompileVAE:
def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder):
if compile_encoder:
if not self._compiled_encoder:
encoder_name = "encoder"
if hasattr(vae.first_stage_model, "taesd_encoder"):
encoder_name = "taesd_encoder"
try:
vae.first_stage_model.encoder = torch.compile(vae.first_stage_model.encoder, mode=mode, fullgraph=fullgraph, backend=backend)
setattr(
vae.first_stage_model,
encoder_name,
torch.compile(
getattr(vae.first_stage_model, encoder_name),
mode=mode,
fullgraph=fullgraph,
backend=backend,
),
)
self._compiled_encoder = True
except:
raise RuntimeError("Failed to compile model")
if compile_decoder:
if not self._compiled_decoder:
decoder_name = "decoder"
if hasattr(vae.first_stage_model, "taesd_decoder"):
decoder_name = "taesd_decoder"
try:
vae.first_stage_model.decoder = torch.compile(vae.first_stage_model.decoder, mode=mode, fullgraph=fullgraph, backend=backend)
setattr(
vae.first_stage_model,
decoder_name,
torch.compile(
getattr(vae.first_stage_model, decoder_name),
mode=mode,
fullgraph=fullgraph,
backend=backend,
),
)
self._compiled_decoder = True
except:
raise RuntimeError("Failed to compile model")