Compare commits

..

No commits in common. "main" and "fastercache" have entirely different histories.

5 changed files with 22 additions and 88 deletions

View File

@ -31,28 +31,17 @@ class LatentPreviewer:
class Latent2RGBPreviewer(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self): def __init__(self):
#latent_rgb_factors = [[0.05389399697934166, 0.025018778505575393, -0.009193515248318657], [0.02318250640590553, -0.026987363837713156, 0.040172639061236956], [0.046035451343323666, -0.02039565868920197, 0.01275569344290342], [-0.015559161155025095, 0.051403973219861246, 0.03179031307996347], [-0.02766167769640129, 0.03749545161530447, 0.003335141009473408], [0.05824598730479011, 0.021744367381243884, -0.01578925627951616], [0.05260929401500947, 0.0560165014956886, -0.027477296572565126], [0.018513891242931686, 0.041961785217662514, 0.004490763489747966], [0.024063060899760215, 0.065082853069653, 0.044343437673514896], [0.05250992323006226, 0.04361117432588933, 0.01030076055524387], [0.0038921710021782366, -0.025299228133723792, 0.019370764014574535], [-0.00011950534333568519, 0.06549370069727675, -0.03436712163379723], [-0.026020578032683626, -0.013341758571090847, -0.009119046570271953], [0.024412451175602937, 0.030135064560817174, -0.008355486384198006], [0.04002209845752687, -0.017341304390739463, 0.02818338690302971], [-0.032575108695213684, -0.009588338926775117, -0.03077312160940468]] #latent_rgb_factors = [[0.05389399697934166, 0.025018778505575393, -0.009193515248318657], [0.02318250640590553, -0.026987363837713156, 0.040172639061236956], [0.046035451343323666, -0.02039565868920197, 0.01275569344290342], [-0.015559161155025095, 0.051403973219861246, 0.03179031307996347], [-0.02766167769640129, 0.03749545161530447, 0.003335141009473408], [0.05824598730479011, 0.021744367381243884, -0.01578925627951616], [0.05260929401500947, 0.0560165014956886, -0.027477296572565126], [0.018513891242931686, 0.041961785217662514, 0.004490763489747966], [0.024063060899760215, 0.065082853069653, 0.044343437673514896], [0.05250992323006226, 0.04361117432588933, 0.01030076055524387], [0.0038921710021782366, -0.025299228133723792, 0.019370764014574535], [-0.00011950534333568519, 0.06549370069727675, -0.03436712163379723], [-0.026020578032683626, -0.013341758571090847, -0.009119046570271953], [0.024412451175602937, 0.030135064560817174, -0.008355486384198006], [0.04002209845752687, -0.017341304390739463, 0.02818338690302971], [-0.032575108695213684, -0.009588338926775117, -0.03077312160940468]]
#latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]] latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
latent_rgb_factors =[
[-0.0069, -0.0045, 0.0018],
[ 0.0154, -0.0692, -0.0274],
[ 0.0333, 0.0019, 0.0206],
[-0.1390, 0.0628, 0.1678],
[-0.0725, 0.0134, -0.1898],
[ 0.0074, -0.0270, -0.0209],
[-0.0176, -0.0277, -0.0221],
[ 0.5294, 0.5204, 0.3852],
[-0.0326, -0.0446, -0.0143],
[-0.0659, 0.0153, -0.0153],
[ 0.0185, -0.0217, 0.0014],
[-0.0396, -0.0495, -0.0281]
]
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453] self.latent_rgb_factors_bias = None
# if latent_rgb_factors_bias is not None:
# self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
if self.latent_rgb_factors_bias is not None: if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = torch.tensor(self.latent_rgb_factors_bias, device="cpu").to(dtype=x0.dtype, device=x0.device) self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors,
bias=self.latent_rgb_factors_bias) bias=self.latent_rgb_factors_bias)

View File

