fix non-accelerate model loading
This commit is contained in:
parent
fae591800c
commit
a32064eefb
@ -2,7 +2,6 @@ import json
|
||||
import random
|
||||
from typing import Dict, List
|
||||
|
||||
from safetensors.torch import load_file
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -163,7 +162,7 @@ class T2VSynthMochiModel:
|
||||
else:
|
||||
print("Loading state_dict without accelerate...")
|
||||
model.load_state_dict(dit_sd)
|
||||
for name, param in self.dit.named_parameters():
|
||||
for name, param in model.named_parameters():
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
param.data = param.data.to(weight_dtype)
|
||||
else:
|
||||
@ -180,7 +179,7 @@ class T2VSynthMochiModel:
|
||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
||||
|
||||
#t.print_stats()
|
||||
t.print_stats()
|
||||
|
||||
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
|
||||
B = len(prompts)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user