mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-31 22:27:16 +08:00
Merge pull request #155 from yondonfu/taesd-compile
Support taesd in TorchCompileVAE
This commit is contained in:
commit
df7edf7893
@ -375,15 +375,41 @@ class TorchCompileVAE:
|
|||||||
def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder):
|
def compile(self, vae, backend, mode, fullgraph, compile_encoder, compile_decoder):
|
||||||
if compile_encoder:
|
if compile_encoder:
|
||||||
if not self._compiled_encoder:
|
if not self._compiled_encoder:
|
||||||
|
encoder_name = "encoder"
|
||||||
|
if hasattr(vae.first_stage_model, "taesd_encoder"):
|
||||||
|
encoder_name = "taesd_encoder"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
vae.first_stage_model.encoder = torch.compile(vae.first_stage_model.encoder, mode=mode, fullgraph=fullgraph, backend=backend)
|
setattr(
|
||||||
|
vae.first_stage_model,
|
||||||
|
encoder_name,
|
||||||
|
torch.compile(
|
||||||
|
getattr(vae.first_stage_model, encoder_name),
|
||||||
|
mode=mode,
|
||||||
|
fullgraph=fullgraph,
|
||||||
|
backend=backend,
|
||||||
|
),
|
||||||
|
)
|
||||||
self._compiled_encoder = True
|
self._compiled_encoder = True
|
||||||
except:
|
except:
|
||||||
raise RuntimeError("Failed to compile model")
|
raise RuntimeError("Failed to compile model")
|
||||||
if compile_decoder:
|
if compile_decoder:
|
||||||
if not self._compiled_decoder:
|
if not self._compiled_decoder:
|
||||||
|
decoder_name = "decoder"
|
||||||
|
if hasattr(vae.first_stage_model, "taesd_decoder"):
|
||||||
|
decoder_name = "taesd_decoder"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
vae.first_stage_model.decoder = torch.compile(vae.first_stage_model.decoder, mode=mode, fullgraph=fullgraph, backend=backend)
|
setattr(
|
||||||
|
vae.first_stage_model,
|
||||||
|
decoder_name,
|
||||||
|
torch.compile(
|
||||||
|
getattr(vae.first_stage_model, decoder_name),
|
||||||
|
mode=mode,
|
||||||
|
fullgraph=fullgraph,
|
||||||
|
backend=backend,
|
||||||
|
),
|
||||||
|
)
|
||||||
self._compiled_decoder = True
|
self._compiled_decoder = True
|
||||||
except:
|
except:
|
||||||
raise RuntimeError("Failed to compile model")
|
raise RuntimeError("Failed to compile model")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user