PAB for I2V as well

This commit is contained in:
kijai 2024-09-23 00:39:17 +03:00
parent d9abc00d3b
commit e03afa778e
5 changed files with 20 additions and 640 deletions

View File

@ -14,7 +14,7 @@ import torch
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, CogVideoXPatchEmbed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import is_torch_version
@ -24,7 +24,7 @@ from torch import nn
from .core.pab_mgr import enable_pab, if_broadcast_spatial
from .modules.embeddings import apply_rotary_emb
from .modules.embeddings import CogVideoXPatchEmbed
#from .modules.embeddings import CogVideoXPatchEmbed
from .modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
@ -407,6 +407,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@ -417,7 +418,22 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
text_embed_dim=text_embed_dim,
bias=True,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
temporal_compression_ratio=temporal_compression_ratio,
max_text_seq_length=max_text_seq_length,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
@ -590,7 +606,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
#if self.parallel_manager.cp_size > 1:

View File

@ -1,406 +0,0 @@
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from torch.distributed import ProcessGroup
# ======================================================
# Model
# ======================================================
def model_sharding(model: torch.nn.Module):
global_rank = dist.get_rank()
world_size = dist.get_world_size()
for _, param in model.named_parameters():
padding_size = (world_size - param.numel() % world_size) % world_size
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // world_size)
splited_params = splited_params[global_rank]
param.data = splited_params
# ======================================================
# AllGather & ReduceScatter
# ======================================================
class AsyncAllGatherForTwo(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
weight: Tensor,
bias: Tensor,
sp_rank: int,
sp_size: int,
group: Optional[ProcessGroup] = None,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
from torch.distributed._functional_collectives import all_gather_tensor
ctx.group = group
ctx.sp_rank = sp_rank
ctx.sp_size = sp_size
# all gather inputs
all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group)
# compute local qkv
local_qkv = F.linear(inputs, weight, bias).unsqueeze(0)
# remote compute
remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1])
# compute remote qkv
remote_qkv = F.linear(remote_inputs, weight, bias)
# concat local and remote qkv
if sp_rank == 0:
qkv = torch.cat([local_qkv, remote_qkv], dim=0)
else:
qkv = torch.cat([remote_qkv, local_qkv], dim=0)
qkv = rearrange(qkv, "sp b n c -> b (sp n) c")
ctx.save_for_backward(inputs, weight, remote_inputs)
return qkv
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
from torch.distributed._functional_collectives import reduce_scatter_tensor
group = ctx.group
sp_rank = ctx.sp_rank
sp_size = ctx.sp_size
inputs, weight, remote_inputs = ctx.saved_tensors
# split qkv_grad
qkv_grad = grad_outputs[0]
qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size)
qkv_grad = torch.chunk(qkv_grad, 2, dim=0)
if sp_rank == 0:
local_qkv_grad, remote_qkv_grad = qkv_grad
else:
remote_qkv_grad, local_qkv_grad = qkv_grad
# compute remote grad
remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0)
weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0)
bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0)
# launch async reduce scatter
remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad)
if sp_rank == 0:
remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0)
else:
remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0)
remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group)
# compute local grad and wait for reduce scatter
local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0)
weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0)
bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0)
# sum remote and local grad
inputs_grad = remote_inputs_grad + local_input_grad
return inputs_grad, weight_grad, bias_grad, None, None, None
class AllGather(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0), None
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
if not overlap:
dist.all_gather(buffer_list, inputs, group=group)
return outputs, None
else:
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: ProcessGroup,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0), None
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
if not overlap:
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs, None
else:
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
# TODO: support async backward
return (
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
# ======================================================
# AlltoAll
# ======================================================
def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
world_size = dist.get_world_size(process_group)
return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim)
@staticmethod
def backward(ctx, *grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
return (return_grad, None, None, None)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
# ======================================================
# Sequence Gather & Split
# ======================================================
def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
# skip if only one rank involved
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
if world_size == 1:
return input_
if pad > 0:
pad_size = list(input_.shape)
pad_size[dim] = pad
input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim)
dim_size = input_.size(dim)
assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})"
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
output = tensor_list[rank].contiguous()
return output
def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
# skip if only one rank involved
input_ = input_.contiguous()
world_size = dist.get_world_size(pg)
dist.get_rank(pg)
if world_size == 1:
return input_
# all gather
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
assert input_.device.type == "cuda"
torch.distributed.all_gather(tensor_list, input_, group=pg)
# concat
output = torch.cat(tensor_list, dim=dim)
if pad > 0:
output = output.narrow(dim, 0, output.size(dim) - pad)
return output
class _GatherForwardSplitBackward(torch.autograd.Function):
"""
Gather the input sequence.
Args:
input_: input matrix.
process_group: process group.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _gather_sequence_func(input_)
@staticmethod
def forward(ctx, input_, process_group, dim, grad_scale, pad):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
ctx.pad = pad
return _gather_sequence_func(input_, process_group, dim, pad)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale == "up":
grad_output = grad_output * dist.get_world_size(ctx.process_group)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.process_group)
return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split sequence.
Args:
input_: input matrix.
process_group: parallel mode.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _split_sequence_func(input_)
@staticmethod
def forward(ctx, input_, process_group, dim, grad_scale, pad):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
ctx.pad = pad
return _split_sequence_func(input_, process_group, dim, pad)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale == "up":
grad_output = grad_output * dist.get_world_size(ctx.process_group)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.process_group)
return _gather_sequence_func(grad_output, ctx.process_group, ctx.pad), None, None, None, None
def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad)
def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad)
# ==============================
# Pad
# ==============================
PAD_DICT = {}
def set_pad(name: str, dim_size: int, parallel_group: dist.ProcessGroup):
sp_size = dist.get_world_size(parallel_group)
pad = (sp_size - (dim_size % sp_size)) % sp_size
global PAD_DICT
PAD_DICT[name] = pad
def get_pad(name) -> int:
return PAD_DICT[name]
def all_to_all_with_pad(
input_: torch.Tensor,
process_group: dist.ProcessGroup,
scatter_dim: int = 2,
gather_dim: int = 1,
scatter_pad: int = 0,
gather_pad: int = 0,
):
if scatter_pad > 0:
pad_shape = list(input_.shape)
pad_shape[scatter_dim] = scatter_pad
pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype)
input_ = torch.cat([input_, pad_tensor], dim=scatter_dim)
assert (
input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0
), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})"
input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
if gather_pad > 0:
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
return input_

View File

@ -1,11 +1,8 @@
import inspect
from abc import abstractmethod
from dataclasses import dataclass
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput
class VideoSysPipeline(DiffusionPipeline):
def __init__(self):
@ -45,8 +42,3 @@ class VideoSysPipeline(DiffusionPipeline):
optional_parameters.remove(name)
return expected_modules, optional_parameters
@dataclass
class VideoSysPipelineOutput(BaseOutput):
video: torch.Tensor

View File

@ -1,205 +0,0 @@
from dataclasses import dataclass
from typing import Iterable, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from videosys.models.modules.normalization import LlamaRMSNorm
class OpenSoraAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = LlamaRMSNorm,
enable_flash_attn: bool = False,
rope=None,
qk_norm_legacy: bool = False,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flash_attn = enable_flash_attn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.qk_norm_legacy = qk_norm_legacy
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = False
if rope is not None:
self.rope = True
self.rotary_emb = rope
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
# flash attn is not memory efficient for small sequences, this is empirical
enable_flash_attn = self.enable_flash_attn and (N > B)
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
if self.qk_norm_legacy:
# WARNING: this may be a bug
if self.rope:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
q, k = self.q_norm(q), self.k_norm(k)
else:
q, k = self.q_norm(q), self.k_norm(k)
if self.rope:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
if enable_flash_attn:
from flash_attn import flash_attn_func
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
else:
x = F.scaled_dot_product_attention(q, k, v)
x_output_shape = (B, N, C)
if not enable_flash_attn:
x = x.transpose(1, 2)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
return x
class OpenSoraMultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, enable_flash_attn=False):
super(OpenSoraMultiHeadCrossAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
self.enable_flash_attn = enable_flash_attn
def forward(self, x, cond, mask=None):
# query/value: img tokens; key: condition; mask: if padding tokens
B, N, C = x.shape
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
if self.enable_flash_attn:
x = self.flash_attn_impl(q, k, v, mask, B, N, C)
else:
x = self.torch_impl(q, k, v, mask, B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def flash_attn_impl(self, q, k, v, mask, B, N, C):
from flash_attn import flash_attn_varlen_func
q_seqinfo = _SeqLenInfo.from_seqlens([N] * B)
k_seqinfo = _SeqLenInfo.from_seqlens(mask)
x = flash_attn_varlen_func(
q.view(-1, self.num_heads, self.head_dim),
k.view(-1, self.num_heads, self.head_dim),
v.view(-1, self.num_heads, self.head_dim),
cu_seqlens_q=q_seqinfo.seqstart.cuda(),
cu_seqlens_k=k_seqinfo.seqstart.cuda(),
max_seqlen_q=q_seqinfo.max_seqlen,
max_seqlen_k=k_seqinfo.max_seqlen,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.view(B, N, C)
return x
def torch_impl(self, q, k, v, mask, B, N, C):
q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
attn_mask = torch.zeros(B, 1, N, k.shape[2], dtype=torch.bool, device=q.device)
for i, m in enumerate(mask):
attn_mask[i, :, :, :m] = -1e9
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = out.transpose(1, 2).contiguous().view(B, N, C)
return x
@dataclass
class _SeqLenInfo:
"""
from xformers
(Internal) Represents the division of a dimension into blocks.
For example, to represents a dimension of length 7 divided into
three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
The members will be:
max_seqlen: 3
min_seqlen: 2
seqstart_py: [0, 2, 5, 7]
seqstart: torch.IntTensor([0, 2, 5, 7])
"""
seqstart: torch.Tensor
max_seqlen: int
min_seqlen: int
seqstart_py: List[int]
def to(self, device: torch.device) -> None:
self.seqstart = self.seqstart.to(device, non_blocking=True)
def intervals(self) -> Iterable[Tuple[int, int]]:
yield from zip(self.seqstart_py, self.seqstart_py[1:])
@classmethod
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
"""
Input tensors are assumed to be in shape [B, M, *]
"""
assert not isinstance(seqlens, torch.Tensor)
seqstart_py = [0]
max_seqlen = -1
min_seqlen = -1
for seqlen in seqlens:
min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
max_seqlen = max(max_seqlen, seqlen)
seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
return cls(
max_seqlen=max_seqlen,
min_seqlen=min_seqlen,
seqstart=seqstart,
seqstart_py=seqstart_py,
)

View File

@ -4,23 +4,6 @@ import torch
import torch.nn as nn
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class CogVideoXLayerNormZero(nn.Module):
def __init__(
self,