cleanup code
This commit is contained in:
parent
5ec01cbff4
commit
3cf9289e08
@ -62,28 +62,6 @@ class TimestepEmbedder(nn.Module):
|
|||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
class PooledCaptionEmbedder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
caption_feature_dim: int,
|
|
||||||
hidden_size: int,
|
|
||||||
*,
|
|
||||||
bias: bool = True,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.caption_feature_dim = caption_feature_dim
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.mlp(x)
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -286,41 +286,40 @@ class T2VSynthMochiModel:
|
|||||||
sample_null["y_mask"], **latent_dims
|
sample_null["y_mask"], **latent_dims
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_fn(*, z, sigma, cfg_scale):
|
self.dit.to(self.device)
|
||||||
self.dit.to(self.device)
|
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
||||||
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
autocast_dtype = torch.float16
|
||||||
autocast_dtype = torch.float16
|
else:
|
||||||
else:
|
autocast_dtype = torch.bfloat16
|
||||||
autocast_dtype = torch.bfloat16
|
|
||||||
|
|
||||||
nonlocal sample, sample_null
|
|
||||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
|
||||||
if cfg_scale > 1.0:
|
|
||||||
out_cond = self.dit(z, sigma, **sample)
|
|
||||||
out_uncond = self.dit(z, sigma, **sample_null)
|
|
||||||
else:
|
|
||||||
out_cond = self.dit(z, sigma, **sample)
|
|
||||||
return out_cond
|
|
||||||
|
|
||||||
|
def model_fn(*, z, sigma, cfg_scale):
|
||||||
|
nonlocal sample, sample_null
|
||||||
|
if cfg_scale > 1.0:
|
||||||
|
out_cond = self.dit(z, sigma, **sample)
|
||||||
|
out_uncond = self.dit(z, sigma, **sample_null)
|
||||||
|
else:
|
||||||
|
out_cond = self.dit(z, sigma, **sample)
|
||||||
|
return out_cond
|
||||||
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||||
|
|
||||||
comfy_pbar = ProgressBar(sample_steps)
|
comfy_pbar = ProgressBar(sample_steps)
|
||||||
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
|
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
||||||
sigma = sigma_schedule[i]
|
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
|
||||||
dsigma = sigma - sigma_schedule[i + 1]
|
sigma = sigma_schedule[i]
|
||||||
|
dsigma = sigma - sigma_schedule[i + 1]
|
||||||
|
|
||||||
# `pred` estimates `z_0 - eps`.
|
# `pred` estimates `z_0 - eps`.
|
||||||
pred = model_fn(
|
pred = model_fn(
|
||||||
z=z,
|
z=z,
|
||||||
sigma=torch.full([B], sigma, device=z.device),
|
sigma=torch.full([B], sigma, device=z.device),
|
||||||
cfg_scale=cfg_schedule[i],
|
cfg_scale=cfg_schedule[i],
|
||||||
)
|
)
|
||||||
pred = pred.to(z)
|
pred = pred.to(z)
|
||||||
z = z + dsigma * pred
|
z = z + dsigma * pred
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
|
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
|
||||||
else:
|
else:
|
||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
self.dit.to(self.offload_device)
|
self.dit.to(self.offload_device)
|
||||||
logging.info(f"samples shape: {z.shape}")
|
logging.info(f"samples shape: {z.shape}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user