mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-12 07:24:26 +08:00
* make setattr safe for non existent attributes Handle the case where the attribute doesnt exist by returning a static sentinel (distinct from None). If the sentinel is passed in as the set value, del the attr. * Account for dequantization and type-casts in offload costs When measuring the cost of offload, identify weights that need a type change or dequantization and add the size of the conversion result to the offload cost. This is mutually exclusive with lowvram patches which already has a large conservative estimate and wont overlap the dequant cost so\ dont double count. * Set the compute type on CLIP MPs So that the loader can know the size of weights for dequant accounting.
1268 lines
53 KiB
Python
1268 lines
53 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Comfy
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
|
|
import torch
|
|
import math
|
|
import struct
|
|
import comfy.checkpoint_pickle
|
|
import safetensors.torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
import logging
|
|
import itertools
|
|
from torch.nn.functional import interpolate
|
|
from einops import rearrange
|
|
from comfy.cli_args import args
|
|
import json
|
|
|
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
|
DISABLE_MMAP = args.disable_mmap
|
|
|
|
ALWAYS_SAFE_LOAD = False
|
|
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
|
class ModelCheckpoint:
|
|
pass
|
|
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
|
|
|
def scalar(*args, **kwargs):
|
|
from numpy.core.multiarray import scalar as sc
|
|
return sc(*args, **kwargs)
|
|
scalar.__module__ = "numpy.core.multiarray"
|
|
|
|
from numpy import dtype
|
|
from numpy.dtypes import Float64DType
|
|
from _codecs import encode
|
|
|
|
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
|
ALWAYS_SAFE_LOAD = True
|
|
logging.info("Checkpoint files will always be loaded safely.")
|
|
else:
|
|
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
|
|
|
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|
if device is None:
|
|
device = torch.device("cpu")
|
|
metadata = None
|
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
|
try:
|
|
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
|
sd = {}
|
|
for k in f.keys():
|
|
tensor = f.get_tensor(k)
|
|
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
|
tensor = tensor.to(device=device, copy=True)
|
|
sd[k] = tensor
|
|
if return_metadata:
|
|
metadata = f.metadata()
|
|
except Exception as e:
|
|
if len(e.args) > 0:
|
|
message = e.args[0]
|
|
if "HeaderTooLarge" in message:
|
|
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt))
|
|
if "MetadataIncompleteBuffer" in message:
|
|
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
|
raise e
|
|
else:
|
|
torch_args = {}
|
|
if MMAP_TORCH_FILES:
|
|
torch_args["mmap"] = True
|
|
|
|
if safe_load or ALWAYS_SAFE_LOAD:
|
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
|
else:
|
|
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
|
if "state_dict" in pl_sd:
|
|
sd = pl_sd["state_dict"]
|
|
else:
|
|
if len(pl_sd) == 1:
|
|
key = list(pl_sd.keys())[0]
|
|
sd = pl_sd[key]
|
|
if not isinstance(sd, dict):
|
|
sd = pl_sd
|
|
else:
|
|
sd = pl_sd
|
|
return (sd, metadata) if return_metadata else sd
|
|
|
|
def save_torch_file(sd, ckpt, metadata=None):
|
|
if metadata is not None:
|
|
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
|
|
else:
|
|
safetensors.torch.save_file(sd, ckpt)
|
|
|
|
def calculate_parameters(sd, prefix=""):
|
|
params = 0
|
|
for k in sd.keys():
|
|
if k.startswith(prefix):
|
|
w = sd[k]
|
|
params += w.nelement()
|
|
return params
|
|
|
|
def weight_dtype(sd, prefix=""):
|
|
dtypes = {}
|
|
for k in sd.keys():
|
|
if k.startswith(prefix):
|
|
w = sd[k]
|
|
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
|
|
|
|
if len(dtypes) == 0:
|
|
return None
|
|
|
|
return max(dtypes, key=dtypes.get)
|
|
|
|
def state_dict_key_replace(state_dict, keys_to_replace):
|
|
for x in keys_to_replace:
|
|
if x in state_dict:
|
|
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
|
return state_dict
|
|
|
|
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
|
if filter_keys:
|
|
out = {}
|
|
else:
|
|
out = state_dict
|
|
for rp in replace_prefix:
|
|
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
|
for x in replace:
|
|
w = state_dict.pop(x[0])
|
|
out[x[1]] = w
|
|
return out
|
|
|
|
|
|
def transformers_convert(sd, prefix_from, prefix_to, number):
|
|
keys_to_replace = {
|
|
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
|
"{}token_embedding.weight": "{}embeddings.token_embedding.weight",
|
|
"{}ln_final.weight": "{}final_layer_norm.weight",
|
|
"{}ln_final.bias": "{}final_layer_norm.bias",
|
|
}
|
|
|
|
for k in keys_to_replace:
|
|
x = k.format(prefix_from)
|
|
if x in sd:
|
|
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
|
|
|
|
resblock_to_replace = {
|
|
"ln_1": "layer_norm1",
|
|
"ln_2": "layer_norm2",
|
|
"mlp.c_fc": "mlp.fc1",
|
|
"mlp.c_proj": "mlp.fc2",
|
|
"attn.out_proj": "self_attn.out_proj",
|
|
}
|
|
|
|
for resblock in range(number):
|
|
for x in resblock_to_replace:
|
|
for y in ["weight", "bias"]:
|
|
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
|
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
|
if k in sd:
|
|
sd[k_to] = sd.pop(k)
|
|
|
|
for y in ["weight", "bias"]:
|
|
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
|
if k_from in sd:
|
|
weights = sd.pop(k_from)
|
|
shape_from = weights.shape[0] // 3
|
|
for x in range(3):
|
|
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
|
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
|
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
|
|
|
return sd
|
|
|
|
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
|
|
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
|
|
|
|
tp = "{}text_projection.weight".format(prefix_from)
|
|
if tp in sd:
|
|
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
|
|
|
|
tp = "{}text_projection".format(prefix_from)
|
|
if tp in sd:
|
|
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous()
|
|
return sd
|
|
|
|
|
|
UNET_MAP_ATTENTIONS = {
|
|
"proj_in.weight",
|
|
"proj_in.bias",
|
|
"proj_out.weight",
|
|
"proj_out.bias",
|
|
"norm.weight",
|
|
"norm.bias",
|
|
}
|
|
|
|
TRANSFORMER_BLOCKS = {
|
|
"norm1.weight",
|
|
"norm1.bias",
|
|
"norm2.weight",
|
|
"norm2.bias",
|
|
"norm3.weight",
|
|
"norm3.bias",
|
|
"attn1.to_q.weight",
|
|
"attn1.to_k.weight",
|
|
"attn1.to_v.weight",
|
|
"attn1.to_out.0.weight",
|
|
"attn1.to_out.0.bias",
|
|
"attn2.to_q.weight",
|
|
"attn2.to_k.weight",
|
|
"attn2.to_v.weight",
|
|
"attn2.to_out.0.weight",
|
|
"attn2.to_out.0.bias",
|
|
"ff.net.0.proj.weight",
|
|
"ff.net.0.proj.bias",
|
|
"ff.net.2.weight",
|
|
"ff.net.2.bias",
|
|
}
|
|
|
|
UNET_MAP_RESNET = {
|
|
"in_layers.2.weight": "conv1.weight",
|
|
"in_layers.2.bias": "conv1.bias",
|
|
"emb_layers.1.weight": "time_emb_proj.weight",
|
|
"emb_layers.1.bias": "time_emb_proj.bias",
|
|
"out_layers.3.weight": "conv2.weight",
|
|
"out_layers.3.bias": "conv2.bias",
|
|
"skip_connection.weight": "conv_shortcut.weight",
|
|
"skip_connection.bias": "conv_shortcut.bias",
|
|
"in_layers.0.weight": "norm1.weight",
|
|
"in_layers.0.bias": "norm1.bias",
|
|
"out_layers.0.weight": "norm2.weight",
|
|
"out_layers.0.bias": "norm2.bias",
|
|
}
|
|
|
|
UNET_MAP_BASIC = {
|
|
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
|
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
|
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
|
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
|
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
|
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
|
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
|
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
|
("input_blocks.0.0.weight", "conv_in.weight"),
|
|
("input_blocks.0.0.bias", "conv_in.bias"),
|
|
("out.0.weight", "conv_norm_out.weight"),
|
|
("out.0.bias", "conv_norm_out.bias"),
|
|
("out.2.weight", "conv_out.weight"),
|
|
("out.2.bias", "conv_out.bias"),
|
|
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
|
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
|
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
|
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
|
}
|
|
|
|
def unet_to_diffusers(unet_config):
|
|
if "num_res_blocks" not in unet_config:
|
|
return {}
|
|
num_res_blocks = unet_config["num_res_blocks"]
|
|
channel_mult = unet_config["channel_mult"]
|
|
transformer_depth = unet_config["transformer_depth"][:]
|
|
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
|
num_blocks = len(channel_mult)
|
|
|
|
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
|
|
|
diffusers_unet_map = {}
|
|
for x in range(num_blocks):
|
|
n = 1 + (num_res_blocks[x] + 1) * x
|
|
for i in range(num_res_blocks[x]):
|
|
for b in UNET_MAP_RESNET:
|
|
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
|
num_transformers = transformer_depth.pop(0)
|
|
if num_transformers > 0:
|
|
for b in UNET_MAP_ATTENTIONS:
|
|
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
|
for t in range(num_transformers):
|
|
for b in TRANSFORMER_BLOCKS:
|
|
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
|
n += 1
|
|
for k in ["weight", "bias"]:
|
|
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
|
|
|
i = 0
|
|
for b in UNET_MAP_ATTENTIONS:
|
|
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
|
for t in range(transformers_mid):
|
|
for b in TRANSFORMER_BLOCKS:
|
|
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
|
|
|
for i, n in enumerate([0, 2]):
|
|
for b in UNET_MAP_RESNET:
|
|
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
|
|
|
num_res_blocks = list(reversed(num_res_blocks))
|
|
for x in range(num_blocks):
|
|
n = (num_res_blocks[x] + 1) * x
|
|
l = num_res_blocks[x] + 1
|
|
for i in range(l):
|
|
c = 0
|
|
for b in UNET_MAP_RESNET:
|
|
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
|
c += 1
|
|
num_transformers = transformer_depth_output.pop()
|
|
if num_transformers > 0:
|
|
c += 1
|
|
for b in UNET_MAP_ATTENTIONS:
|
|
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
|
for t in range(num_transformers):
|
|
for b in TRANSFORMER_BLOCKS:
|
|
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
|
if i == l - 1:
|
|
for k in ["weight", "bias"]:
|
|
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
|
n += 1
|
|
|
|
for k in UNET_MAP_BASIC:
|
|
diffusers_unet_map[k[1]] = k[0]
|
|
|
|
return diffusers_unet_map
|
|
|
|
def swap_scale_shift(weight):
|
|
shift, scale = weight.chunk(2, dim=0)
|
|
new_weight = torch.cat([scale, shift], dim=0)
|
|
return new_weight
|
|
|
|
MMDIT_MAP_BASIC = {
|
|
("context_embedder.bias", "context_embedder.bias"),
|
|
("context_embedder.weight", "context_embedder.weight"),
|
|
("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
|
("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
|
("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
|
("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
|
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
|
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
|
("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
|
("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
|
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
|
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
|
("pos_embed", "pos_embed.pos_embed"),
|
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
|
("final_layer.linear.bias", "proj_out.bias"),
|
|
("final_layer.linear.weight", "proj_out.weight"),
|
|
}
|
|
|
|
MMDIT_MAP_BLOCK = {
|
|
("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"),
|
|
("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"),
|
|
("context_block.attn.proj.bias", "attn.to_add_out.bias"),
|
|
("context_block.attn.proj.weight", "attn.to_add_out.weight"),
|
|
("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"),
|
|
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
|
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
|
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
|
("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"),
|
|
("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"),
|
|
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
|
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
|
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
|
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
|
("x_block.attn.ln_q.weight", "attn.norm_q.weight"),
|
|
("x_block.attn.ln_k.weight", "attn.norm_k.weight"),
|
|
("x_block.attn2.proj.bias", "attn2.to_out.0.bias"),
|
|
("x_block.attn2.proj.weight", "attn2.to_out.0.weight"),
|
|
("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"),
|
|
("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"),
|
|
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
|
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
|
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
|
("x_block.mlp.fc2.weight", "ff.net.2.weight"),
|
|
}
|
|
|
|
def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
|
key_map = {}
|
|
|
|
depth = mmdit_config.get("depth", 0)
|
|
num_blocks = mmdit_config.get("num_blocks", depth)
|
|
for i in range(num_blocks):
|
|
block_from = "transformer_blocks.{}".format(i)
|
|
block_to = "{}joint_blocks.{}".format(output_prefix, i)
|
|
|
|
offset = depth * 64
|
|
|
|
for end in ("weight", "bias"):
|
|
k = "{}.attn.".format(block_from)
|
|
qkv = "{}.x_block.attn.qkv.{}".format(block_to, end)
|
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
|
|
|
qkv = "{}.context_block.attn.qkv.{}".format(block_to, end)
|
|
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset))
|
|
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
|
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
|
|
|
k = "{}.attn2.".format(block_from)
|
|
qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end)
|
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
|
|
|
for k in MMDIT_MAP_BLOCK:
|
|
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
|
|
|
map_basic = MMDIT_MAP_BASIC.copy()
|
|
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
|
|
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
|
|
|
|
for k in map_basic:
|
|
if len(k) > 2:
|
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
|
else:
|
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
|
|
|
return key_map
|
|
|
|
PIXART_MAP_BASIC = {
|
|
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
|
|
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
|
|
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
|
|
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
|
|
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
|
|
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
|
|
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
|
|
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
|
|
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
|
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
|
("y_embedder.y_embedding", "caption_projection.y_embedding"),
|
|
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
|
|
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
|
|
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
|
|
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
|
|
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
|
|
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
|
|
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
|
|
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
|
|
("t_block.1.weight", "adaln_single.linear.weight"),
|
|
("t_block.1.bias", "adaln_single.linear.bias"),
|
|
("final_layer.linear.weight", "proj_out.weight"),
|
|
("final_layer.linear.bias", "proj_out.bias"),
|
|
("final_layer.scale_shift_table", "scale_shift_table"),
|
|
}
|
|
|
|
PIXART_MAP_BLOCK = {
|
|
("scale_shift_table", "scale_shift_table"),
|
|
("attn.proj.weight", "attn1.to_out.0.weight"),
|
|
("attn.proj.bias", "attn1.to_out.0.bias"),
|
|
("mlp.fc1.weight", "ff.net.0.proj.weight"),
|
|
("mlp.fc1.bias", "ff.net.0.proj.bias"),
|
|
("mlp.fc2.weight", "ff.net.2.weight"),
|
|
("mlp.fc2.bias", "ff.net.2.bias"),
|
|
("cross_attn.proj.weight" ,"attn2.to_out.0.weight"),
|
|
("cross_attn.proj.bias" ,"attn2.to_out.0.bias"),
|
|
}
|
|
|
|
def pixart_to_diffusers(mmdit_config, output_prefix=""):
|
|
key_map = {}
|
|
|
|
depth = mmdit_config.get("depth", 0)
|
|
offset = mmdit_config.get("hidden_size", 1152)
|
|
|
|
for i in range(depth):
|
|
block_from = "transformer_blocks.{}".format(i)
|
|
block_to = "{}blocks.{}".format(output_prefix, i)
|
|
|
|
for end in ("weight", "bias"):
|
|
s = "{}.attn1.".format(block_from)
|
|
qkv = "{}.attn.qkv.{}".format(block_to, end)
|
|
key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset))
|
|
key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset))
|
|
key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset))
|
|
|
|
s = "{}.attn2.".format(block_from)
|
|
q = "{}.cross_attn.q_linear.{}".format(block_to, end)
|
|
kv = "{}.cross_attn.kv_linear.{}".format(block_to, end)
|
|
|
|
key_map["{}to_q.{}".format(s, end)] = q
|
|
key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset))
|
|
key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset))
|
|
|
|
for k in PIXART_MAP_BLOCK:
|
|
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
|
|
|
for k in PIXART_MAP_BASIC:
|
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
|
|
|
return key_map
|
|
|
|
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
|
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
|
n_layers = mmdit_config.get("n_layers", 0)
|
|
|
|
key_map = {}
|
|
for i in range(n_layers):
|
|
if i < n_double_layers:
|
|
index = i
|
|
prefix_from = "joint_transformer_blocks"
|
|
prefix_to = "{}double_layers".format(output_prefix)
|
|
block_map = {
|
|
"attn.to_q.weight": "attn.w2q.weight",
|
|
"attn.to_k.weight": "attn.w2k.weight",
|
|
"attn.to_v.weight": "attn.w2v.weight",
|
|
"attn.to_out.0.weight": "attn.w2o.weight",
|
|
"attn.add_q_proj.weight": "attn.w1q.weight",
|
|
"attn.add_k_proj.weight": "attn.w1k.weight",
|
|
"attn.add_v_proj.weight": "attn.w1v.weight",
|
|
"attn.to_add_out.weight": "attn.w1o.weight",
|
|
"ff.linear_1.weight": "mlpX.c_fc1.weight",
|
|
"ff.linear_2.weight": "mlpX.c_fc2.weight",
|
|
"ff.out_projection.weight": "mlpX.c_proj.weight",
|
|
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
|
|
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
|
|
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
|
|
"norm1.linear.weight": "modX.1.weight",
|
|
"norm1_context.linear.weight": "modC.1.weight",
|
|
}
|
|
else:
|
|
index = i - n_double_layers
|
|
prefix_from = "single_transformer_blocks"
|
|
prefix_to = "{}single_layers".format(output_prefix)
|
|
|
|
block_map = {
|
|
"attn.to_q.weight": "attn.w1q.weight",
|
|
"attn.to_k.weight": "attn.w1k.weight",
|
|
"attn.to_v.weight": "attn.w1v.weight",
|
|
"attn.to_out.0.weight": "attn.w1o.weight",
|
|
"norm1.linear.weight": "modCX.1.weight",
|
|
"ff.linear_1.weight": "mlp.c_fc1.weight",
|
|
"ff.linear_2.weight": "mlp.c_fc2.weight",
|
|
"ff.out_projection.weight": "mlp.c_proj.weight"
|
|
}
|
|
|
|
for k in block_map:
|
|
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
|
|
|
|
MAP_BASIC = {
|
|
("positional_encoding", "pos_embed.pos_embed"),
|
|
("register_tokens", "register_tokens"),
|
|
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
|
|
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
|
|
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
|
|
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
|
|
("cond_seq_linear.weight", "context_embedder.weight"),
|
|
("init_x_linear.weight", "pos_embed.proj.weight"),
|
|
("init_x_linear.bias", "pos_embed.proj.bias"),
|
|
("final_linear.weight", "proj_out.weight"),
|
|
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
|
}
|
|
|
|
for k in MAP_BASIC:
|
|
if len(k) > 2:
|
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
|
else:
|
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
|
|
|
return key_map
|
|
|
|
def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|
n_double_layers = mmdit_config.get("depth", 0)
|
|
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
|
|
hidden_size = mmdit_config.get("hidden_size", 0)
|
|
|
|
key_map = {}
|
|
for index in range(n_double_layers):
|
|
prefix_from = "transformer_blocks.{}".format(index)
|
|
prefix_to = "{}double_blocks.{}".format(output_prefix, index)
|
|
|
|
for end in ("weight", "bias"):
|
|
k = "{}.attn.".format(prefix_from)
|
|
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
|
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
|
|
|
k = "{}.attn.".format(prefix_from)
|
|
qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
|
|
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
|
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
|
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
|
|
|
block_map = {
|
|
"attn.to_out.0.weight": "img_attn.proj.weight",
|
|
"attn.to_out.0.bias": "img_attn.proj.bias",
|
|
"norm1.linear.weight": "img_mod.lin.weight",
|
|
"norm1.linear.bias": "img_mod.lin.bias",
|
|
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
|
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
|
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
|
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
|
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
|
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
|
"ff.net.2.weight": "img_mlp.2.weight",
|
|
"ff.net.2.bias": "img_mlp.2.bias",
|
|
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
|
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
|
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
|
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
|
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
|
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
|
}
|
|
|
|
for k in block_map:
|
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
|
|
|
for index in range(n_single_layers):
|
|
prefix_from = "single_transformer_blocks.{}".format(index)
|
|
prefix_to = "{}single_blocks.{}".format(output_prefix, index)
|
|
|
|
for end in ("weight", "bias"):
|
|
k = "{}.attn.".format(prefix_from)
|
|
qkv = "{}.linear1.{}".format(prefix_to, end)
|
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
|
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
|
|
|
|
block_map = {
|
|
"norm.linear.weight": "modulation.lin.weight",
|
|
"norm.linear.bias": "modulation.lin.bias",
|
|
"proj_out.weight": "linear2.weight",
|
|
"proj_out.bias": "linear2.bias",
|
|
"attn.norm_q.weight": "norm.query_norm.scale",
|
|
"attn.norm_k.weight": "norm.key_norm.scale",
|
|
}
|
|
|
|
for k in block_map:
|
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
|
|
|
MAP_BASIC = {
|
|
("final_layer.linear.bias", "proj_out.bias"),
|
|
("final_layer.linear.weight", "proj_out.weight"),
|
|
("img_in.bias", "x_embedder.bias"),
|
|
("img_in.weight", "x_embedder.weight"),
|
|
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
|
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
|
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
|
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
|
("txt_in.bias", "context_embedder.bias"),
|
|
("txt_in.weight", "context_embedder.weight"),
|
|
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
|
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
|
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
|
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
|
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
|
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
|
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
|
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
|
|
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
|
|
}
|
|
|
|
for k in MAP_BASIC:
|
|
if len(k) > 2:
|
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
|
else:
|
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
|
|
|
return key_map
|
|
|
|
def z_image_to_diffusers(mmdit_config, output_prefix=""):
|
|
n_layers = mmdit_config.get("n_layers", 0)
|
|
hidden_size = mmdit_config.get("dim", 0)
|
|
n_context_refiner = mmdit_config.get("n_refiner_layers", 2)
|
|
n_noise_refiner = mmdit_config.get("n_refiner_layers", 2)
|
|
key_map = {}
|
|
|
|
def add_block_keys(prefix_from, prefix_to, has_adaln=True):
|
|
for end in ("weight", "bias"):
|
|
k = "{}.attention.".format(prefix_from)
|
|
qkv = "{}.attention.qkv.{}".format(prefix_to, end)
|
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
|
|
|
block_map = {
|
|
"attention.norm_q.weight": "attention.q_norm.weight",
|
|
"attention.norm_k.weight": "attention.k_norm.weight",
|
|
"attention.to_out.0.weight": "attention.out.weight",
|
|
"attention.to_out.0.bias": "attention.out.bias",
|
|
"attention_norm1.weight": "attention_norm1.weight",
|
|
"attention_norm2.weight": "attention_norm2.weight",
|
|
"feed_forward.w1.weight": "feed_forward.w1.weight",
|
|
"feed_forward.w2.weight": "feed_forward.w2.weight",
|
|
"feed_forward.w3.weight": "feed_forward.w3.weight",
|
|
"ffn_norm1.weight": "ffn_norm1.weight",
|
|
"ffn_norm2.weight": "ffn_norm2.weight",
|
|
}
|
|
if has_adaln:
|
|
block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight"
|
|
block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias"
|
|
for k, v in block_map.items():
|
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v)
|
|
|
|
for i in range(n_layers):
|
|
add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i))
|
|
|
|
for i in range(n_context_refiner):
|
|
add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i))
|
|
|
|
for i in range(n_noise_refiner):
|
|
add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i))
|
|
|
|
MAP_BASIC = [
|
|
("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"),
|
|
("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"),
|
|
("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"),
|
|
("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"),
|
|
("x_embedder.weight", "all_x_embedder.2-1.weight"),
|
|
("x_embedder.bias", "all_x_embedder.2-1.bias"),
|
|
("x_pad_token", "x_pad_token"),
|
|
("cap_embedder.0.weight", "cap_embedder.0.weight"),
|
|
("cap_embedder.1.weight", "cap_embedder.1.weight"),
|
|
("cap_embedder.1.bias", "cap_embedder.1.bias"),
|
|
("cap_pad_token", "cap_pad_token"),
|
|
("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"),
|
|
("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"),
|
|
("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"),
|
|
("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"),
|
|
]
|
|
|
|
for c, diffusers in MAP_BASIC:
|
|
key_map[diffusers] = "{}{}".format(output_prefix, c)
|
|
|
|
return key_map
|
|
|
|
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
|
if tensor.shape[dim] > batch_size:
|
|
return tensor.narrow(dim, 0, batch_size)
|
|
elif tensor.shape[dim] < batch_size:
|
|
return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size)
|
|
return tensor
|
|
|
|
def resize_to_batch_size(tensor, batch_size):
|
|
in_batch_size = tensor.shape[0]
|
|
if in_batch_size == batch_size:
|
|
return tensor
|
|
|
|
if batch_size <= 1:
|
|
return tensor[:batch_size]
|
|
|
|
output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device)
|
|
if batch_size < in_batch_size:
|
|
scale = (in_batch_size - 1) / (batch_size - 1)
|
|
for i in range(batch_size):
|
|
output[i] = tensor[min(round(i * scale), in_batch_size - 1)]
|
|
else:
|
|
scale = in_batch_size / batch_size
|
|
for i in range(batch_size):
|
|
output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]
|
|
|
|
return output
|
|
|
|
def resize_list_to_batch_size(l, batch_size):
|
|
in_batch_size = len(l)
|
|
if in_batch_size == batch_size or in_batch_size == 0:
|
|
return l
|
|
|
|
if batch_size <= 1:
|
|
return l[:batch_size]
|
|
|
|
output = []
|
|
if batch_size < in_batch_size:
|
|
scale = (in_batch_size - 1) / (batch_size - 1)
|
|
for i in range(batch_size):
|
|
output.append(l[min(round(i * scale), in_batch_size - 1)])
|
|
else:
|
|
scale = in_batch_size / batch_size
|
|
for i in range(batch_size):
|
|
output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
|
|
|
|
return output
|
|
|
|
def convert_sd_to(state_dict, dtype):
|
|
keys = list(state_dict.keys())
|
|
for k in keys:
|
|
state_dict[k] = state_dict[k].to(dtype)
|
|
return state_dict
|
|
|
|
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
|
with open(safetensors_path, "rb") as f:
|
|
header = f.read(8)
|
|
length_of_header = struct.unpack('<Q', header)[0]
|
|
if length_of_header > max_size:
|
|
return None
|
|
return f.read(length_of_header)
|
|
|
|
ATTR_UNSET={}
|
|
|
|
def set_attr(obj, attr, value):
|
|
attrs = attr.split(".")
|
|
for name in attrs[:-1]:
|
|
obj = getattr(obj, name)
|
|
prev = getattr(obj, attrs[-1], ATTR_UNSET)
|
|
if value is ATTR_UNSET:
|
|
delattr(obj, attrs[-1])
|
|
else:
|
|
setattr(obj, attrs[-1], value)
|
|
return prev
|
|
|
|
def set_attr_param(obj, attr, value):
|
|
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
|
|
|
def copy_to_param(obj, attr, value):
|
|
# inplace update tensor instead of replacing it
|
|
attrs = attr.split(".")
|
|
for name in attrs[:-1]:
|
|
obj = getattr(obj, name)
|
|
prev = getattr(obj, attrs[-1])
|
|
prev.data.copy_(value)
|
|
|
|
def get_attr(obj, attr: str):
|
|
"""Retrieves a nested attribute from an object using dot notation.
|
|
|
|
Args:
|
|
obj: The object to get the attribute from
|
|
attr (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
|
|
|
Returns:
|
|
The value of the requested attribute
|
|
|
|
Example:
|
|
model = MyModel()
|
|
weight = get_attr(model, "layer1.conv.weight")
|
|
# Equivalent to: model.layer1.conv.weight
|
|
|
|
Important:
|
|
Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when
|
|
accessing nested model objects under `ModelPatcher.model`.
|
|
"""
|
|
attrs = attr.split(".")
|
|
for name in attrs:
|
|
obj = getattr(obj, name)
|
|
return obj
|
|
|
|
def bislerp(samples, width, height):
|
|
def slerp(b1, b2, r):
|
|
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
|
|
|
c = b1.shape[-1]
|
|
|
|
#norms
|
|
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
|
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
|
|
|
#normalize
|
|
b1_normalized = b1 / b1_norms
|
|
b2_normalized = b2 / b2_norms
|
|
|
|
#zero when norms are zero
|
|
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
|
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
|
|
|
#slerp
|
|
dot = (b1_normalized*b2_normalized).sum(1)
|
|
omega = torch.acos(dot)
|
|
so = torch.sin(omega)
|
|
|
|
#technically not mathematically correct, but more pleasing?
|
|
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
|
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
|
|
|
#edge cases for same or polar opposites
|
|
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
|
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
|
return res
|
|
|
|
def generate_bilinear_data(length_old, length_new, device):
|
|
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
|
|
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
|
ratios = coords_1 - coords_1.floor()
|
|
coords_1 = coords_1.to(torch.int64)
|
|
|
|
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
|
|
coords_2[:,:,:,-1] -= 1
|
|
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
|
coords_2 = coords_2.to(torch.int64)
|
|
return ratios, coords_1, coords_2
|
|
|
|
orig_dtype = samples.dtype
|
|
samples = samples.float()
|
|
n,c,h,w = samples.shape
|
|
h_new, w_new = (height, width)
|
|
|
|
#linear w
|
|
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
|
coords_1 = coords_1.expand((n, c, h, -1))
|
|
coords_2 = coords_2.expand((n, c, h, -1))
|
|
ratios = ratios.expand((n, 1, h, -1))
|
|
|
|
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
|
|
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
|
|
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
|
|
|
result = slerp(pass_1, pass_2, ratios)
|
|
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
|
|
|
#linear h
|
|
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
|
|
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
|
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
|
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
|
|
|
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
|
|
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
|
|
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
|
|
|
result = slerp(pass_1, pass_2, ratios)
|
|
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
|
return result.to(orig_dtype)
|
|
|
|
def lanczos(samples, width, height):
|
|
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
|
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
|
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
|
result = torch.stack(images)
|
|
return result.to(samples.device, samples.dtype)
|
|
|
|
def common_upscale(samples, width, height, upscale_method, crop):
|
|
orig_shape = tuple(samples.shape)
|
|
if len(orig_shape) > 4:
|
|
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
|
|
samples = samples.movedim(2, 1)
|
|
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
|
|
if crop == "center":
|
|
old_width = samples.shape[-1]
|
|
old_height = samples.shape[-2]
|
|
old_aspect = old_width / old_height
|
|
new_aspect = width / height
|
|
x = 0
|
|
y = 0
|
|
if old_aspect > new_aspect:
|
|
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
|
elif old_aspect < new_aspect:
|
|
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
|
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
|
|
else:
|
|
s = samples
|
|
|
|
if upscale_method == "bislerp":
|
|
out = bislerp(s, width, height)
|
|
elif upscale_method == "lanczos":
|
|
out = lanczos(s, width, height)
|
|
else:
|
|
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
|
|
|
if len(orig_shape) == 4:
|
|
return out
|
|
|
|
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
|
|
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
|
|
|
|
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
|
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
|
|
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
|
|
return rows * cols
|
|
|
|
@torch.inference_mode()
|
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
|
|
dims = len(tile)
|
|
|
|
if not (isinstance(upscale_amount, (tuple, list))):
|
|
upscale_amount = [upscale_amount] * dims
|
|
|
|
if not (isinstance(overlap, (tuple, list))):
|
|
overlap = [overlap] * dims
|
|
|
|
if index_formulas is None:
|
|
index_formulas = upscale_amount
|
|
|
|
if not (isinstance(index_formulas, (tuple, list))):
|
|
index_formulas = [index_formulas] * dims
|
|
|
|
def get_upscale(dim, val):
|
|
up = upscale_amount[dim]
|
|
if callable(up):
|
|
return up(val)
|
|
else:
|
|
return up * val
|
|
|
|
def get_downscale(dim, val):
|
|
up = upscale_amount[dim]
|
|
if callable(up):
|
|
return up(val)
|
|
else:
|
|
return val / up
|
|
|
|
def get_upscale_pos(dim, val):
|
|
up = index_formulas[dim]
|
|
if callable(up):
|
|
return up(val)
|
|
else:
|
|
return up * val
|
|
|
|
def get_downscale_pos(dim, val):
|
|
up = index_formulas[dim]
|
|
if callable(up):
|
|
return up(val)
|
|
else:
|
|
return val / up
|
|
|
|
if downscale:
|
|
get_scale = get_downscale
|
|
get_pos = get_downscale_pos
|
|
else:
|
|
get_scale = get_upscale
|
|
get_pos = get_upscale_pos
|
|
|
|
def mult_list_upscale(a):
|
|
out = []
|
|
for i in range(len(a)):
|
|
out.append(round(get_scale(i, a[i])))
|
|
return out
|
|
|
|
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
|
|
|
for b in range(samples.shape[0]):
|
|
s = samples[b:b+1]
|
|
|
|
# handle entire input fitting in a single tile
|
|
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
|
output[b:b+1] = function(s).to(output_device)
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
continue
|
|
|
|
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
|
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
|
|
|
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
|
|
|
for it in itertools.product(*positions):
|
|
s_in = s
|
|
upscaled = []
|
|
|
|
for d in range(dims):
|
|
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
|
l = min(tile[d], s.shape[d + 2] - pos)
|
|
s_in = s_in.narrow(d + 2, pos, l)
|
|
upscaled.append(round(get_pos(d, pos)))
|
|
|
|
ps = function(s_in).to(output_device)
|
|
mask = torch.ones_like(ps)
|
|
|
|
for d in range(2, dims + 2):
|
|
feather = round(get_scale(d - 2, overlap[d - 2]))
|
|
if feather >= mask.shape[d]:
|
|
continue
|
|
for t in range(feather):
|
|
a = (t + 1) / feather
|
|
mask.narrow(d, t, 1).mul_(a)
|
|
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
|
|
|
o = out
|
|
o_d = out_div
|
|
for d in range(dims):
|
|
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
|
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
|
|
|
o.add_(ps * mask)
|
|
o_d.add_(mask)
|
|
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
|
|
output[b:b+1] = out/out_div
|
|
return output
|
|
|
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
|
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
|
|
|
PROGRESS_BAR_ENABLED = True
|
|
def set_progress_bar_enabled(enabled):
|
|
global PROGRESS_BAR_ENABLED
|
|
PROGRESS_BAR_ENABLED = enabled
|
|
|
|
PROGRESS_BAR_HOOK = None
|
|
def set_progress_bar_global_hook(function):
|
|
global PROGRESS_BAR_HOOK
|
|
PROGRESS_BAR_HOOK = function
|
|
|
|
class ProgressBar:
|
|
def __init__(self, total, node_id=None):
|
|
global PROGRESS_BAR_HOOK
|
|
self.total = total
|
|
self.current = 0
|
|
self.hook = PROGRESS_BAR_HOOK
|
|
self.node_id = node_id
|
|
|
|
def update_absolute(self, value, total=None, preview=None):
|
|
if total is not None:
|
|
self.total = total
|
|
if value > self.total:
|
|
value = self.total
|
|
self.current = value
|
|
if self.hook is not None:
|
|
self.hook(self.current, self.total, preview, node_id=self.node_id)
|
|
|
|
def update(self, value):
|
|
self.update_absolute(self.current + value)
|
|
|
|
def reshape_mask(input_mask, output_shape):
|
|
dims = len(output_shape) - 2
|
|
|
|
if dims == 1:
|
|
scale_mode = "linear"
|
|
|
|
if dims == 2:
|
|
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
|
|
scale_mode = "bilinear"
|
|
|
|
if dims == 3:
|
|
if len(input_mask.shape) < 5:
|
|
input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
|
|
scale_mode = "trilinear"
|
|
|
|
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
|
|
if mask.shape[1] < output_shape[1]:
|
|
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
|
mask = repeat_to_batch_size(mask, output_shape[0])
|
|
return mask
|
|
|
|
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
|
hi, wi = img_size_in
|
|
ho, wo = img_size_out
|
|
# if it's already the correct size, no need to do anything
|
|
if (hi, wi) == (ho, wo):
|
|
return mask
|
|
if mask.ndim == 2:
|
|
mask = mask.unsqueeze(0)
|
|
if mask.ndim != 3:
|
|
raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]")
|
|
txt_tokens = mask.shape[1] - (hi * wi)
|
|
# quadrants of the mask
|
|
txt_to_txt = mask[:, :txt_tokens, :txt_tokens]
|
|
txt_to_img = mask[:, :txt_tokens, txt_tokens:]
|
|
img_to_img = mask[:, txt_tokens:, txt_tokens:]
|
|
img_to_txt = mask[:, txt_tokens:, :txt_tokens]
|
|
|
|
# convert to 1d x 2d, interpolate, then back to 1d x 1d
|
|
txt_to_img = rearrange (txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi)
|
|
txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear")
|
|
txt_to_img = rearrange (txt_to_img, "b t h w -> b t (h w)")
|
|
# this one is hard because we have to do it twice
|
|
# convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d
|
|
img_to_img = rearrange (img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi)
|
|
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
|
|
img_to_img = rearrange (img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi)
|
|
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
|
|
img_to_img = rearrange (img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo)
|
|
# convert to 2d x 1d, interpolate, then back to 1d x 1d
|
|
img_to_txt = rearrange (img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi)
|
|
img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear")
|
|
img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
|
|
|
|
# reassemble the mask from blocks
|
|
out = torch.cat([
|
|
torch.cat([txt_to_txt, txt_to_img], dim=2),
|
|
torch.cat([img_to_txt, img_to_img], dim=2)],
|
|
dim=1
|
|
)
|
|
return out
|
|
|
|
def pack_latents(latents):
|
|
latent_shapes = []
|
|
tensors = []
|
|
for tensor in latents:
|
|
latent_shapes.append(tensor.shape)
|
|
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
|
|
|
|
latent = torch.cat(tensors, dim=-1)
|
|
return latent, latent_shapes
|
|
|
|
def unpack_latents(combined_latent, latent_shapes):
|
|
if len(latent_shapes) > 1:
|
|
output_tensors = []
|
|
for shape in latent_shapes:
|
|
cut = math.prod(shape[1:])
|
|
tens = combined_latent[:, :, :cut]
|
|
combined_latent = combined_latent[:, :, cut:]
|
|
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
|
else:
|
|
output_tensors = combined_latent
|
|
return output_tensors
|
|
|
|
def detect_layer_quantization(state_dict, prefix):
|
|
for k in state_dict:
|
|
if k.startswith(prefix) and k.endswith(".comfy_quant"):
|
|
logging.info("Found quantization metadata version 1")
|
|
return {"mixed_ops": True}
|
|
return None
|
|
|
|
def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
|
if metadata is None:
|
|
metadata = {}
|
|
|
|
quant_metadata = None
|
|
if "_quantization_metadata" not in metadata:
|
|
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
|
|
|
|
if scaled_fp8_key in state_dict:
|
|
scaled_fp8_weight = state_dict[scaled_fp8_key]
|
|
scaled_fp8_dtype = scaled_fp8_weight.dtype
|
|
if scaled_fp8_dtype == torch.float32:
|
|
scaled_fp8_dtype = torch.float8_e4m3fn
|
|
|
|
if scaled_fp8_weight.nelement() == 2:
|
|
full_precision_matrix_mult = True
|
|
else:
|
|
full_precision_matrix_mult = False
|
|
|
|
out_sd = {}
|
|
layers = {}
|
|
for k in list(state_dict.keys()):
|
|
if not k.startswith(model_prefix):
|
|
out_sd[k] = state_dict[k]
|
|
continue
|
|
k_out = k
|
|
w = state_dict.pop(k)
|
|
layer = None
|
|
if k_out.endswith(".scale_weight"):
|
|
layer = k_out[:-len(".scale_weight")]
|
|
k_out = "{}.weight_scale".format(layer)
|
|
|
|
if layer is not None:
|
|
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
|
if full_precision_matrix_mult:
|
|
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
|
layers[layer] = layer_conf
|
|
|
|
if k_out.endswith(".scale_input"):
|
|
layer = k_out[:-len(".scale_input")]
|
|
k_out = "{}.input_scale".format(layer)
|
|
if w.item() == 1.0:
|
|
continue
|
|
|
|
out_sd[k_out] = w
|
|
|
|
state_dict = out_sd
|
|
quant_metadata = {"layers": layers}
|
|
else:
|
|
quant_metadata = json.loads(metadata["_quantization_metadata"])
|
|
|
|
if quant_metadata is not None:
|
|
layers = quant_metadata["layers"]
|
|
for k, v in layers.items():
|
|
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
|
|
|
return state_dict, metadata
|