diff --git a/configs/vae_stats.json b/configs/vae_stats.json deleted file mode 100644 index e3278af..0000000 --- a/configs/vae_stats.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "mean": [-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285], - "std": [0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041] -} diff --git a/mochi_preview/dit/joint_model/context_parallel.py b/mochi_preview/dit/joint_model/context_parallel.py deleted file mode 100644 index d93145d..0000000 --- a/mochi_preview/dit/joint_model/context_parallel.py +++ /dev/null @@ -1,163 +0,0 @@ -import torch -import torch.distributed as dist -from einops import rearrange - -_CONTEXT_PARALLEL_GROUP = None -_CONTEXT_PARALLEL_RANK = None -_CONTEXT_PARALLEL_GROUP_SIZE = None -_CONTEXT_PARALLEL_GROUP_RANKS = None - - -def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor: - if not _CONTEXT_PARALLEL_GROUP: - return x - - cp_rank, cp_size = get_cp_rank_size() - return x.tensor_split(cp_size, dim=dim)[cp_rank] - - -def set_cp_group(cp_group, ranks, global_rank): - global \ - _CONTEXT_PARALLEL_GROUP, \ - _CONTEXT_PARALLEL_RANK, \ - _CONTEXT_PARALLEL_GROUP_SIZE, \ - _CONTEXT_PARALLEL_GROUP_RANKS - if _CONTEXT_PARALLEL_GROUP is not None: - raise RuntimeError("CP group already initialized.") - _CONTEXT_PARALLEL_GROUP = cp_group - _CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group) - _CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group) - _CONTEXT_PARALLEL_GROUP_RANKS = ranks - - assert ( - _CONTEXT_PARALLEL_RANK == ranks.index(global_rank) - ), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} " - assert _CONTEXT_PARALLEL_GROUP_SIZE == len( - ranks - ), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})" - - -def get_cp_group(): - if _CONTEXT_PARALLEL_GROUP is None: - raise RuntimeError("CP group not initialized") - return _CONTEXT_PARALLEL_GROUP - - -def is_cp_active(): - return _CONTEXT_PARALLEL_GROUP is not None - - -def get_cp_rank_size(): - if _CONTEXT_PARALLEL_GROUP: - return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE - else: - return 0, 1 - - -class AllGatherIntoTensorFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup): - ctx.reduce_dtype = reduce_dtype - ctx.group = group - ctx.batch_size = x.size(0) - group_size = dist.get_world_size(group) - - x = x.contiguous() - output = torch.empty( - group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device - ) - dist.all_gather_into_tensor(output, x, group=group) - return output - - -def all_gather(tensor: torch.Tensor) -> torch.Tensor: - if not _CONTEXT_PARALLEL_GROUP: - return tensor - - return AllGatherIntoTensorFunction.apply( - tensor, torch.float32, _CONTEXT_PARALLEL_GROUP - ) - - -@torch.compiler.disable() -def _all_to_all_single(output, input, group): - # Disable compilation since torch compile changes contiguity. - assert input.is_contiguous(), "Input tensor must be contiguous." - assert output.is_contiguous(), "Output tensor must be contiguous." - return dist.all_to_all_single(output, input, group=group) - - -class CollectTokens(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int): - """Redistribute heads and receive tokens. - - Args: - qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim] - - Returns: - qkv: shape: [3, B, N, local_heads, head_dim] - - where M is the number of local tokens, - N = cp_size * M is the number of global tokens, - local_heads = num_heads // cp_size is the number of local heads. - """ - ctx.group = group - ctx.num_heads = num_heads - cp_size = dist.get_world_size(group) - assert num_heads % cp_size == 0 - ctx.local_heads = num_heads // cp_size - - qkv = rearrange( - qkv, - "B M (qkv G h d) -> G M h B (qkv d)", - qkv=3, - G=cp_size, - h=ctx.local_heads, - ).contiguous() - - output_chunks = torch.empty_like(qkv) - _all_to_all_single(output_chunks, qkv, group=group) - - return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3) - - -def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor: - if not _CONTEXT_PARALLEL_GROUP: - # Move QKV dimension to the front. - # B M (3 H d) -> 3 B M H d - B, M, _ = x.size() - x = x.view(B, M, 3, num_heads, -1) - return x.permute(2, 0, 1, 3, 4) - - return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads) - - -class CollectHeads(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup): - """Redistribute tokens and receive heads. - - Args: - x: Output of attention. Shape: [B, N, local_heads, head_dim] - - Returns: - Shape: [B, M, num_heads * head_dim] - """ - ctx.group = group - ctx.local_heads = x.size(2) - ctx.head_dim = x.size(3) - group_size = dist.get_world_size(group) - x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous() - output = torch.empty_like(x) - _all_to_all_single(output, x, group=group) - del x - return rearrange(output, "G h M B D -> B M (G h D)") - - -def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor: - if not _CONTEXT_PARALLEL_GROUP: - # Merge heads. - return x.view(x.size(0), x.size(1), x.size(2) * x.size(3)) - - return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP) diff --git a/mochi_preview/dit/joint_model/layers.py b/mochi_preview/dit/joint_model/layers.py index aa40a67..9d66921 100644 --- a/mochi_preview/dit/joint_model/layers.py +++ b/mochi_preview/dit/joint_model/layers.py @@ -62,28 +62,6 @@ class TimestepEmbedder(nn.Module): return t_emb -class PooledCaptionEmbedder(nn.Module): - def __init__( - self, - caption_feature_dim: int, - hidden_size: int, - *, - bias: bool = True, - device: Optional[torch.device] = None, - ): - super().__init__() - self.caption_feature_dim = caption_feature_dim - self.hidden_size = hidden_size - self.mlp = nn.Sequential( - nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=bias, device=device), - ) - - def forward(self, x): - return self.mlp(x) - - class FeedForward(nn.Module): def __init__( self, @@ -152,8 +130,6 @@ class PatchEmbed(nn.Module): x = F.pad(x, (0, pad_w, 0, pad_h)) x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T) - #print("x",x.dtype, x.device) - #print(self.proj.weight.dtype, self.proj.weight.device) x = self.proj(x) # Flatten temporal and spatial dimensions. diff --git a/mochi_preview/dit/joint_model/rope_mixed.py b/mochi_preview/dit/joint_model/rope_mixed.py index f2952bd..d0102dc 100644 --- a/mochi_preview/dit/joint_model/rope_mixed.py +++ b/mochi_preview/dit/joint_model/rope_mixed.py @@ -1,4 +1,4 @@ -import functools +#import functools import math import torch @@ -21,7 +21,7 @@ def centers(start: float, stop, num, dtype=None, device=None): return (edges[:-1] + edges[1:]) / 2 -@functools.lru_cache(maxsize=1) +#@functools.lru_cache(maxsize=1) def create_position_matrix( T: int, pH: int, diff --git a/mochi_preview/dit/joint_model/utils.py b/mochi_preview/dit/joint_model/utils.py index 85fd2df..0bcfbd3 100644 --- a/mochi_preview/dit/joint_model/utils.py +++ b/mochi_preview/dit/joint_model/utils.py @@ -28,95 +28,4 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch. mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) pooled = (x * mask).sum(dim=1, keepdim=keepdim) return pooled - - -class PadSplitXY(torch.autograd.Function): - """ - Merge heads, pad and extract visual and text tokens, - and split along the sequence length. - """ - - @staticmethod - def forward( - ctx, - xy: torch.Tensor, - indices: torch.Tensor, - B: int, - N: int, - L: int, - dtype: torch.dtype, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim). - indices: Valid token indices out of unpacked tensor. Shape: (total,) - - Returns: - x: Visual tokens. Shape: (B, N, num_heads * head_dim). - y: Text tokens. Shape: (B, L, num_heads * head_dim). - """ - ctx.save_for_backward(indices) - ctx.B, ctx.N, ctx.L = B, N, L - D = xy.size(1) - - # Pad sequences to (B, N + L, dim). - assert indices.ndim == 1 - output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype) - indices = indices.unsqueeze(1).expand( - -1, D - ) # (total,) -> (total, num_heads * head_dim) - output.scatter_(0, indices, xy) - xy = output.view(B, N + L, D) - - # Split visual and text tokens along the sequence length. - return torch.tensor_split(xy, (N,), dim=1) - - -def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]: - return PadSplitXY.apply(xy, indices, B, N, L, dtype) - - -class UnifyStreams(torch.autograd.Function): - """Unify visual and text streams.""" - - @staticmethod - def forward( - ctx, - q_x: torch.Tensor, - k_x: torch.Tensor, - v_x: torch.Tensor, - q_y: torch.Tensor, - k_y: torch.Tensor, - v_y: torch.Tensor, - indices: torch.Tensor, - ): - """ - Args: - q_x: (B, N, num_heads, head_dim) - k_x: (B, N, num_heads, head_dim) - v_x: (B, N, num_heads, head_dim) - q_y: (B, L, num_heads, head_dim) - k_y: (B, L, num_heads, head_dim) - v_y: (B, L, num_heads, head_dim) - indices: (total <= B * (N + L)) - - Returns: - qkv: (total <= B * (N + L), 3, num_heads, head_dim) - """ - ctx.save_for_backward(indices) - B, N, num_heads, head_dim = q_x.size() - ctx.B, ctx.N, ctx.L = B, N, q_y.size(1) - D = num_heads * head_dim - - q = torch.cat([q_x, q_y], dim=1) - k = torch.cat([k_x, k_y], dim=1) - v = torch.cat([v_x, v_y], dim=1) - qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D) - - indices = indices[:, None, None].expand(-1, 3, D) - qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim) - return qkv.unflatten(2, (num_heads, head_dim)) - - -def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor: - return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices) + \ No newline at end of file diff --git a/mochi_preview/vae/cp_conv.py b/mochi_preview/vae/cp_conv.py deleted file mode 100644 index e5e96de..0000000 --- a/mochi_preview/vae/cp_conv.py +++ /dev/null @@ -1,152 +0,0 @@ -from typing import Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size - - -def cast_tuple(t, length=1): - return t if isinstance(t, tuple) else ((t,) * length) - - -def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor: - """ - Forward pass that handles communication between ranks for inference. - Args: - x: Tensor of shape (B, C, T, H, W) - frames_to_send: int, number of frames to communicate between ranks - Returns: - output: Tensor of shape (B, C, T', H, W) - """ - cp_rank, cp_world_size = cp.get_cp_rank_size() - if frames_to_send == 0 or cp_world_size == 1: - return x - - group = get_cp_group() - global_rank = dist.get_rank() - - # Send to next rank - if cp_rank < cp_world_size - 1: - assert x.size(2) >= frames_to_send - tail = x[:, :, -frames_to_send:].contiguous() - dist.send(tail, global_rank + 1, group=group) - - # Receive from previous rank - if cp_rank > 0: - B, C, _, H, W = x.shape - recv_buffer = torch.empty( - (B, C, frames_to_send, H, W), - dtype=x.dtype, - device=x.device, - ) - dist.recv(recv_buffer, global_rank - 1, group=group) - x = torch.cat([recv_buffer, x], dim=2) - - return x - - -def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor: - if max_T > x.size(2): - pad_T = max_T - x.size(2) - pad_dims = (0, 0, 0, 0, 0, pad_T) - return F.pad(x, pad_dims) - return x - - -def gather_all_frames(x: torch.Tensor) -> torch.Tensor: - """ - Gathers all frames from all processes for inference. - Args: - x: Tensor of shape (B, C, T, H, W) - Returns: - output: Tensor of shape (B, C, T_total, H, W) - """ - cp_rank, cp_size = get_cp_rank_size() - cp_group = get_cp_group() - - # Ensure the tensor is contiguous for collective operations - x = x.contiguous() - - # Get the local time dimension size - local_T = x.size(2) - local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64) - - # Gather all T sizes from all processes - all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)] - dist.all_gather(all_T, local_T_tensor, group=cp_group) - all_T = [t.item() for t in all_T] - - # Pad the tensor at the end of the time dimension to match max_T - max_T = max(all_T) - x = _pad_to_max(x, max_T).contiguous() - - # Prepare a list to hold the gathered tensors - gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)] - - # Perform the all_gather operation - dist.all_gather(gathered_x, x, group=cp_group) - - # Slice each gathered tensor back to its original T size - for idx, t_size in enumerate(all_T): - gathered_x[idx] = gathered_x[idx][:, :, :t_size] - - return torch.cat(gathered_x, dim=2) - - -def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool: - """Estimate memory usage based on input tensor size and data type.""" - element_size = input.element_size() # Size in bytes of each element - memory_bytes = input.numel() * element_size - memory_gb = memory_bytes / 1024**3 - return memory_gb > max_gb - - -class ContextParallelCausalConv3d(torch.nn.Conv3d): - def __init__( - self, - in_channels, - out_channels, - kernel_size: Union[int, Tuple[int, int, int]], - stride: Union[int, Tuple[int, int, int]], - **kwargs, - ): - kernel_size = cast_tuple(kernel_size, 3) - stride = cast_tuple(stride, 3) - height_pad = (kernel_size[1] - 1) // 2 - width_pad = (kernel_size[2] - 1) // 2 - - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=(1, 1, 1), - padding=(0, height_pad, width_pad), - **kwargs, - ) - - def forward(self, x: torch.Tensor): - cp_rank, cp_world_size = get_cp_rank_size() - - context_size = self.kernel_size[0] - 1 - if cp_rank == 0: - mode = "constant" if self.padding_mode == "zeros" else self.padding_mode - x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode) - - if cp_world_size == 1: - return super().forward(x) - - if all(s == 1 for s in self.stride): - # Receive some frames from previous rank. - x = cp_pass_frames(x, context_size) - return super().forward(x) - - # Less efficient implementation for strided convs. - # All gather x, infer and chunk. - x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W] - x = super().forward(x) - x_chunks = x.tensor_split(cp_world_size, dim=2) - assert len(x_chunks) == cp_world_size - return x_chunks[cp_rank] diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py index 6cfeeae..e26add7 100644 --- a/mochi_preview/vae/model.py +++ b/mochi_preview/vae/model.py @@ -6,8 +6,6 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -#from ..dit.joint_model.context_parallel import get_cp_rank_size -#from ..vae.cp_conv import cp_pass_frames, gather_all_frames from .latent_dist import LatentDistribution def cast_tuple(t, length=1): @@ -96,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d): raise NotImplementedError +def mps_safe_pad(input, pad, mode): + if input.device.type == "mps" and input.numel() >= 2 ** 16: + device = input.device + input = input.to(device="cpu") + output = F.pad(input, pad, mode=mode) + return output.to(device=device) + else: + return F.pad(input, pad, mode=mode) class ContextParallelConv3d(SafeConv3d): def __init__( @@ -138,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d): # Apply padding. mode = "constant" if self.padding_mode == "zeros" else self.padding_mode if self.context_parallel: - x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) + x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) else: - x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode) + x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode) return super().forward(x) diff --git a/nodes.py b/nodes.py index 65fd380..1546747 100644 --- a/nodes.py +++ b/nodes.py @@ -59,7 +59,7 @@ class MochiSigmaSchedule: RETURN_NAMES = ("sigmas",) FUNCTION = "loadmodel" CATEGORY = "MochiWrapper" - DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" + DESCRIPTION = "Sigma schedule to use with mochi wrapper sampler" def loadmodel(self, num_steps, threshold_noise, denoise, linear_steps=None): total_steps = num_steps @@ -105,6 +105,7 @@ class DownloadAndLoadMochiModel: "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), "cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}), + "rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}), }, } @@ -114,7 +115,7 @@ class DownloadAndLoadMochiModel: CATEGORY = "MochiWrapper" DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface" - def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False): + def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -154,11 +155,11 @@ class DownloadAndLoadMochiModel: model = T2VSynthMochiModel( device=device, offload_device=offload_device, - vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), dit_checkpoint_path=model_path, weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, compile_args=compile_args, cublas_ops=cublas_ops ) @@ -180,7 +181,7 @@ class DownloadAndLoadMochiModel: vae_sd = load_torch_file(vae_path) if is_accelerate_available: for key in vae_sd: - set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key]) + set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key]) else: vae.load_state_dict(vae_sd, strict=True) vae.eval().to(torch.bfloat16).to("cpu") @@ -201,6 +202,7 @@ class MochiModelLoader: "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), "cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}), + "rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}), }, } @@ -209,7 +211,7 @@ class MochiModelLoader: FUNCTION = "loadmodel" CATEGORY = "MochiWrapper" - def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False): + def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -221,11 +223,11 @@ class MochiModelLoader: model = T2VSynthMochiModel( device=device, offload_device=offload_device, - vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), dit_checkpoint_path=model_path, weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, compile_args=compile_args, cublas_ops=cublas_ops ) @@ -473,6 +475,7 @@ class MochiSampler: CATEGORY = "MochiWrapper" def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None): + mm.unload_all_models() mm.soft_empty_cache() if opt_sigmas is not None: @@ -630,7 +633,7 @@ class MochiDecode: return torch.cat(result_rows, dim=3) vae.to(device) - with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16): + with torch.autocast(mm.get_autocast_device(device), dtype=vae.dtype): if enable_vae_tiling and frame_batch_size > T: logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}") frame_batch_size = T @@ -748,10 +751,16 @@ class MochiImageEncode: from .mochi_preview.vae.model import apply_tiled B, H, W, C = images.shape - images = images.unsqueeze(0) * 2 - 1 - images = rearrange(images, "t b h w c -> t c b h w") - images = images.to(device) - print(images.shape) + import torchvision.transforms as transforms + normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + input_image_tensor = rearrange(images, 'b h w c -> b c h w') + input_image_tensor = normalize(input_image_tensor).unsqueeze(0) + input_image_tensor = rearrange(input_image_tensor, 'b t c h w -> b c t h w', t=B) + + #images = images.unsqueeze(0).sub_(0.5).div_(0.5) + #images = rearrange(input_image_tensor, "b c t h w -> t c b h w") + images = input_image_tensor.to(device) + encoder.to(device) print("images before encoding", images.shape) with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):