mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-08 14:59:12 +08:00
use selected load device as LoRA load device too
This commit is contained in:
parent
e52dc36bc5
commit
276a045a57
@ -406,8 +406,8 @@ def merge_lora(transformer, lora_path, multiplier, device='cpu', dtype=torch.flo
|
|||||||
else:
|
else:
|
||||||
temp_name = layer_infos.pop(0)
|
temp_name = layer_infos.pop(0)
|
||||||
|
|
||||||
weight_up = elems['lora_up.weight'].to(dtype)
|
weight_up = elems['lora_up.weight'].to(dtype).to(device)
|
||||||
weight_down = elems['lora_down.weight'].to(dtype)
|
weight_down = elems['lora_down.weight'].to(dtype).to(device)
|
||||||
if 'alpha' in elems.keys():
|
if 'alpha' in elems.keys():
|
||||||
alpha = elems['alpha'].item() / weight_up.shape[1]
|
alpha = elems['alpha'].item() / weight_up.shape[1]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -241,51 +241,51 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
|
|
||||||
#LoRAs
|
#LoRAs
|
||||||
if lora is not None:
|
if lora is not None:
|
||||||
dimensionx_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling
|
dimensionx_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling
|
||||||
dimensionx_lora = False
|
dimensionx_lora = False
|
||||||
adapter_list = []
|
adapter_list = []
|
||||||
adapter_weights = []
|
adapter_weights = []
|
||||||
for l in lora:
|
for l in lora:
|
||||||
if any(item in l["path"].lower() for item in dimensionx_loras):
|
if any(item in l["path"].lower() for item in dimensionx_loras):
|
||||||
dimensionx_lora = True
|
dimensionx_lora = True
|
||||||
fuse = True if l["fuse_lora"] else False
|
fuse = True if l["fuse_lora"] else False
|
||||||
lora_sd = load_torch_file(l["path"])
|
lora_sd = load_torch_file(l["path"])
|
||||||
lora_rank = None
|
lora_rank = None
|
||||||
for key, val in lora_sd.items():
|
for key, val in lora_sd.items():
|
||||||
if "lora_B" in key:
|
if "lora_B" in key:
|
||||||
lora_rank = val.shape[1]
|
lora_rank = val.shape[1]
|
||||||
break
|
break
|
||||||
if lora_rank is not None:
|
if lora_rank is not None:
|
||||||
log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}")
|
log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}")
|
||||||
adapter_name = l['path'].split("/")[-1].split(".")[0]
|
adapter_name = l['path'].split("/")[-1].split(".")[0]
|
||||||
adapter_weight = l['strength']
|
adapter_weight = l['strength']
|
||||||
pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name)
|
pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name)
|
||||||
|
|
||||||
adapter_list.append(adapter_name)
|
adapter_list.append(adapter_name)
|
||||||
adapter_weights.append(adapter_weight)
|
adapter_weights.append(adapter_weight)
|
||||||
else:
|
else:
|
||||||
try: #Fun trainer LoRAs are loaded differently
|
try: #Fun trainer LoRAs are loaded differently
|
||||||
from .lora_utils import merge_lora
|
from .lora_utils import merge_lora
|
||||||
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
|
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
|
||||||
transformer = merge_lora(transformer, l["path"], l["strength"])
|
pipe.transformer = merge_lora(pipe.transformer, l["path"], l["strength"], device=transformer_load_device, state_dict=lora_sd)
|
||||||
except:
|
except:
|
||||||
raise ValueError(f"Can't recognize LoRA {l['path']}")
|
raise ValueError(f"Can't recognize LoRA {l['path']}")
|
||||||
if adapter_list:
|
if adapter_list:
|
||||||
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
|
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
|
||||||
if fuse:
|
if fuse:
|
||||||
lora_scale = 1
|
lora_scale = 1
|
||||||
if dimensionx_lora:
|
if dimensionx_lora:
|
||||||
lora_scale = lora_scale / lora_rank
|
lora_scale = lora_scale / lora_rank
|
||||||
pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"])
|
pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"])
|
||||||
|
|
||||||
|
|
||||||
if "fused" in attention_mode:
|
if "fused" in attention_mode:
|
||||||
from diffusers.models.attention import Attention
|
from diffusers.models.attention import Attention
|
||||||
transformer.fuse_qkv_projections = True
|
pipe.transformer.fuse_qkv_projections = True
|
||||||
for module in transformer.modules():
|
for module in pipe.transformer.modules():
|
||||||
if isinstance(module, Attention):
|
if isinstance(module, Attention):
|
||||||
module.fuse_projections(fuse=True)
|
module.fuse_projections(fuse=True)
|
||||||
transformer.attention_mode = attention_mode
|
pipe.transformer.attention_mode = attention_mode
|
||||||
|
|
||||||
if compile_args is not None:
|
if compile_args is not None:
|
||||||
pipe.transformer.to(memory_format=torch.channels_last)
|
pipe.transformer.to(memory_format=torch.channels_last)
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -602,6 +602,7 @@ class CogVideoSampler:
|
|||||||
|
|
||||||
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
|
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
|
||||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None):
|
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None):
|
||||||
|
mm.unload_all_models()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
model_name = model.get("model_name", "")
|
model_name = model.get("model_name", "")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user