use selected load device as LoRA load device too

This commit is contained in:
kijai 2024-11-21 02:46:07 +02:00
parent e52dc36bc5
commit 276a045a57
3 changed files with 42 additions and 41 deletions

View File

@ -406,8 +406,8 @@ def merge_lora(transformer, lora_path, multiplier, device='cpu', dtype=torch.flo
else: else:
temp_name = layer_infos.pop(0) temp_name = layer_infos.pop(0)
weight_up = elems['lora_up.weight'].to(dtype) weight_up = elems['lora_up.weight'].to(dtype).to(device)
weight_down = elems['lora_down.weight'].to(dtype) weight_down = elems['lora_down.weight'].to(dtype).to(device)
if 'alpha' in elems.keys(): if 'alpha' in elems.keys():
alpha = elems['alpha'].item() / weight_up.shape[1] alpha = elems['alpha'].item() / weight_up.shape[1]
else: else:

View File

@ -267,7 +267,7 @@ class DownloadAndLoadCogVideoModel:
try: #Fun trainer LoRAs are loaded differently try: #Fun trainer LoRAs are loaded differently
from .lora_utils import merge_lora from .lora_utils import merge_lora
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}") log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
transformer = merge_lora(transformer, l["path"], l["strength"]) pipe.transformer = merge_lora(pipe.transformer, l["path"], l["strength"], device=transformer_load_device, state_dict=lora_sd)
except: except:
raise ValueError(f"Can't recognize LoRA {l['path']}") raise ValueError(f"Can't recognize LoRA {l['path']}")
if adapter_list: if adapter_list:
@ -281,11 +281,11 @@ class DownloadAndLoadCogVideoModel:
if "fused" in attention_mode: if "fused" in attention_mode:
from diffusers.models.attention import Attention from diffusers.models.attention import Attention
transformer.fuse_qkv_projections = True pipe.transformer.fuse_qkv_projections = True
for module in transformer.modules(): for module in pipe.transformer.modules():
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
transformer.attention_mode = attention_mode pipe.transformer.attention_mode = attention_mode
if compile_args is not None: if compile_args is not None:
pipe.transformer.to(memory_format=torch.channels_last) pipe.transformer.to(memory_format=torch.channels_last)

View File

@ -602,6 +602,7 @@ class CogVideoSampler:
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None, def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None): denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None):
mm.unload_all_models()
mm.soft_empty_cache() mm.soft_empty_cache()
model_name = model.get("model_name", "") model_name = model.get("model_name", "")