fix non-accelerate model loading

This commit is contained in:
kijai 2024-10-23 20:40:42 +03:00
parent fae591800c
commit a32064eefb

View File

@ -2,7 +2,6 @@ import json
import random
from typing import Dict, List
from safetensors.torch import load_file
import numpy as np
import torch
import torch.nn as nn
@ -163,7 +162,7 @@ class T2VSynthMochiModel:
else:
print("Loading state_dict without accelerate...")
model.load_state_dict(dit_sd)
for name, param in self.dit.named_parameters():
for name, param in model.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(weight_dtype)
else:
@ -180,7 +179,7 @@ class T2VSynthMochiModel:
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
#t.print_stats()
t.print_stats()
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
B = len(prompts)