2025-02-12 00:50:38 +02:00

56 lines
1.6 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .core.spatracker.spatracker import SpaTracker
def build_spatracker(
checkpoint: str,
seq_length: int = 8,
):
model_name = checkpoint.split("/")[-1].split(".")[0]
return build_spatracker_from_cfg(checkpoint=checkpoint, seq_length=seq_length)
# model used to produce the results in the paper
def build_spatracker_from_cfg(checkpoint=None, seq_length=8):
return _build_spatracker(
stride=4,
sequence_len=seq_length,
checkpoint=checkpoint,
)
def _build_spatracker(
stride,
sequence_len,
checkpoint=None,
):
spatracker = SpaTracker(
stride=stride,
S=sequence_len,
add_space_attn=True,
space_depth=6,
time_depth=6,
)
if checkpoint is not None:
with open(checkpoint, "rb") as f:
if "safetensors" in checkpoint:
from safetensors.torch import load_file
state_dict = load_file(checkpoint)
else:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
if "model" in state_dict:
model_paras = spatracker.state_dict()
paras_dict = {k: v for k,v in state_dict["model"].items() if k in spatracker.state_dict()}
model_paras.update(paras_dict)
state_dict = model_paras
spatracker.load_state_dict(state_dict)
return spatracker