mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-18 15:44:21 +08:00
Merge pull request #155 from yondonfu/taesd-compile
Support taesd in TorchCompileVAE
This commit is contained in:
commit
df7edf7893
@ -375,15 +375,41 @@ class TorchCompileVAE:
|
||||
def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder):
|
||||
if compile_encoder:
|
||||
if not self._compiled_encoder:
|
||||
encoder_name = "encoder"
|
||||
if hasattr(vae.first_stage_model, "taesd_encoder"):
|
||||
encoder_name = "taesd_encoder"
|
||||
|
||||
try:
|
||||
vae.first_stage_model.encoder = torch.compile(vae.first_stage_model.encoder, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
setattr(
|
||||
vae.first_stage_model,
|
||||
encoder_name,
|
||||
torch.compile(
|
||||
getattr(vae.first_stage_model, encoder_name),
|
||||
mode=mode,
|
||||
fullgraph=fullgraph,
|
||||
backend=backend,
|
||||
),
|
||||
)
|
||||
self._compiled_encoder = True
|
||||
except:
|
||||
raise RuntimeError("Failed to compile model")
|
||||
if compile_decoder:
|
||||
if not self._compiled_decoder:
|
||||
decoder_name = "decoder"
|
||||
if hasattr(vae.first_stage_model, "taesd_decoder"):
|
||||
decoder_name = "taesd_decoder"
|
||||
|
||||
try:
|
||||
vae.first_stage_model.decoder = torch.compile(vae.first_stage_model.decoder, mode=mode, fullgraph=fullgraph, backend=backend)
|
||||
setattr(
|
||||
vae.first_stage_model,
|
||||
decoder_name,
|
||||
torch.compile(
|
||||
getattr(vae.first_stage_model, decoder_name),
|
||||
mode=mode,
|
||||
fullgraph=fullgraph,
|
||||
backend=backend,
|
||||
),
|
||||
)
|
||||
self._compiled_decoder = True
|
||||
except:
|
||||
raise RuntimeError("Failed to compile model")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user