make cfg 1.0 not do uncond, set steps by sigma schedule
This commit is contained in:
parent
4348d1ed20
commit
195da244df
@ -237,7 +237,8 @@ class AsymmetricAttention(nn.Module):
|
||||
return out.view(total, local_dim)
|
||||
|
||||
def sdpa_attention(self, qkv):
|
||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
with sdpa_kernel(backends):
|
||||
out = F.scaled_dot_product_attention(
|
||||
@ -248,10 +249,13 @@ class AsymmetricAttention(nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
return rearrange(out, 'b h s d -> s (b h d)')
|
||||
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
|
||||
|
||||
def sage_attention(self, qkv):
|
||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
#q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
out = sageattn(
|
||||
q,
|
||||
@ -261,11 +265,14 @@ class AsymmetricAttention(nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
return rearrange(out, 'b h s d -> s (b h d)')
|
||||
#print(out.shape)
|
||||
#out = rearrange(out, 'b h s d -> s (b h d)')
|
||||
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
|
||||
|
||||
def comfy_attention(self, qkv):
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
out = optimized_attention(
|
||||
q,
|
||||
|
||||
@ -325,13 +325,15 @@ class T2VSynthMochiModel:
|
||||
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
|
||||
else:
|
||||
nonlocal sample, sample_null
|
||||
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
out_uncond = self.dit(z, sigma, **sample_null)
|
||||
if cfg_scale > 1.0:
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
out_uncond = self.dit(z, sigma, **sample_null)
|
||||
else:
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
return out_cond
|
||||
|
||||
assert out_cond.shape == out_uncond.shape
|
||||
return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond
|
||||
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||
|
||||
comfy_pbar = ProgressBar(sample_steps)
|
||||
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
|
||||
@ -339,7 +341,7 @@ class T2VSynthMochiModel:
|
||||
dsigma = sigma - sigma_schedule[i + 1]
|
||||
|
||||
# `pred` estimates `z_0 - eps`.
|
||||
pred, output_cond = model_fn(
|
||||
pred = model_fn(
|
||||
z=z,
|
||||
sigma=torch.full(
|
||||
[B] if not batch_cfg else [B * 2], sigma, device=z.device
|
||||
@ -347,8 +349,6 @@ class T2VSynthMochiModel:
|
||||
cfg_scale=cfg_schedule[i],
|
||||
)
|
||||
pred = pred.to(z)
|
||||
output_cond = output_cond.to(z)
|
||||
|
||||
z = z + dsigma * pred
|
||||
comfy_pbar.update(1)
|
||||
|
||||
@ -356,7 +356,6 @@ class T2VSynthMochiModel:
|
||||
z = z[:B]
|
||||
|
||||
self.dit.to(self.offload_device)
|
||||
|
||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||
logging.info(f"samples shape: {samples.shape}")
|
||||
return samples
|
||||
|
||||
@ -141,6 +141,7 @@ class WQLinear_GGUF(nn.Module):
|
||||
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
|
||||
else:
|
||||
raise ValueError(f"Unknown qtype: {self.qtype}")
|
||||
|
||||
return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user