mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
PAB for I2V as well
This commit is contained in:
parent
d9abc00d3b
commit
e03afa778e
@ -14,7 +14,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, CogVideoXPatchEmbed
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils import is_torch_version
|
||||
@ -24,7 +24,7 @@ from torch import nn
|
||||
from .core.pab_mgr import enable_pab, if_broadcast_spatial
|
||||
from .modules.embeddings import apply_rotary_emb
|
||||
|
||||
from .modules.embeddings import CogVideoXPatchEmbed
|
||||
#from .modules.embeddings import CogVideoXPatchEmbed
|
||||
|
||||
from .modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
|
||||
@ -407,6 +407,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
spatial_interpolation_scale: float = 1.875,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
use_learned_positional_embeddings: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
@ -417,7 +418,22 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
|
||||
self.patch_embed = CogVideoXPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
text_embed_dim=text_embed_dim,
|
||||
bias=True,
|
||||
sample_width=sample_width,
|
||||
sample_height=sample_height,
|
||||
sample_frames=sample_frames,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
max_text_seq_length=max_text_seq_length,
|
||||
spatial_interpolation_scale=spatial_interpolation_scale,
|
||||
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||
use_positional_embeddings=not use_rotary_positional_embeddings,
|
||||
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
||||
)
|
||||
self.embedding_dropout = nn.Dropout(dropout)
|
||||
|
||||
# 2. 3D positional embeddings
|
||||
@ -590,7 +606,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 6. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
#if self.parallel_manager.cp_size > 1:
|
||||
|
||||
@ -1,406 +0,0 @@
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# ======================================================
|
||||
# Model
|
||||
# ======================================================
|
||||
|
||||
|
||||
def model_sharding(model: torch.nn.Module):
|
||||
global_rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
for _, param in model.named_parameters():
|
||||
padding_size = (world_size - param.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
else:
|
||||
padding_param = param.data.view(-1)
|
||||
splited_params = padding_param.split(padding_param.numel() // world_size)
|
||||
splited_params = splited_params[global_rank]
|
||||
param.data = splited_params
|
||||
|
||||
|
||||
# ======================================================
|
||||
# AllGather & ReduceScatter
|
||||
# ======================================================
|
||||
|
||||
|
||||
class AsyncAllGatherForTwo(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
sp_rank: int,
|
||||
sp_size: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
from torch.distributed._functional_collectives import all_gather_tensor
|
||||
|
||||
ctx.group = group
|
||||
ctx.sp_rank = sp_rank
|
||||
ctx.sp_size = sp_size
|
||||
|
||||
# all gather inputs
|
||||
all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group)
|
||||
# compute local qkv
|
||||
local_qkv = F.linear(inputs, weight, bias).unsqueeze(0)
|
||||
|
||||
# remote compute
|
||||
remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1])
|
||||
# compute remote qkv
|
||||
remote_qkv = F.linear(remote_inputs, weight, bias)
|
||||
|
||||
# concat local and remote qkv
|
||||
if sp_rank == 0:
|
||||
qkv = torch.cat([local_qkv, remote_qkv], dim=0)
|
||||
else:
|
||||
qkv = torch.cat([remote_qkv, local_qkv], dim=0)
|
||||
qkv = rearrange(qkv, "sp b n c -> b (sp n) c")
|
||||
|
||||
ctx.save_for_backward(inputs, weight, remote_inputs)
|
||||
return qkv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
from torch.distributed._functional_collectives import reduce_scatter_tensor
|
||||
|
||||
group = ctx.group
|
||||
sp_rank = ctx.sp_rank
|
||||
sp_size = ctx.sp_size
|
||||
inputs, weight, remote_inputs = ctx.saved_tensors
|
||||
|
||||
# split qkv_grad
|
||||
qkv_grad = grad_outputs[0]
|
||||
qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size)
|
||||
qkv_grad = torch.chunk(qkv_grad, 2, dim=0)
|
||||
if sp_rank == 0:
|
||||
local_qkv_grad, remote_qkv_grad = qkv_grad
|
||||
else:
|
||||
remote_qkv_grad, local_qkv_grad = qkv_grad
|
||||
|
||||
# compute remote grad
|
||||
remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0)
|
||||
weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0)
|
||||
bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0)
|
||||
|
||||
# launch async reduce scatter
|
||||
remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad)
|
||||
if sp_rank == 0:
|
||||
remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0)
|
||||
else:
|
||||
remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0)
|
||||
remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group)
|
||||
|
||||
# compute local grad and wait for reduce scatter
|
||||
local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0)
|
||||
weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0)
|
||||
bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0)
|
||||
|
||||
# sum remote and local grad
|
||||
inputs_grad = remote_inputs_grad + local_input_grad
|
||||
return inputs_grad, weight_grad, bias_grad, None, None, None
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
assert ctx is not None or not overlap
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.unsqueeze(0), None
|
||||
|
||||
buffer_shape = (comm_size,) + inputs.shape
|
||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||
if not overlap:
|
||||
dist.all_gather(buffer_list, inputs, group=group)
|
||||
return outputs, None
|
||||
else:
|
||||
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
|
||||
return outputs, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
return (
|
||||
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class ReduceScatter(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: ProcessGroup,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
assert ctx is not None or not overlap
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.squeeze(0), None
|
||||
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
output_shape = inputs.shape[1:]
|
||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||
if not overlap:
|
||||
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||||
return outputs, None
|
||||
else:
|
||||
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
|
||||
return outputs, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
# TODO: support async backward
|
||||
return (
|
||||
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================
|
||||
# AlltoAll
|
||||
# ======================================================
|
||||
|
||||
|
||||
def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim):
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
"""All-to-all communication.
|
||||
|
||||
Args:
|
||||
input_: input matrix
|
||||
process_group: communication group
|
||||
scatter_dim: scatter dimension
|
||||
gather_dim: gather dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
||||
ctx.process_group = process_group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
world_size = dist.get_world_size(process_group)
|
||||
|
||||
return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
process_group = ctx.process_group
|
||||
scatter_dim = ctx.gather_dim
|
||||
gather_dim = ctx.scatter_dim
|
||||
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
|
||||
return (return_grad, None, None, None)
|
||||
|
||||
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
|
||||
|
||||
# ======================================================
|
||||
# Sequence Gather & Split
|
||||
# ======================================================
|
||||
|
||||
|
||||
def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
|
||||
# skip if only one rank involved
|
||||
world_size = dist.get_world_size(pg)
|
||||
rank = dist.get_rank(pg)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
if pad > 0:
|
||||
pad_size = list(input_.shape)
|
||||
pad_size[dim] = pad
|
||||
input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim)
|
||||
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})"
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
output = tensor_list[rank].contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
|
||||
# skip if only one rank involved
|
||||
input_ = input_.contiguous()
|
||||
world_size = dist.get_world_size(pg)
|
||||
dist.get_rank(pg)
|
||||
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
assert input_.device.type == "cuda"
|
||||
torch.distributed.all_gather(tensor_list, input_, group=pg)
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim)
|
||||
|
||||
if pad > 0:
|
||||
output = output.narrow(dim, 0, output.size(dim) - pad)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""
|
||||
Gather the input sequence.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: process group.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _gather_sequence_func(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim, grad_scale, pad):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.pad = pad
|
||||
return _gather_sequence_func(input_, process_group, dim, pad)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale == "up":
|
||||
grad_output = grad_output * dist.get_world_size(ctx.process_group)
|
||||
elif ctx.grad_scale == "down":
|
||||
grad_output = grad_output / dist.get_world_size(ctx.process_group)
|
||||
|
||||
return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
Split sequence.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split_sequence_func(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim, grad_scale, pad):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.pad = pad
|
||||
return _split_sequence_func(input_, process_group, dim, pad)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale == "up":
|
||||
grad_output = grad_output * dist.get_world_size(ctx.process_group)
|
||||
elif ctx.grad_scale == "down":
|
||||
grad_output = grad_output / dist.get_world_size(ctx.process_group)
|
||||
return _gather_sequence_func(grad_output, ctx.process_group, ctx.pad), None, None, None, None
|
||||
|
||||
|
||||
def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
|
||||
return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad)
|
||||
|
||||
|
||||
def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
|
||||
return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad)
|
||||
|
||||
|
||||
# ==============================
|
||||
# Pad
|
||||
# ==============================
|
||||
|
||||
PAD_DICT = {}
|
||||
|
||||
|
||||
def set_pad(name: str, dim_size: int, parallel_group: dist.ProcessGroup):
|
||||
sp_size = dist.get_world_size(parallel_group)
|
||||
pad = (sp_size - (dim_size % sp_size)) % sp_size
|
||||
global PAD_DICT
|
||||
PAD_DICT[name] = pad
|
||||
|
||||
|
||||
def get_pad(name) -> int:
|
||||
return PAD_DICT[name]
|
||||
|
||||
|
||||
def all_to_all_with_pad(
|
||||
input_: torch.Tensor,
|
||||
process_group: dist.ProcessGroup,
|
||||
scatter_dim: int = 2,
|
||||
gather_dim: int = 1,
|
||||
scatter_pad: int = 0,
|
||||
gather_pad: int = 0,
|
||||
):
|
||||
if scatter_pad > 0:
|
||||
pad_shape = list(input_.shape)
|
||||
pad_shape[scatter_dim] = scatter_pad
|
||||
pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype)
|
||||
input_ = torch.cat([input_, pad_tensor], dim=scatter_dim)
|
||||
|
||||
assert (
|
||||
input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0
|
||||
), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})"
|
||||
input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
|
||||
if gather_pad > 0:
|
||||
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
|
||||
|
||||
return input_
|
||||
@ -1,11 +1,8 @@
|
||||
import inspect
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
class VideoSysPipeline(DiffusionPipeline):
|
||||
def __init__(self):
|
||||
@ -45,8 +42,3 @@ class VideoSysPipeline(DiffusionPipeline):
|
||||
optional_parameters.remove(name)
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoSysPipelineOutput(BaseOutput):
|
||||
video: torch.Tensor
|
||||
|
||||
@ -1,205 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from videosys.models.modules.normalization import LlamaRMSNorm
|
||||
|
||||
|
||||
class OpenSoraAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = LlamaRMSNorm,
|
||||
enable_flash_attn: bool = False,
|
||||
rope=None,
|
||||
qk_norm_legacy: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.qk_norm_legacy = qk_norm_legacy
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.rope = False
|
||||
if rope is not None:
|
||||
self.rope = True
|
||||
self.rotary_emb = rope
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
# flash attn is not memory efficient for small sequences, this is empirical
|
||||
enable_flash_attn = self.enable_flash_attn and (N > B)
|
||||
qkv = self.qkv(x)
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
|
||||
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
if self.qk_norm_legacy:
|
||||
# WARNING: this may be a bug
|
||||
if self.rope:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
else:
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if self.rope:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
|
||||
if enable_flash_attn:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
x = flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
else:
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
x_output_shape = (B, N, C)
|
||||
if not enable_flash_attn:
|
||||
x = x.transpose(1, 2)
|
||||
x = x.reshape(x_output_shape)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class OpenSoraMultiHeadCrossAttention(nn.Module):
|
||||
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, enable_flash_attn=False):
|
||||
super(OpenSoraMultiHeadCrossAttention, self).__init__()
|
||||
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_model // num_heads
|
||||
|
||||
self.q_linear = nn.Linear(d_model, d_model)
|
||||
self.kv_linear = nn.Linear(d_model, d_model * 2)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(d_model, d_model)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
|
||||
def forward(self, x, cond, mask=None):
|
||||
# query/value: img tokens; key: condition; mask: if padding tokens
|
||||
B, N, C = x.shape
|
||||
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
||||
if self.enable_flash_attn:
|
||||
x = self.flash_attn_impl(q, k, v, mask, B, N, C)
|
||||
else:
|
||||
x = self.torch_impl(q, k, v, mask, B, N, C)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def flash_attn_impl(self, q, k, v, mask, B, N, C):
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
q_seqinfo = _SeqLenInfo.from_seqlens([N] * B)
|
||||
k_seqinfo = _SeqLenInfo.from_seqlens(mask)
|
||||
|
||||
x = flash_attn_varlen_func(
|
||||
q.view(-1, self.num_heads, self.head_dim),
|
||||
k.view(-1, self.num_heads, self.head_dim),
|
||||
v.view(-1, self.num_heads, self.head_dim),
|
||||
cu_seqlens_q=q_seqinfo.seqstart.cuda(),
|
||||
cu_seqlens_k=k_seqinfo.seqstart.cuda(),
|
||||
max_seqlen_q=q_seqinfo.max_seqlen,
|
||||
max_seqlen_k=k_seqinfo.max_seqlen,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
x = x.view(B, N, C)
|
||||
return x
|
||||
|
||||
def torch_impl(self, q, k, v, mask, B, N, C):
|
||||
q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attn_mask = torch.zeros(B, 1, N, k.shape[2], dtype=torch.bool, device=q.device)
|
||||
for i, m in enumerate(mask):
|
||||
attn_mask[i, :, :, :m] = -1e9
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = out.transpose(1, 2).contiguous().view(B, N, C)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SeqLenInfo:
|
||||
"""
|
||||
from xformers
|
||||
|
||||
(Internal) Represents the division of a dimension into blocks.
|
||||
For example, to represents a dimension of length 7 divided into
|
||||
three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
|
||||
The members will be:
|
||||
max_seqlen: 3
|
||||
min_seqlen: 2
|
||||
seqstart_py: [0, 2, 5, 7]
|
||||
seqstart: torch.IntTensor([0, 2, 5, 7])
|
||||
"""
|
||||
|
||||
seqstart: torch.Tensor
|
||||
max_seqlen: int
|
||||
min_seqlen: int
|
||||
seqstart_py: List[int]
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
self.seqstart = self.seqstart.to(device, non_blocking=True)
|
||||
|
||||
def intervals(self) -> Iterable[Tuple[int, int]]:
|
||||
yield from zip(self.seqstart_py, self.seqstart_py[1:])
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
|
||||
"""
|
||||
Input tensors are assumed to be in shape [B, M, *]
|
||||
"""
|
||||
assert not isinstance(seqlens, torch.Tensor)
|
||||
seqstart_py = [0]
|
||||
max_seqlen = -1
|
||||
min_seqlen = -1
|
||||
for seqlen in seqlens:
|
||||
min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
|
||||
max_seqlen = max(max_seqlen, seqlen)
|
||||
seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
|
||||
seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
|
||||
return cls(
|
||||
max_seqlen=max_seqlen,
|
||||
min_seqlen=min_seqlen,
|
||||
seqstart=seqstart,
|
||||
seqstart_py=seqstart_py,
|
||||
)
|
||||
@ -4,23 +4,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class CogVideoXLayerNormZero(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user