@ -165,7 +165,6 @@ class AsymmetricAttention(nn.Module):
raise ImportError("Flash RMSNorm not available.") raise ImportError("Flash RMSNorm not available.")
elif rms_norm_func == "apex": elif rms_norm_func == "apex":
from apex.normalization import FusedRMSNorm as ApexRMSNorm from apex.normalization import FusedRMSNorm as ApexRMSNorm
@torch.compiler.disable()
class RMSNorm(ApexRMSNorm): class RMSNorm(ApexRMSNorm):
pass pass
else: else:
@ -238,6 +237,7 @@ class AsymmetricAttention(nn.Module):
skip_reshape=True skip_reshape=True
) )
return out return out
def run_attention( def run_attention(
self, self,
q, q,

View File

@ -140,7 +140,7 @@ class PatchEmbed(nn.Module):
x = self.norm(x) x = self.norm(x)
return x return x
@torch.compiler.disable()
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None): def __init__(self, hidden_size, eps=1e-5, device=None):
super().__init__() super().__init__()

View File

@ -34,7 +34,6 @@ except:
import torch import torch
import torch.utils.data import torch.utils.data
import torch._dynamo
from tqdm import tqdm from tqdm import tqdm
from comfy.utils import ProgressBar, load_torch_file from comfy.utils import ProgressBar, load_torch_file
@ -162,8 +161,6 @@ class T2VSynthMochiModel:
#torch.compile #torch.compile
if compile_args is not None: if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
log.info(f"Set dynamo cache size limit to {torch._dynamo.config.cache_size_limit}")
if compile_args["compile_dit"]: if compile_args["compile_dit"]:
for i, block in enumerate(model.blocks): for i, block in enumerate(model.blocks):
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"]) model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
@ -311,7 +308,7 @@ class T2VSynthMochiModel:
comfy_pbar = ProgressBar(sample_steps) comfy_pbar = ProgressBar(sample_steps)
if (hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul) or self.device.type == "mps": if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
autocast_dtype = torch.float16 autocast_dtype = torch.float16
else: else:
autocast_dtype = torch.bfloat16 autocast_dtype = torch.bfloat16

View File

@ -119,10 +119,6 @@ class DownloadAndLoadMochiModel:
mm.soft_empty_cache() mm.soft_empty_cache()
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
if "fp8" in precision:
vae_dtype = torch.bfloat16
else:
vae_dtype = dtype
# Transformer model # Transformer model
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi') model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
@ -178,15 +174,14 @@ class DownloadAndLoadMochiModel:
nonlinearity="silu", nonlinearity="silu",
output_nonlinearity="silu", output_nonlinearity="silu",
causal=True, causal=True,
dtype=vae_dtype,
) )
vae_sd = load_torch_file(vae_path) vae_sd = load_torch_file(vae_path)
if is_accelerate_available: if is_accelerate_available:
for key in vae_sd: for key in vae_sd:
set_module_tensor_to_device(vae, key, dtype=vae_dtype, device=offload_device, value=vae_sd[key]) set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key])
else: else:
vae.load_state_dict(vae_sd, strict=True) vae.load_state_dict(vae_sd, strict=True)
vae.eval().to(vae_dtype).to("cpu") vae.eval().to(torch.bfloat16).to("cpu")
del vae_sd del vae_sd
return (model, vae,) return (model, vae,)
@ -247,7 +242,6 @@ class MochiTorchCompileSettings:
"compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}), "compile_dit": ("BOOLEAN", {"default": True, "tooltip": "Compiles all transformer blocks"}),
"compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}), "compile_final_layer": ("BOOLEAN", {"default": True, "tooltip": "Enable compiling final layer."}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
}, },
} }
RETURN_TYPES = ("MOCHICOMPILEARGS",) RETURN_TYPES = ("MOCHICOMPILEARGS",)
@ -256,7 +250,7 @@ class MochiTorchCompileSettings:
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic, dynamo_cache_size_limit): def loadmodel(self, backend, fullgraph, mode, compile_dit, compile_final_layer, dynamic):
compile_args = { compile_args = {
"backend": backend, "backend": backend,
@ -265,7 +259,6 @@ class MochiTorchCompileSettings:
"compile_dit": compile_dit, "compile_dit": compile_dit,
"compile_final_layer": compile_final_layer, "compile_final_layer": compile_final_layer,
"dynamic": dynamic, "dynamic": dynamic,
"dynamo_cache_size_limit": dynamo_cache_size_limit,
} }
return (compile_args, ) return (compile_args, )
@ -315,16 +308,6 @@ class MochiVAELoader:
dtype=dtype, dtype=dtype,
) )
vae_sd = load_torch_file(vae_path) vae_sd = load_torch_file(vae_path)
#support loading from combined VAE
if vae_sd.get("decoder.blocks.0.0.bias") is not None:
new_vae_sd = {}
for k, v in vae_sd.items():
if k.startswith("decoder."):
new_k = k[len("decoder."):]
new_vae_sd[new_k] = v
vae_sd = new_vae_sd
if is_accelerate_available: if is_accelerate_available:
for name, param in vae.named_parameters(): for name, param in vae.named_parameters():
set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name]) set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name])
@ -393,16 +376,6 @@ class MochiVAEEncoderLoader:
) )
encoder_sd = load_torch_file(vae_path) encoder_sd = load_torch_file(vae_path)
#support loading from combined VAE
if encoder_sd.get("encoder.layers.0.bias") is not None:
new_vae_sd = {}
for k, v in encoder_sd.items():
if k.startswith("encoder."):
new_k = k[len("encoder."):]
new_vae_sd[new_k] = v
encoder_sd = new_vae_sd
if is_accelerate_available: if is_accelerate_available:
for name, param in encoder.named_parameters(): for name, param in encoder.named_parameters():
set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name]) set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name])
@ -598,9 +571,6 @@ class MochiDecode:
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
}, },
"optional": {
"unnormalize": ("BOOLEAN", {"default": False, "tooltip": "Unnormalize the latents before decoding"}),
},
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
@ -609,13 +579,12 @@ class MochiDecode:
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height,
tile_overlap_factor_width, auto_tile_size, frame_batch_size, unnormalize=False): tile_overlap_factor_width, auto_tile_size, frame_batch_size):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device() intermediate_device = mm.intermediate_device()
samples = samples["samples"] samples = samples["samples"]
if unnormalize: samples = dit_latents_to_vae_latents(samples)
samples = dit_latents_to_vae_latents(samples)
samples = samples.to(vae.dtype).to(device) samples = samples.to(vae.dtype).to(device)
B, C, T, H, W = samples.shape B, C, T, H, W = samples.shape
@ -730,9 +699,6 @@ class MochiDecodeSpatialTiling:
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}), "min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
"per_batch": ("INT", {"default": 6, "min": 1, "max": 256, "step": 1, "tooltip": "Number of samples per batch, in latent space (6 frames in 1 latent)"}), "per_batch": ("INT", {"default": 6, "min": 1, "max": 256, "step": 1, "tooltip": "Number of samples per batch, in latent space (6 frames in 1 latent)"}),
}, },
"optional": {
"unnormalize": ("BOOLEAN", {"default": True, "tooltip": "Unnormalize the latents before decoding"}),
},
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
@ -741,13 +707,12 @@ class MochiDecodeSpatialTiling:
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap,
min_block_size, per_batch, unnormalize=True): min_block_size, per_batch):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device() intermediate_device = mm.intermediate_device()
samples = samples["samples"] samples = samples["samples"]
if unnormalize: samples = dit_latents_to_vae_latents(samples)
samples = dit_latents_to_vae_latents(samples)
samples = samples.to(vae.dtype).to(device) samples = samples.to(vae.dtype).to(device)
B, C, T, H, W = samples.shape B, C, T, H, W = samples.shape
@ -803,17 +768,14 @@ class MochiImageEncode:
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}), "overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}), "min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
}, },
"optional": {
"normalize": ("BOOLEAN", {"default": True, "tooltip": "Normalize the images before encoding"}),
},
} }
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",) RETURN_NAMES = ("samples",)
FUNCTION = "encode" FUNCTION = "decode"
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
def encode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size, normalize=True): def decode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device() intermediate_device = mm.intermediate_device()
@ -838,8 +800,7 @@ class MochiImageEncode:
latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size) latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
else: else:
latents = encoder(video) latents = encoder(video)
if normalize: latents = vae_latents_to_dit_latents(latents)
latents = vae_latents_to_dit_latents(latents)
print("encoder output",latents.shape) print("encoder output",latents.shape)
return ({"samples": latents},) return ({"samples": latents},)
@ -870,21 +831,8 @@ class MochiLatentPreview:
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
#latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]] latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
latent_rgb_factors =[
[-0.0069, -0.0045, 0.0018],
[ 0.0154, -0.0692, -0.0274],
[ 0.0333, 0.0019, 0.0206],
[-0.1390, 0.0628, 0.1678],
[-0.0725, 0.0134, -0.1898],
[ 0.0074, -0.0270, -0.0209],
[-0.0176, -0.0277, -0.0221],
[ 0.5294, 0.5204, 0.3852],
[-0.0326, -0.0446, -0.0143],
[-0.0659, 0.0153, -0.0153],
[ 0.0185, -0.0217, 0.0014],
[-0.0396, -0.0495, -0.0281]
]
# import random # import random
# random.seed(seed) # random.seed(seed)
# latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)] # latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)]
@ -892,7 +840,7 @@ class MochiLatentPreview:
# print(latent_rgb_factors) # print(latent_rgb_factors)
latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453] latent_rgb_factors_bias = [0,0,0]
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1) latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype) latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)