cleanup code

This commit is contained in:
kijai 2024-11-03 01:47:34 +02:00
parent 5ec01cbff4
commit 3cf9289e08
2 changed files with 29 additions and 52 deletions

View File

@ -62,28 +62,6 @@ class TimestepEmbedder(nn.Module):
return t_emb return t_emb
class PooledCaptionEmbedder(nn.Module):
def __init__(
self,
caption_feature_dim: int,
hidden_size: int,
*,
bias: bool = True,
device: Optional[torch.device] = None,
):
super().__init__()
self.caption_feature_dim = caption_feature_dim
self.hidden_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
)
def forward(self, x):
return self.mlp(x)
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__( def __init__(
self, self,

View File

@ -286,41 +286,40 @@ class T2VSynthMochiModel:
sample_null["y_mask"], **latent_dims sample_null["y_mask"], **latent_dims
) )
def model_fn(*, z, sigma, cfg_scale): self.dit.to(self.device)
self.dit.to(self.device) if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: autocast_dtype = torch.float16
autocast_dtype = torch.float16 else:
else: autocast_dtype = torch.bfloat16
autocast_dtype = torch.bfloat16
nonlocal sample, sample_null
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
if cfg_scale > 1.0:
out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null)
else:
out_cond = self.dit(z, sigma, **sample)
return out_cond
def model_fn(*, z, sigma, cfg_scale):
nonlocal sample, sample_null
if cfg_scale > 1.0:
out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null)
else:
out_cond = self.dit(z, sigma, **sample)
return out_cond
return out_uncond + cfg_scale * (out_cond - out_uncond) return out_uncond + cfg_scale * (out_cond - out_uncond)
comfy_pbar = ProgressBar(sample_steps) comfy_pbar = ProgressBar(sample_steps)
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps): with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
sigma = sigma_schedule[i] for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
dsigma = sigma - sigma_schedule[i + 1] sigma = sigma_schedule[i]
dsigma = sigma - sigma_schedule[i + 1]
# `pred` estimates `z_0 - eps`. # `pred` estimates `z_0 - eps`.
pred = model_fn( pred = model_fn(
z=z, z=z,
sigma=torch.full([B], sigma, device=z.device), sigma=torch.full([B], sigma, device=z.device),
cfg_scale=cfg_schedule[i], cfg_scale=cfg_schedule[i],
) )
pred = pred.to(z) pred = pred.to(z)
z = z + dsigma * pred z = z + dsigma * pred
if callback is not None: if callback is not None:
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps) callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
else: else:
comfy_pbar.update(1) comfy_pbar.update(1)
self.dit.to(self.offload_device) self.dit.to(self.offload_device)
logging.info(f"samples shape: {z.shape}") logging.info(f"samples shape: {z.shape}")