fix non-accelerate model loading
This commit is contained in:
parent
fae591800c
commit
a32064eefb
@ -2,7 +2,6 @@ import json
|
|||||||
import random
|
import random
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -163,7 +162,7 @@ class T2VSynthMochiModel:
|
|||||||
else:
|
else:
|
||||||
print("Loading state_dict without accelerate...")
|
print("Loading state_dict without accelerate...")
|
||||||
model.load_state_dict(dit_sd)
|
model.load_state_dict(dit_sd)
|
||||||
for name, param in self.dit.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
param.data = param.data.to(weight_dtype)
|
param.data = param.data.to(weight_dtype)
|
||||||
else:
|
else:
|
||||||
@ -180,7 +179,7 @@ class T2VSynthMochiModel:
|
|||||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
||||||
|
|
||||||
#t.print_stats()
|
t.print_stats()
|
||||||
|
|
||||||
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
|
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
|
||||||
B = len(prompts)
|
B = len(prompts)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user