From 195da244dffd5f13a94013c9db4f03099ac7aa7d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 27 Oct 2024 19:52:16 +0200 Subject: [PATCH] make cfg 1.0 not do uncond, set steps by sigma schedule --- .../dit/joint_model/asymm_models_joint.py | 17 ++++++++++++----- mochi_preview/t2v_synth_mochi.py | 17 ++++++++--------- mz_gguf_loader.py | 1 + nodes.py | 2 ++ 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 81c2332..f4fdf6a 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -237,7 +237,8 @@ class AsymmetricAttention(nn.Module): return out.view(total, local_dim) def sdpa_attention(self, qkv): - q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) + q, k, v = qkv.unbind(dim=1) + q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)] with torch.autocast(mm.get_autocast_device(self.device), enabled=False): with sdpa_kernel(backends): out = F.scaled_dot_product_attention( @@ -248,10 +249,13 @@ class AsymmetricAttention(nn.Module): dropout_p=0.0, is_causal=False ) - return rearrange(out, 'b h s d -> s (b h d)') + return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1) def sage_attention(self, qkv): - q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) + #q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) + q, k, v = qkv.unbind(dim=1) + q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)] + with torch.autocast(mm.get_autocast_device(self.device), enabled=False): out = sageattn( q, @@ -261,11 +265,14 @@ class AsymmetricAttention(nn.Module): dropout_p=0.0, is_causal=False ) - return rearrange(out, 'b h s d -> s (b h d)') + #print(out.shape) + #out = rearrange(out, 'b h s d -> s (b h d)') + return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1) def comfy_attention(self, qkv): from comfy.ldm.modules.attention import optimized_attention - q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) + q, k, v = qkv.unbind(dim=1) + q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)] with torch.autocast(mm.get_autocast_device(self.device), enabled=False): out = optimized_attention( q, diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 6b598b7..18d67e6 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -325,13 +325,15 @@ class T2VSynthMochiModel: out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) else: nonlocal sample, sample_null - with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): - out_cond = self.dit(z, sigma, **sample) - out_uncond = self.dit(z, sigma, **sample_null) + if cfg_scale > 1.0: + out_cond = self.dit(z, sigma, **sample) + out_uncond = self.dit(z, sigma, **sample_null) + else: + out_cond = self.dit(z, sigma, **sample) + return out_cond - assert out_cond.shape == out_uncond.shape - return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond + return out_uncond + cfg_scale * (out_cond - out_uncond) comfy_pbar = ProgressBar(sample_steps) for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps): @@ -339,7 +341,7 @@ class T2VSynthMochiModel: dsigma = sigma - sigma_schedule[i + 1] # `pred` estimates `z_0 - eps`. - pred, output_cond = model_fn( + pred = model_fn( z=z, sigma=torch.full( [B] if not batch_cfg else [B * 2], sigma, device=z.device @@ -347,8 +349,6 @@ class T2VSynthMochiModel: cfg_scale=cfg_schedule[i], ) pred = pred.to(z) - output_cond = output_cond.to(z) - z = z + dsigma * pred comfy_pbar.update(1) @@ -356,7 +356,6 @@ class T2VSynthMochiModel: z = z[:B] self.dit.to(self.offload_device) - samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) logging.info(f"samples shape: {samples.shape}") return samples diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py index a01791c..ca3a913 100644 --- a/mz_gguf_loader.py +++ b/mz_gguf_loader.py @@ -141,6 +141,7 @@ class WQLinear_GGUF(nn.Module): dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype) else: raise ValueError(f"Unknown qtype: {self.qtype}") + return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None) diff --git a/nodes.py b/nodes.py index f8abed5..358941b 100644 --- a/nodes.py +++ b/nodes.py @@ -373,6 +373,8 @@ class MochiSampler: if opt_sigmas is not None: sigma_schedule = opt_sigmas.tolist() + sigma_schedule.extend([1.0]) + steps = len(sigma_schedule) logging.info(f"Using sigma_schedule: {sigma_schedule}") else: sigma_schedule = linear_quadratic_schedule(steps, 0.025)