make cfg 1.0 not do uncond, set steps by sigma schedule
This commit is contained in:
parent
4348d1ed20
commit
195da244df
@ -237,7 +237,8 @@ class AsymmetricAttention(nn.Module):
|
|||||||
return out.view(total, local_dim)
|
return out.view(total, local_dim)
|
||||||
|
|
||||||
def sdpa_attention(self, qkv):
|
def sdpa_attention(self, qkv):
|
||||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
q, k, v = qkv.unbind(dim=1)
|
||||||
|
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||||
with sdpa_kernel(backends):
|
with sdpa_kernel(backends):
|
||||||
out = F.scaled_dot_product_attention(
|
out = F.scaled_dot_product_attention(
|
||||||
@ -248,10 +249,13 @@ class AsymmetricAttention(nn.Module):
|
|||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
is_causal=False
|
is_causal=False
|
||||||
)
|
)
|
||||||
return rearrange(out, 'b h s d -> s (b h d)')
|
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
|
||||||
|
|
||||||
def sage_attention(self, qkv):
|
def sage_attention(self, qkv):
|
||||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
#q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||||
|
q, k, v = qkv.unbind(dim=1)
|
||||||
|
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||||
|
|
||||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||||
out = sageattn(
|
out = sageattn(
|
||||||
q,
|
q,
|
||||||
@ -261,11 +265,14 @@ class AsymmetricAttention(nn.Module):
|
|||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
is_causal=False
|
is_causal=False
|
||||||
)
|
)
|
||||||
return rearrange(out, 'b h s d -> s (b h d)')
|
#print(out.shape)
|
||||||
|
#out = rearrange(out, 'b h s d -> s (b h d)')
|
||||||
|
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
|
||||||
|
|
||||||
def comfy_attention(self, qkv):
|
def comfy_attention(self, qkv):
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
q, k, v = qkv.unbind(dim=1)
|
||||||
|
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||||
out = optimized_attention(
|
out = optimized_attention(
|
||||||
q,
|
q,
|
||||||
|
|||||||
@ -325,13 +325,15 @@ class T2VSynthMochiModel:
|
|||||||
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
|
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
|
||||||
else:
|
else:
|
||||||
nonlocal sample, sample_null
|
nonlocal sample, sample_null
|
||||||
|
|
||||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
||||||
out_cond = self.dit(z, sigma, **sample)
|
if cfg_scale > 1.0:
|
||||||
out_uncond = self.dit(z, sigma, **sample_null)
|
out_cond = self.dit(z, sigma, **sample)
|
||||||
|
out_uncond = self.dit(z, sigma, **sample_null)
|
||||||
|
else:
|
||||||
|
out_cond = self.dit(z, sigma, **sample)
|
||||||
|
return out_cond
|
||||||
|
|
||||||
assert out_cond.shape == out_uncond.shape
|
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||||
return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond
|
|
||||||
|
|
||||||
comfy_pbar = ProgressBar(sample_steps)
|
comfy_pbar = ProgressBar(sample_steps)
|
||||||
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
|
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
|
||||||
@ -339,7 +341,7 @@ class T2VSynthMochiModel:
|
|||||||
dsigma = sigma - sigma_schedule[i + 1]
|
dsigma = sigma - sigma_schedule[i + 1]
|
||||||
|
|
||||||
# `pred` estimates `z_0 - eps`.
|
# `pred` estimates `z_0 - eps`.
|
||||||
pred, output_cond = model_fn(
|
pred = model_fn(
|
||||||
z=z,
|
z=z,
|
||||||
sigma=torch.full(
|
sigma=torch.full(
|
||||||
[B] if not batch_cfg else [B * 2], sigma, device=z.device
|
[B] if not batch_cfg else [B * 2], sigma, device=z.device
|
||||||
@ -347,8 +349,6 @@ class T2VSynthMochiModel:
|
|||||||
cfg_scale=cfg_schedule[i],
|
cfg_scale=cfg_schedule[i],
|
||||||
)
|
)
|
||||||
pred = pred.to(z)
|
pred = pred.to(z)
|
||||||
output_cond = output_cond.to(z)
|
|
||||||
|
|
||||||
z = z + dsigma * pred
|
z = z + dsigma * pred
|
||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
@ -356,7 +356,6 @@ class T2VSynthMochiModel:
|
|||||||
z = z[:B]
|
z = z[:B]
|
||||||
|
|
||||||
self.dit.to(self.offload_device)
|
self.dit.to(self.offload_device)
|
||||||
|
|
||||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||||
logging.info(f"samples shape: {samples.shape}")
|
logging.info(f"samples shape: {samples.shape}")
|
||||||
return samples
|
return samples
|
||||||
|
|||||||
@ -141,6 +141,7 @@ class WQLinear_GGUF(nn.Module):
|
|||||||
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
|
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown qtype: {self.qtype}")
|
raise ValueError(f"Unknown qtype: {self.qtype}")
|
||||||
|
|
||||||
return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
|
return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -373,6 +373,8 @@ class MochiSampler:
|
|||||||
|
|
||||||
if opt_sigmas is not None:
|
if opt_sigmas is not None:
|
||||||
sigma_schedule = opt_sigmas.tolist()
|
sigma_schedule = opt_sigmas.tolist()
|
||||||
|
sigma_schedule.extend([1.0])
|
||||||
|
steps = len(sigma_schedule)
|
||||||
logging.info(f"Using sigma_schedule: {sigma_schedule}")
|
logging.info(f"Using sigma_schedule: {sigma_schedule}")
|
||||||
else:
|
else:
|
||||||
sigma_schedule = linear_quadratic_schedule(steps, 0.025)
|
sigma_schedule = linear_quadratic_schedule(steps, 0.025)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user