2024-10-23 15:34:22 +03:00

153 lines
4.8 KiB
Python

from typing import Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
"""
Forward pass that handles communication between ranks for inference.
Args:
x: Tensor of shape (B, C, T, H, W)
frames_to_send: int, number of frames to communicate between ranks
Returns:
output: Tensor of shape (B, C, T', H, W)
"""
cp_rank, cp_world_size = cp.get_cp_rank_size()
if frames_to_send == 0 or cp_world_size == 1:
return x
group = get_cp_group()
global_rank = dist.get_rank()
# Send to next rank
if cp_rank < cp_world_size - 1:
assert x.size(2) >= frames_to_send
tail = x[:, :, -frames_to_send:].contiguous()
dist.send(tail, global_rank + 1, group=group)
# Receive from previous rank
if cp_rank > 0:
B, C, _, H, W = x.shape
recv_buffer = torch.empty(
(B, C, frames_to_send, H, W),
dtype=x.dtype,
device=x.device,
)
dist.recv(recv_buffer, global_rank - 1, group=group)
x = torch.cat([recv_buffer, x], dim=2)
return x
def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
if max_T > x.size(2):
pad_T = max_T - x.size(2)
pad_dims = (0, 0, 0, 0, 0, pad_T)
return F.pad(x, pad_dims)
return x
def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
"""
Gathers all frames from all processes for inference.
Args:
x: Tensor of shape (B, C, T, H, W)
Returns:
output: Tensor of shape (B, C, T_total, H, W)
"""
cp_rank, cp_size = get_cp_rank_size()
cp_group = get_cp_group()
# Ensure the tensor is contiguous for collective operations
x = x.contiguous()
# Get the local time dimension size
local_T = x.size(2)
local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)
# Gather all T sizes from all processes
all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
dist.all_gather(all_T, local_T_tensor, group=cp_group)
all_T = [t.item() for t in all_T]
# Pad the tensor at the end of the time dimension to match max_T
max_T = max(all_T)
x = _pad_to_max(x, max_T).contiguous()
# Prepare a list to hold the gathered tensors
gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
# Perform the all_gather operation
dist.all_gather(gathered_x, x, group=cp_group)
# Slice each gathered tensor back to its original T size
for idx, t_size in enumerate(all_T):
gathered_x[idx] = gathered_x[idx][:, :, :t_size]
return torch.cat(gathered_x, dim=2)
def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
"""Estimate memory usage based on input tensor size and data type."""
element_size = input.element_size() # Size in bytes of each element
memory_bytes = input.numel() * element_size
memory_gb = memory_bytes / 1024**3
return memory_gb > max_gb
class ContextParallelCausalConv3d(torch.nn.Conv3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
**kwargs,
):
kernel_size = cast_tuple(kernel_size, 3)
stride = cast_tuple(stride, 3)
height_pad = (kernel_size[1] - 1) // 2
width_pad = (kernel_size[2] - 1) // 2
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=(1, 1, 1),
padding=(0, height_pad, width_pad),
**kwargs,
)
def forward(self, x: torch.Tensor):
cp_rank, cp_world_size = get_cp_rank_size()
context_size = self.kernel_size[0] - 1
if cp_rank == 0:
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)
if cp_world_size == 1:
return super().forward(x)
if all(s == 1 for s in self.stride):
# Receive some frames from previous rank.
x = cp_pass_frames(x, context_size)
return super().forward(x)
# Less efficient implementation for strided convs.
# All gather x, infer and chunk.
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
x = super().forward(x)
x_chunks = x.tensor_split(cp_world_size, dim=2)
assert len(x_chunks) == cp_world_size
return x_chunks[cp_rank]