mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
parent
8f9fa07455
commit
e8bc2fd052
@ -20,6 +20,7 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import is_torch_version, logging
|
||||
@ -276,6 +277,8 @@ class CogVideoXBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
video_flow_feature: Optional[torch.Tensor] = None,
|
||||
fuser=None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
@ -284,6 +287,28 @@ class CogVideoXBlock(nn.Module):
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# Motion-guidance Fuser
|
||||
if video_flow_feature is not None:
|
||||
#print(video_flow_feature)
|
||||
#print("hidden_states.shape", hidden_states.shape)
|
||||
#print("tora_trajectory.shape", video_flow_feature.shape)
|
||||
|
||||
H, W = video_flow_feature.shape[-2:]
|
||||
T = norm_hidden_states.shape[1] // H // W
|
||||
|
||||
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W).to(torch.float16)
|
||||
#print("h.dtype", h.dtype)
|
||||
|
||||
#video_flow_feature = video_flow_feature.to(h)
|
||||
#print("video_flow_feature.dtype", video_flow_feature.dtype)
|
||||
|
||||
h = fuser(h, video_flow_feature.to(h), T=T)
|
||||
# if torch.any(torch.isnan(h)):
|
||||
# #print("hidden_states", h)
|
||||
# raise ValueError("hidden_states has NaN values")
|
||||
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
||||
del h, fuser
|
||||
|
||||
# attention
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
@ -458,6 +483,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.fuser_list = None
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
@ -570,6 +597,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
controlnet_states: torch.Tensor = None,
|
||||
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
|
||||
video_flow_features: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
@ -594,30 +622,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 3. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
)
|
||||
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
controlnet_states_block = controlnet_states[i]
|
||||
|
||||
147
nodes.py
147
nodes.py
@ -254,6 +254,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
"bertjiazheng/KoolCogVideoX-5b",
|
||||
"kijai/CogVideoX-Fun-2b",
|
||||
"kijai/CogVideoX-Fun-5b",
|
||||
"kijai/CogVideoX-5b-Tora",
|
||||
"alibaba-pai/CogVideoX-Fun-V1.1-2b-InP",
|
||||
"alibaba-pai/CogVideoX-Fun-V1.1-5b-InP",
|
||||
"alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose",
|
||||
@ -409,6 +410,38 @@ class DownloadAndLoadCogVideoModel:
|
||||
fuse_qkv_projections=True if pab_config is None else False,
|
||||
)
|
||||
|
||||
if "Tora" in model:
|
||||
import torch.nn as nn
|
||||
from .tora.traj_module import MGF
|
||||
|
||||
hidden_size = 3072
|
||||
num_layers = transformer.num_layers
|
||||
pipe.transformer.fuser_list = nn.ModuleList([MGF(128, hidden_size) for _ in range(num_layers)])
|
||||
fuser_sd = load_torch_file(os.path.join(base_path, "fuser", "fuser.safetensors"))
|
||||
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
|
||||
for module in transformer.fuser_list:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.to(torch.float16).to(device)
|
||||
del fuser_sd
|
||||
|
||||
from .tora.traj_module import TrajExtractor
|
||||
traj_extractor = TrajExtractor(
|
||||
vae_downsize=(4, 8, 8),
|
||||
patch_size=2,
|
||||
nums_rb=2,
|
||||
cin=vae.config.latent_channels,
|
||||
channels=[128] * transformer.num_layers,
|
||||
sk=True,
|
||||
use_conv=False,
|
||||
)
|
||||
|
||||
traj_sd = load_torch_file(os.path.join(base_path, "traj_extractor", "traj_extractor.safetensors"))
|
||||
traj_extractor.load_state_dict(traj_sd)
|
||||
traj_extractor.to(torch.float32).to(device)
|
||||
|
||||
pipe.traj_extractor = traj_extractor
|
||||
|
||||
|
||||
pipeline = {
|
||||
"pipe": pipe,
|
||||
"dtype": dtype,
|
||||
@ -950,6 +983,108 @@ class CogVideoImageInterpolationEncode:
|
||||
|
||||
return ({"samples": final_latents}, )
|
||||
|
||||
class ToraEncodeTrajectory:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"pipeline": ("COGVIDEOPIPE",),
|
||||
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
|
||||
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
|
||||
"num_frames": ("INT", {"default": 49, "min": 16, "max": 1024, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TORATRAJLIST",)
|
||||
RETURN_NAMES = ("tora_traj_list",)
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, width, height, num_frames):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
transformer = pipeline["pipe"].transformer
|
||||
vae = pipeline["pipe"].vae
|
||||
vae.enable_slicing()
|
||||
|
||||
canvas_width, canvas_height = 256, 256
|
||||
traj_list = PROVIDED_TRAJS["infinity"]
|
||||
traj_list_range_256 = scale_traj_list_to_256(traj_list, canvas_width, canvas_height)
|
||||
|
||||
|
||||
return (traj_list_range_256, )
|
||||
|
||||
from .tora.traj_utils import process_traj, scale_traj_list_to_256, PROVIDED_TRAJS
|
||||
from torchvision.utils import flow_to_image
|
||||
|
||||
class ToraEncodeTrajectory:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"pipeline": ("COGVIDEOPIPE",),
|
||||
"coordinates": ("STRING", {"forceInput": True}),
|
||||
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
|
||||
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
|
||||
"num_frames": ("INT", {"default": 49, "min": 16, "max": 1024, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TORAFEATURES",)
|
||||
RETURN_NAMES = ("tora_trajectory",)
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, width, height, num_frames, coordinates):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
traj_extractor = pipeline["pipe"].traj_extractor
|
||||
vae = pipeline["pipe"].vae
|
||||
vae.enable_slicing()
|
||||
|
||||
canvas_width, canvas_height = 256, 256
|
||||
coordinates = json.loads(coordinates.replace("'", '"'))
|
||||
coordinates = [(coord['x'], coord['y']) for coord in coordinates]
|
||||
|
||||
traj_list_range_256 = scale_traj_list_to_256(coordinates, canvas_width, canvas_height)
|
||||
|
||||
check_diffusers_version()
|
||||
vae._clear_fake_context_parallel_cache()
|
||||
|
||||
total_num_frames = num_frames
|
||||
|
||||
video_flow, points = process_traj(traj_list_range_256, total_num_frames, (height,width), device=device)
|
||||
video_flow = video_flow.unsqueeze_(0)
|
||||
|
||||
tmp = rearrange(video_flow[0], "T H W C -> T C H W")
|
||||
video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W]
|
||||
|
||||
del tmp
|
||||
video_flow = (
|
||||
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition
|
||||
|
||||
if not pipeline["cpu_offloading"]:
|
||||
vae.to(device)
|
||||
|
||||
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||
video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
|
||||
print("video_flow shape", video_flow.shape)
|
||||
|
||||
vae.to(offload_device)
|
||||
|
||||
video_flow = rearrange(video_flow, "b t d h w -> b d t h w")
|
||||
video_flow_features = traj_extractor(video_flow.to(torch.float32))
|
||||
video_flow_features = torch.stack(video_flow_features)
|
||||
|
||||
return (video_flow_features, )
|
||||
|
||||
|
||||
|
||||
class CogVideoSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -975,6 +1110,7 @@ class CogVideoSampler:
|
||||
"image_cond_latents": ("LATENT", ),
|
||||
"context_options": ("COGCONTEXT", ),
|
||||
"controlnet": ("COGVIDECONTROLNET",),
|
||||
"tora_trajectory": ("TORAFEATURES", ),
|
||||
}
|
||||
}
|
||||
|
||||
@ -984,7 +1120,7 @@ class CogVideoSampler:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, pipeline, positive, negative, steps, cfg, seed, height, width, num_frames, scheduler, samples=None,
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None):
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None):
|
||||
mm.soft_empty_cache()
|
||||
|
||||
base_path = pipeline["base_path"]
|
||||
@ -1042,7 +1178,8 @@ class CogVideoSampler:
|
||||
context_stride= context_stride,
|
||||
context_overlap= context_overlap,
|
||||
freenoise=context_options["freenoise"] if context_options is not None else None,
|
||||
controlnet=controlnet
|
||||
controlnet=controlnet,
|
||||
video_flow_features=tora_trajectory if tora_trajectory is not None else None,
|
||||
)
|
||||
if not pipeline["cpu_offloading"]:
|
||||
pipe.transformer.to(offload_device)
|
||||
@ -1586,7 +1723,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoLoraSelect": CogVideoLoraSelect,
|
||||
"CogVideoContextOptions": CogVideoContextOptions,
|
||||
"CogVideoControlNet": CogVideoControlNet,
|
||||
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet
|
||||
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
|
||||
"ToraEncodeTrajectory": ToraEncodeTrajectory,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||
@ -1606,5 +1744,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
|
||||
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
||||
"CogVideoContextOptions": "CogVideo Context Options",
|
||||
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet"
|
||||
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
|
||||
"ToraEncodeTrajectory": "Tora Encode Trajectory",
|
||||
}
|
||||
|
||||
@ -161,6 +161,8 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
self.original_mask = original_mask
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
self.traj_extractor = None
|
||||
|
||||
if pab_config is not None:
|
||||
set_pab_manager(pab_config)
|
||||
|
||||
@ -388,6 +390,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
context_overlap: Optional[int] = None,
|
||||
freenoise: Optional[bool] = True,
|
||||
controlnet: Optional[dict] = None,
|
||||
video_flow_features: Optional[torch.Tensor] = None,
|
||||
|
||||
):
|
||||
"""
|
||||
@ -848,7 +851,8 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
if isinstance(controlnet_states, (tuple, list)):
|
||||
controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states]
|
||||
else:
|
||||
controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
|
||||
controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
|
||||
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
@ -859,6 +863,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
return_dict=False,
|
||||
controlnet_states=controlnet_states,
|
||||
controlnet_weights=control_weights,
|
||||
video_flow_features=video_flow_features,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
|
||||
297
tora/traj_module.py
Normal file
297
tora/traj_module.py
Normal file
@ -0,0 +1,297 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, reduce
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
|
||||
super().__init__()
|
||||
ps = ksize // 2
|
||||
if in_c != out_c or sk == False:
|
||||
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
||||
else:
|
||||
# print('n_in')
|
||||
self.in_conv = None
|
||||
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
|
||||
self.act = nn.ReLU()
|
||||
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
|
||||
self.bn1 = nn.BatchNorm2d(out_c)
|
||||
self.bn2 = nn.BatchNorm2d(out_c)
|
||||
if sk == False:
|
||||
# self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) # edit by zhouxiawang
|
||||
self.skep = nn.Conv2d(out_c, out_c, ksize, 1, ps)
|
||||
else:
|
||||
self.skep = None
|
||||
|
||||
self.down = down
|
||||
if self.down == True:
|
||||
self.down_opt = Downsample(in_c, use_conv=use_conv)
|
||||
|
||||
def forward(self, x):
|
||||
if self.down == True:
|
||||
x = self.down_opt(x)
|
||||
if self.in_conv is not None: # edit
|
||||
x = self.in_conv(x)
|
||||
|
||||
h = self.bn1(x)
|
||||
h = self.act(h)
|
||||
h = self.block1(h)
|
||||
h = self.bn2(h)
|
||||
h = self.act(h)
|
||||
h = self.block2(h)
|
||||
if self.skep is not None:
|
||||
return h + self.skep(x)
|
||||
else:
|
||||
return h + x
|
||||
|
||||
|
||||
class VAESpatialEmulator(nn.Module):
|
||||
def __init__(self, kernel_size=(8, 8)):
|
||||
super().__init__()
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: torch.Tensor: shape [B C T H W]
|
||||
"""
|
||||
Hp, Wp = self.kernel_size
|
||||
H, W = x.shape[-2], x.shape[-1]
|
||||
valid_h = H - H % Hp
|
||||
valid_w = W - W % Wp
|
||||
x = x[..., :valid_h, :valid_w]
|
||||
x = rearrange(
|
||||
x,
|
||||
"B C T (Nh Hp) (Nw Wp) -> B (Hp Wp C) T Nh Nw",
|
||||
Hp=Hp,
|
||||
Wp=Wp,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class VAETemporalEmulator(nn.Module):
|
||||
def __init__(self, micro_frame_size, kernel_size=4):
|
||||
super().__init__()
|
||||
self.micro_frame_size = micro_frame_size
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
def forward(self, x_z):
|
||||
"""
|
||||
x_z: torch.Tensor: shape [B C T H W]
|
||||
"""
|
||||
|
||||
z_list = []
|
||||
for i in range(0, x_z.shape[2], self.micro_frame_size):
|
||||
x_z_bs = x_z[:, :, i : i + self.micro_frame_size]
|
||||
z_list.append(x_z_bs[:, :, 0:1])
|
||||
x_z_bs = x_z_bs[:, :, 1:]
|
||||
t_valid = x_z_bs.shape[2] - x_z_bs.shape[2] % self.kernel_size
|
||||
x_z_bs = x_z_bs[:, :, :t_valid]
|
||||
x_z_bs = reduce(x_z_bs, "B C (T n) H W -> B C T H W", n=self.kernel_size, reduction="mean")
|
||||
z_list.append(x_z_bs)
|
||||
z = torch.cat(z_list, dim=2)
|
||||
return z
|
||||
|
||||
|
||||
class TrajExtractor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vae_downsize=(4, 8, 8),
|
||||
patch_size=2,
|
||||
channels=[320, 640, 1280, 1280],
|
||||
nums_rb=3,
|
||||
cin=2,
|
||||
ksize=3,
|
||||
sk=False,
|
||||
use_conv=True,
|
||||
):
|
||||
super(TrajExtractor, self).__init__()
|
||||
self.vae_downsize = vae_downsize
|
||||
# self.vae_spatial_emulator = VAESpatialEmulator(kernel_size=vae_downsize[-2:])
|
||||
self.downsize_patchify = nn.PixelUnshuffle(patch_size)
|
||||
self.patch_size = (1, patch_size, patch_size)
|
||||
self.channels = channels
|
||||
self.nums_rb = nums_rb
|
||||
self.body = []
|
||||
for i in range(len(channels)):
|
||||
for j in range(nums_rb):
|
||||
if (i != 0) and (j == 0):
|
||||
self.body.append(
|
||||
ResnetBlock(
|
||||
channels[i - 1],
|
||||
channels[i],
|
||||
down=False,
|
||||
ksize=ksize,
|
||||
sk=sk,
|
||||
use_conv=use_conv,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.body.append(
|
||||
ResnetBlock(
|
||||
channels[i],
|
||||
channels[i],
|
||||
down=False,
|
||||
ksize=ksize,
|
||||
sk=sk,
|
||||
use_conv=use_conv,
|
||||
)
|
||||
)
|
||||
self.body = nn.ModuleList(self.body)
|
||||
cin_ = cin * patch_size**2
|
||||
self.conv_in = nn.Conv2d(cin_, channels[0], 3, 1, 1)
|
||||
|
||||
# Initialize weights
|
||||
def conv_init(module):
|
||||
if isinstance(module, (nn.Conv2d, nn.Conv1d)):
|
||||
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(conv_init)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: torch.Tensor: shape [B C T H W]
|
||||
"""
|
||||
# downsize
|
||||
T, H, W = x.shape[-3:]
|
||||
if W % self.patch_size[2] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
||||
if H % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
||||
if T % self.patch_size[0] != 0:
|
||||
x = F.pad(
|
||||
x,
|
||||
(0, 0, 0, 0, 0, self.patch_size[0] - T % self.patch_size[0]),
|
||||
)
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||
x = self.downsize_patchify(x)
|
||||
|
||||
# extract features
|
||||
features = []
|
||||
x = self.conv_in(x)
|
||||
for i in range(len(self.channels)):
|
||||
for j in range(self.nums_rb):
|
||||
idx = i * self.nums_rb + j
|
||||
x = self.body[idx](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class FloatGroupNorm(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.to(self.bias.dtype)).type(x.dtype)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
class MGF(nn.Module):
|
||||
def __init__(self, flow_in_channel=128, out_channels=1152):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.flow_gamma_spatial = nn.Conv2d(flow_in_channel, self.out_channels // 4, 3, padding=1)
|
||||
self.flow_gamma_temporal = zero_module(
|
||||
nn.Conv1d(
|
||||
self.out_channels // 4,
|
||||
self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
padding_mode="replicate",
|
||||
)
|
||||
)
|
||||
self.flow_beta_spatial = nn.Conv2d(flow_in_channel, self.out_channels // 4, 3, padding=1)
|
||||
self.flow_beta_temporal = zero_module(
|
||||
nn.Conv1d(
|
||||
self.out_channels // 4,
|
||||
self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
padding_mode="replicate",
|
||||
)
|
||||
)
|
||||
self.flow_cond_norm = FloatGroupNorm(32, self.out_channels)
|
||||
|
||||
def forward(self, h, flow, T):
|
||||
if flow is not None:
|
||||
gamma_flow = self.flow_gamma_spatial(flow)
|
||||
beta_flow = self.flow_beta_spatial(flow)
|
||||
_, _, hh, wh = beta_flow.shape
|
||||
gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T)
|
||||
beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T)
|
||||
gamma_flow = self.flow_gamma_temporal(gamma_flow)
|
||||
beta_flow = self.flow_beta_temporal(beta_flow)
|
||||
gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
|
||||
beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
|
||||
h = h + self.flow_cond_norm(h) * gamma_flow + beta_flow
|
||||
return h
|
||||
670
tora/traj_utils.py
Normal file
670
tora/traj_utils.py
Normal file
@ -0,0 +1,670 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
# Note that the coordinates passed to the model must not exceed 256.
|
||||
# xy range 256
|
||||
PROVIDED_TRAJS = {
|
||||
"circle1": [
|
||||
[120, 194],
|
||||
[144, 193],
|
||||
[155, 189],
|
||||
[158, 170],
|
||||
[160, 153],
|
||||
[159, 123],
|
||||
[152, 113],
|
||||
[136, 100],
|
||||
[124, 100],
|
||||
[108, 100],
|
||||
[101, 106],
|
||||
[90, 110],
|
||||
[84, 129],
|
||||
[79, 146],
|
||||
[78, 165],
|
||||
[83, 182],
|
||||
[87, 189],
|
||||
[94, 192],
|
||||
[100, 194],
|
||||
[106, 194],
|
||||
[112, 194],
|
||||
[118, 195],
|
||||
],
|
||||
"circle2": [
|
||||
[100, 127],
|
||||
[105, 117],
|
||||
[122, 117],
|
||||
[132, 129],
|
||||
[133, 158],
|
||||
[125, 181],
|
||||
[108, 189],
|
||||
[92, 185],
|
||||
[84, 179],
|
||||
[79, 163],
|
||||
[75, 142],
|
||||
[73, 118],
|
||||
[75, 82],
|
||||
[91, 63],
|
||||
[115, 52],
|
||||
[139, 46],
|
||||
[154, 55],
|
||||
[167, 93],
|
||||
[175, 112],
|
||||
[177, 137],
|
||||
[177, 158],
|
||||
[177, 171],
|
||||
[175, 188],
|
||||
[173, 204],
|
||||
],
|
||||
"coaster": [
|
||||
[40, 208],
|
||||
[40, 148],
|
||||
[40, 100],
|
||||
[52, 58],
|
||||
[60, 57],
|
||||
[74, 68],
|
||||
[78, 90],
|
||||
[84, 123],
|
||||
[88, 148],
|
||||
[96, 168],
|
||||
[100, 181],
|
||||
[102, 188],
|
||||
[105, 192],
|
||||
[113, 118],
|
||||
[119, 80],
|
||||
[128, 68],
|
||||
[145, 109],
|
||||
[149, 155],
|
||||
[157, 175],
|
||||
[161, 184],
|
||||
[164, 184],
|
||||
[172, 166],
|
||||
[183, 107],
|
||||
[189, 84],
|
||||
[198, 76],
|
||||
],
|
||||
"dance": [
|
||||
[81, 112],
|
||||
[86, 112],
|
||||
[92, 112],
|
||||
[100, 113],
|
||||
[102, 114],
|
||||
[97, 115],
|
||||
[92, 114],
|
||||
[86, 112],
|
||||
[81, 112],
|
||||
[80, 112],
|
||||
[84, 113],
|
||||
[89, 114],
|
||||
[95, 114],
|
||||
[101, 114],
|
||||
[102, 114],
|
||||
[103, 124],
|
||||
[105, 137],
|
||||
[109, 156],
|
||||
[114, 172],
|
||||
[119, 180],
|
||||
[124, 184],
|
||||
[131, 181],
|
||||
[140, 168],
|
||||
[146, 152],
|
||||
[150, 128],
|
||||
[151, 117],
|
||||
[152, 116],
|
||||
[156, 116],
|
||||
[163, 115],
|
||||
[169, 116],
|
||||
[175, 116],
|
||||
[173, 116],
|
||||
[167, 116],
|
||||
[162, 114],
|
||||
[157, 114],
|
||||
[152, 115],
|
||||
[156, 115],
|
||||
[163, 115],
|
||||
[168, 115],
|
||||
[174, 116],
|
||||
[175, 116],
|
||||
[168, 116],
|
||||
[162, 116],
|
||||
[152, 114],
|
||||
[149, 134],
|
||||
[145, 156],
|
||||
[139, 168],
|
||||
[130, 183],
|
||||
[118, 180],
|
||||
[112, 170],
|
||||
[107, 151],
|
||||
[102, 128],
|
||||
[103, 117],
|
||||
[96, 113],
|
||||
[88, 113],
|
||||
[83, 112],
|
||||
[80, 112],
|
||||
],
|
||||
"infinity": [
|
||||
[60, 141],
|
||||
[71, 127],
|
||||
[92, 120],
|
||||
[112, 123],
|
||||
[130, 145],
|
||||
[145, 163],
|
||||
[167, 178],
|
||||
[189, 187],
|
||||
[206, 176],
|
||||
[213, 147],
|
||||
[208, 124],
|
||||
[190, 112],
|
||||
[176, 111],
|
||||
[158, 124],
|
||||
[145, 147],
|
||||
[125, 172],
|
||||
[104, 189],
|
||||
[72, 189],
|
||||
[59, 184],
|
||||
[55, 153],
|
||||
[57, 140],
|
||||
[75, 119],
|
||||
[112, 118],
|
||||
[129, 142],
|
||||
[149, 163],
|
||||
[168, 180],
|
||||
[194, 186],
|
||||
[206, 175],
|
||||
[211, 159],
|
||||
[212, 149],
|
||||
[212, 134],
|
||||
[206, 122],
|
||||
[180, 112],
|
||||
[163, 116],
|
||||
[149, 138],
|
||||
[128, 170],
|
||||
[108, 184],
|
||||
[86, 190],
|
||||
[63, 181],
|
||||
[57, 152],
|
||||
[57, 139],
|
||||
],
|
||||
"pause": [
|
||||
[98, 186],
|
||||
[100, 188],
|
||||
[98, 186],
|
||||
[100, 188],
|
||||
[101, 187],
|
||||
[104, 187],
|
||||
[111, 184],
|
||||
[116, 176],
|
||||
[125, 162],
|
||||
[132, 140],
|
||||
[136, 119],
|
||||
[137, 104],
|
||||
[138, 96],
|
||||
[139, 94],
|
||||
[140, 94],
|
||||
[140, 96],
|
||||
[138, 98],
|
||||
[138, 96],
|
||||
[136, 94],
|
||||
[137, 92],
|
||||
[140, 92],
|
||||
[144, 92],
|
||||
[149, 92],
|
||||
[152, 92],
|
||||
[151, 92],
|
||||
[147, 92],
|
||||
[142, 92],
|
||||
[140, 92],
|
||||
[139, 95],
|
||||
[139, 105],
|
||||
[141, 122],
|
||||
[142, 143],
|
||||
[140, 167],
|
||||
[136, 184],
|
||||
[135, 188],
|
||||
[132, 195],
|
||||
[132, 192],
|
||||
[131, 192],
|
||||
[131, 192],
|
||||
[130, 192],
|
||||
[130, 195],
|
||||
],
|
||||
"shake": [
|
||||
[103, 89],
|
||||
[104, 89],
|
||||
[106, 89],
|
||||
[107, 89],
|
||||
[108, 89],
|
||||
[109, 89],
|
||||
[110, 89],
|
||||
[111, 89],
|
||||
[112, 89],
|
||||
[113, 89],
|
||||
[114, 89],
|
||||
[115, 89],
|
||||
[116, 89],
|
||||
[117, 89],
|
||||
[118, 89],
|
||||
[119, 89],
|
||||
[120, 89],
|
||||
[122, 89],
|
||||
[123, 89],
|
||||
[124, 89],
|
||||
[125, 89],
|
||||
[126, 89],
|
||||
[127, 88],
|
||||
[128, 88],
|
||||
[129, 88],
|
||||
[130, 88],
|
||||
[131, 88],
|
||||
[133, 87],
|
||||
[136, 86],
|
||||
[137, 86],
|
||||
[138, 86],
|
||||
[139, 86],
|
||||
[140, 86],
|
||||
[141, 86],
|
||||
[142, 86],
|
||||
[143, 86],
|
||||
[144, 86],
|
||||
[145, 86],
|
||||
[146, 87],
|
||||
[147, 87],
|
||||
[148, 87],
|
||||
[149, 87],
|
||||
[148, 87],
|
||||
[146, 87],
|
||||
[145, 88],
|
||||
[144, 88],
|
||||
[142, 89],
|
||||
[141, 89],
|
||||
[140, 90],
|
||||
[140, 91],
|
||||
[138, 91],
|
||||
[137, 92],
|
||||
[136, 92],
|
||||
[136, 93],
|
||||
[135, 93],
|
||||
[134, 93],
|
||||
[133, 93],
|
||||
[132, 93],
|
||||
[131, 93],
|
||||
[130, 93],
|
||||
[129, 93],
|
||||
[128, 93],
|
||||
[127, 92],
|
||||
[125, 92],
|
||||
[124, 92],
|
||||
[123, 92],
|
||||
[122, 92],
|
||||
[121, 92],
|
||||
[120, 92],
|
||||
[119, 92],
|
||||
[118, 92],
|
||||
[117, 92],
|
||||
[116, 92],
|
||||
[115, 92],
|
||||
[113, 92],
|
||||
[112, 92],
|
||||
[111, 92],
|
||||
[110, 92],
|
||||
[109, 92],
|
||||
[108, 92],
|
||||
[108, 91],
|
||||
[108, 90],
|
||||
[109, 90],
|
||||
[110, 90],
|
||||
[111, 89],
|
||||
[112, 89],
|
||||
[113, 89],
|
||||
[114, 89],
|
||||
[115, 89],
|
||||
[115, 88],
|
||||
[116, 88],
|
||||
[117, 88],
|
||||
[118, 88],
|
||||
[118, 87],
|
||||
[119, 87],
|
||||
[120, 87],
|
||||
[121, 87],
|
||||
[122, 86],
|
||||
[123, 86],
|
||||
[124, 86],
|
||||
[125, 86],
|
||||
[126, 85],
|
||||
[127, 85],
|
||||
[128, 85],
|
||||
[129, 85],
|
||||
[130, 85],
|
||||
[131, 85],
|
||||
[132, 85],
|
||||
[133, 85],
|
||||
[134, 85],
|
||||
[135, 85],
|
||||
[136, 85],
|
||||
[137, 85],
|
||||
[138, 85],
|
||||
[139, 85],
|
||||
[140, 85],
|
||||
[141, 85],
|
||||
[142, 85],
|
||||
[143, 85],
|
||||
[143, 84],
|
||||
[144, 84],
|
||||
[145, 84],
|
||||
[146, 84],
|
||||
[147, 84],
|
||||
[148, 84],
|
||||
[149, 84],
|
||||
[148, 84],
|
||||
[147, 84],
|
||||
[145, 84],
|
||||
[144, 84],
|
||||
[143, 84],
|
||||
[142, 84],
|
||||
[141, 84],
|
||||
[140, 85],
|
||||
[139, 85],
|
||||
[138, 85],
|
||||
[137, 86],
|
||||
[136, 86],
|
||||
[136, 87],
|
||||
[135, 87],
|
||||
[134, 87],
|
||||
[133, 87],
|
||||
[132, 88],
|
||||
[131, 88],
|
||||
[130, 88],
|
||||
[129, 88],
|
||||
[129, 89],
|
||||
[128, 89],
|
||||
[127, 89],
|
||||
[126, 89],
|
||||
[125, 89],
|
||||
[124, 90],
|
||||
[123, 90],
|
||||
[122, 90],
|
||||
[121, 90],
|
||||
[120, 91],
|
||||
[119, 91],
|
||||
[118, 91],
|
||||
[117, 91],
|
||||
[116, 91],
|
||||
[115, 91],
|
||||
[114, 91],
|
||||
[113, 91],
|
||||
[112, 91],
|
||||
[111, 91],
|
||||
[110, 91],
|
||||
[109, 91],
|
||||
[109, 90],
|
||||
[108, 90],
|
||||
[110, 90],
|
||||
[111, 90],
|
||||
[113, 90],
|
||||
[114, 90],
|
||||
[115, 90],
|
||||
[116, 90],
|
||||
[118, 90],
|
||||
[120, 90],
|
||||
[121, 90],
|
||||
[122, 90],
|
||||
[123, 90],
|
||||
[124, 90],
|
||||
[126, 90],
|
||||
[127, 90],
|
||||
[128, 90],
|
||||
[129, 90],
|
||||
[130, 90],
|
||||
[131, 90],
|
||||
[132, 90],
|
||||
[133, 90],
|
||||
[134, 90],
|
||||
[135, 90],
|
||||
[136, 90],
|
||||
[137, 90],
|
||||
[138, 90],
|
||||
[139, 90],
|
||||
[140, 90],
|
||||
[141, 89],
|
||||
[142, 89],
|
||||
[143, 89],
|
||||
[144, 89],
|
||||
[145, 89],
|
||||
[146, 89],
|
||||
[147, 89],
|
||||
[147, 89],
|
||||
[147, 89],
|
||||
],
|
||||
"spiral": [
|
||||
[16, 152],
|
||||
[23, 138],
|
||||
[39, 122],
|
||||
[54, 115],
|
||||
[75, 118],
|
||||
[88, 130],
|
||||
[93, 150],
|
||||
[89, 176],
|
||||
[75, 184],
|
||||
[63, 177],
|
||||
[65, 152],
|
||||
[77, 135],
|
||||
[98, 121],
|
||||
[116, 120],
|
||||
[135, 127],
|
||||
[148, 136],
|
||||
[156, 145],
|
||||
[160, 165],
|
||||
[158, 176],
|
||||
[138, 187],
|
||||
[133, 185],
|
||||
[129, 148],
|
||||
[140, 133],
|
||||
[156, 120],
|
||||
[177, 118],
|
||||
[197, 118],
|
||||
[214, 119],
|
||||
[225, 118],
|
||||
],
|
||||
}
|
||||
|
||||
def pdf2(sigma_matrix, grid):
|
||||
"""Calculate PDF of the bivariate Gaussian distribution.
|
||||
Args:
|
||||
sigma_matrix (ndarray): with the shape (2, 2)
|
||||
grid (ndarray): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size.
|
||||
Returns:
|
||||
kernel (ndarrray): un-normalized kernel.
|
||||
"""
|
||||
inverse_sigma = np.linalg.inv(sigma_matrix)
|
||||
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
||||
return kernel
|
||||
|
||||
|
||||
def mesh_grid(kernel_size):
|
||||
"""Generate the mesh grid, centering at zero.
|
||||
Args:
|
||||
kernel_size (int):
|
||||
Returns:
|
||||
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
||||
xx (ndarray): with the shape (kernel_size, kernel_size)
|
||||
yy (ndarray): with the shape (kernel_size, kernel_size)
|
||||
"""
|
||||
ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
|
||||
xx, yy = np.meshgrid(ax, ax)
|
||||
xy = np.hstack(
|
||||
(
|
||||
xx.reshape((kernel_size * kernel_size, 1)),
|
||||
yy.reshape(kernel_size * kernel_size, 1),
|
||||
)
|
||||
).reshape(kernel_size, kernel_size, 2)
|
||||
return xy, xx, yy
|
||||
|
||||
|
||||
def sigma_matrix2(sig_x, sig_y, theta):
|
||||
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
||||
Args:
|
||||
sig_x (float):
|
||||
sig_y (float):
|
||||
theta (float): Radian measurement.
|
||||
Returns:
|
||||
ndarray: Rotated sigma matrix.
|
||||
"""
|
||||
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
|
||||
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
||||
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
|
||||
|
||||
|
||||
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
|
||||
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
|
||||
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sig_x (float):
|
||||
sig_y (float):
|
||||
theta (float): Radian measurement.
|
||||
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size. Default: None
|
||||
isotropic (bool):
|
||||
Returns:
|
||||
kernel (ndarray): normalized kernel.
|
||||
"""
|
||||
if grid is None:
|
||||
grid, _, _ = mesh_grid(kernel_size)
|
||||
if isotropic:
|
||||
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
||||
else:
|
||||
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
||||
kernel = pdf2(sigma_matrix, grid)
|
||||
kernel = kernel / np.sum(kernel)
|
||||
return kernel
|
||||
|
||||
size = 99
|
||||
sigma = 10
|
||||
blur_kernel = bivariate_Gaussian(size, sigma, sigma, 0, grid=None, isotropic=True)
|
||||
blur_kernel = blur_kernel / blur_kernel[size // 2, size // 2]
|
||||
|
||||
canvas_width, canvas_height = 256, 256
|
||||
|
||||
def get_flow(points, optical_flow, video_len):
|
||||
for i in range(video_len - 1):
|
||||
p = points[i]
|
||||
p1 = points[i + 1]
|
||||
optical_flow[i + 1, p[1], p[0], 0] = p1[0] - p[0]
|
||||
optical_flow[i + 1, p[1], p[0], 1] = p1[1] - p[1]
|
||||
|
||||
return optical_flow
|
||||
|
||||
|
||||
def process_points(points, frames=49):
|
||||
defualt_points = [[128, 128]] * frames
|
||||
|
||||
if len(points) < 2:
|
||||
return defualt_points
|
||||
|
||||
elif len(points) >= frames:
|
||||
skip = len(points) // frames
|
||||
return points[::skip][: frames - 1] + points[-1:]
|
||||
else:
|
||||
insert_num = frames - len(points)
|
||||
insert_num_dict = {}
|
||||
interval = len(points) - 1
|
||||
n = insert_num // interval
|
||||
m = insert_num % interval
|
||||
for i in range(interval):
|
||||
insert_num_dict[i] = n
|
||||
for i in range(m):
|
||||
insert_num_dict[i] += 1
|
||||
|
||||
res = []
|
||||
for i in range(interval):
|
||||
insert_points = []
|
||||
x0, y0 = points[i]
|
||||
x1, y1 = points[i + 1]
|
||||
|
||||
delta_x = x1 - x0
|
||||
delta_y = y1 - y0
|
||||
for j in range(insert_num_dict[i]):
|
||||
x = x0 + (j + 1) / (insert_num_dict[i] + 1) * delta_x
|
||||
y = y0 + (j + 1) / (insert_num_dict[i] + 1) * delta_y
|
||||
insert_points.append([int(x), int(y)])
|
||||
|
||||
res += points[i : i + 1] + insert_points
|
||||
res += points[-1:]
|
||||
return res
|
||||
|
||||
|
||||
def read_points_from_list(traj_list, video_len=16, reverse=False):
|
||||
points = []
|
||||
for point in traj_list:
|
||||
if isinstance(point, str):
|
||||
x, y = point.strip().split(",")
|
||||
else:
|
||||
x, y = point[0], point[1]
|
||||
points.append((int(x), int(y)))
|
||||
if reverse:
|
||||
points = points[::-1]
|
||||
|
||||
if len(points) > video_len:
|
||||
skip = len(points) // video_len
|
||||
points = points[::skip]
|
||||
points = points[:video_len]
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def read_points_from_file(file, video_len=16, reverse=False):
|
||||
with open(file, "r") as f:
|
||||
lines = f.readlines()
|
||||
points = []
|
||||
for line in lines:
|
||||
x, y = line.strip().split(",")
|
||||
points.append((int(x), int(y)))
|
||||
if reverse:
|
||||
points = points[::-1]
|
||||
|
||||
if len(points) > video_len:
|
||||
skip = len(points) // video_len
|
||||
points = points[::skip]
|
||||
points = points[:video_len]
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def process_traj(trajs_list, num_frames, video_size, device="cpu"):
|
||||
if trajs_list and trajs_list[0] and (not isinstance(trajs_list[0][0], (list, tuple))):
|
||||
tmp = trajs_list
|
||||
trajs_list = [tmp]
|
||||
|
||||
optical_flow = np.zeros((num_frames, video_size[0], video_size[1], 2), dtype=np.float32)
|
||||
processed_points = []
|
||||
for traj_list in trajs_list:
|
||||
points = read_points_from_list(traj_list, video_len=num_frames)
|
||||
xy_range = 256
|
||||
h, w = video_size
|
||||
points = process_points(points, num_frames)
|
||||
points = [[int(w * x / xy_range), int(h * y / xy_range)] for x, y in points]
|
||||
optical_flow = get_flow(points, optical_flow, video_len=num_frames)
|
||||
processed_points.append(points)
|
||||
|
||||
print(f"received {len(trajs_list)} trajectorie(s)")
|
||||
|
||||
for i in range(1, num_frames):
|
||||
optical_flow[i] = cv2.filter2D(optical_flow[i], -1, blur_kernel)
|
||||
|
||||
optical_flow = torch.tensor(optical_flow).to(device)
|
||||
|
||||
return optical_flow, processed_points
|
||||
|
||||
|
||||
def add_provided_traj(traj_name):
|
||||
global traj_list
|
||||
traj_list = PROVIDED_TRAJS[traj_name]
|
||||
traj_str = [f"{traj}" for traj in traj_list]
|
||||
return ", ".join(traj_str)
|
||||
|
||||
|
||||
def scale_traj_list_to_256(traj_list, canvas_width, canvas_height):
|
||||
scale_x = 256 / canvas_width
|
||||
scale_y = 256 / canvas_height
|
||||
scaled_traj_list = [[int(x * scale_x), int(y * scale_y)] for x, y in traj_list]
|
||||
return scaled_traj_list
|
||||
Loading…
x
Reference in New Issue
Block a user