make cfg 1.0 not do uncond, set steps by sigma schedule

This commit is contained in:
kijai 2024-10-27 19:52:16 +02:00
parent 4348d1ed20
commit 195da244df
4 changed files with 23 additions and 14 deletions

View File

@ -237,7 +237,8 @@ class AsymmetricAttention(nn.Module):
return out.view(total, local_dim)
def sdpa_attention(self, qkv):
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
q, k, v = qkv.unbind(dim=1)
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
with sdpa_kernel(backends):
out = F.scaled_dot_product_attention(
@ -248,10 +249,13 @@ class AsymmetricAttention(nn.Module):
dropout_p=0.0,
is_causal=False
)
return rearrange(out, 'b h s d -> s (b h d)')
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
def sage_attention(self, qkv):
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
#q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
q, k, v = qkv.unbind(dim=1)
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = sageattn(
q,
@ -261,11 +265,14 @@ class AsymmetricAttention(nn.Module):
dropout_p=0.0,
is_causal=False
)
return rearrange(out, 'b h s d -> s (b h d)')
#print(out.shape)
#out = rearrange(out, 'b h s d -> s (b h d)')
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
def comfy_attention(self, qkv):
from comfy.ldm.modules.attention import optimized_attention
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
q, k, v = qkv.unbind(dim=1)
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = optimized_attention(
q,

View File

@ -325,13 +325,15 @@ class T2VSynthMochiModel:
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
else:
nonlocal sample, sample_null
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null)
if cfg_scale > 1.0:
out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null)
else:
out_cond = self.dit(z, sigma, **sample)
return out_cond
assert out_cond.shape == out_uncond.shape
return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond
return out_uncond + cfg_scale * (out_cond - out_uncond)
comfy_pbar = ProgressBar(sample_steps)
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
@ -339,7 +341,7 @@ class T2VSynthMochiModel:
dsigma = sigma - sigma_schedule[i + 1]
# `pred` estimates `z_0 - eps`.
pred, output_cond = model_fn(
pred = model_fn(
z=z,
sigma=torch.full(
[B] if not batch_cfg else [B * 2], sigma, device=z.device
@ -347,8 +349,6 @@ class T2VSynthMochiModel:
cfg_scale=cfg_schedule[i],
)
pred = pred.to(z)
output_cond = output_cond.to(z)
z = z + dsigma * pred
comfy_pbar.update(1)
@ -356,7 +356,6 @@ class T2VSynthMochiModel:
z = z[:B]
self.dit.to(self.offload_device)
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
logging.info(f"samples shape: {samples.shape}")
return samples

View File

@ -141,6 +141,7 @@ class WQLinear_GGUF(nn.Module):
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
else:
raise ValueError(f"Unknown qtype: {self.qtype}")
return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)

View File

@ -373,6 +373,8 @@ class MochiSampler:
if opt_sigmas is not None:
sigma_schedule = opt_sigmas.tolist()
sigma_schedule.extend([1.0])
steps = len(sigma_schedule)
logging.info(f"Using sigma_schedule: {sigma_schedule}")
else:
sigma_schedule = linear_quadratic_schedule(steps, 0.025)