From 6f88198cc548b94413a72333982244f76d515875 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 4 Apr 2024 01:02:12 +0300 Subject: [PATCH] Add CameraPoseVisualizer --- nodes.py | 131 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 129 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index c7644bb..80dd68f 100644 --- a/nodes.py +++ b/nodes.py @@ -3953,7 +3953,7 @@ class RemapImageRange: RETURN_TYPES = ("IMAGE",) FUNCTION = "remap" - CATEGORY = "Marigold" + CATEGORY = "KJNodes" def remap(self, image, min, max, clamp): if image.dtype == torch.float16: @@ -3963,6 +3963,131 @@ class RemapImageRange: image = torch.clamp(image, min=0.0, max=1.0) return (image, ) +class CameraPoseVisualizer: + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "pose_file_path": ("STRING", {"default": 'pose file path here', "multiline": False}), + "sample_stride": ("INT", {"default": 1,"min": 0, "max": 100, "step": 1}), + "frames": ("INT", {"default": 16,"min": 0, "max": 100, "step": 1}), + "base_xval": ("FLOAT", {"default": 0.5,"min": 0, "max": 100, "step": 0.01}), + "zval": ("FLOAT", {"default": 2.0,"min": 0, "max": 100, "step": 0.01}), + "use_exact_fx": ("BOOLEAN", {"default": True}), + "relative_c2w": ("BOOLEAN", {"default": True}), + "x_min": ("FLOAT", {"default": -5.0,"min": -100, "max": 100, "step": 0.01}), + "x_max": ("FLOAT", {"default": 5.0,"min": -100, "max": 100, "step": 0.01}), + "y_min": ("FLOAT", {"default": -5.0,"min": -100, "max": 100, "step": 0.01}), + "y_max": ("FLOAT", {"default": 5.0,"min": -100, "max": 100, "step": 0.01}), + "z_min": ("FLOAT", {"default": -5.0,"min": -100, "max": 100, "step": 0.01}), + "z_max": ("FLOAT", {"default": 5.0,"min": -100, "max": 100, "step": 0.01}), + "use_viewer": ("BOOLEAN", {"default": False}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "plot" + + CATEGORY = "KJNodes" + + def plot(self, pose_file_path, sample_stride, frames, base_xval, zval, use_exact_fx, relative_c2w, x_min, x_max, y_min, y_max, z_min, z_max, use_viewer): + import matplotlib as mpl + import matplotlib.pyplot as plt + import io + from torchvision.transforms import ToTensor + self.fig = plt.figure(figsize=(18, 7)) + self.ax = self.fig.add_subplot(projection='3d') + self.plotly_data = None # plotly data traces + self.ax.set_aspect("auto") + self.ax.set_xlim(x_min, x_max) + self.ax.set_ylim(y_min, y_max) + self.ax.set_zlim(z_min, z_max) + self.ax.set_xlabel('x') + self.ax.set_ylabel('y') + self.ax.set_zlabel('z') + print('initialize camera pose visualizer') + with open(pose_file_path, 'r') as f: + poses = f.readlines() + w2cs = [np.asarray([float(p) for p in pose.strip().split(' ')[7:]]).reshape(3, 4) for pose in poses[1:]] + fxs = [float(pose.strip().split(' ')[1]) for pose in poses[1:]] + + cropped_length = frames * sample_stride + total_frames = len(w2cs) + start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1)) + end_frame_ind = min(start_frame_ind + cropped_length, total_frames) + frame_ind = np.linspace(start_frame_ind, end_frame_ind - 1, frames, dtype=int) + w2cs = [w2cs[x] for x in frame_ind] + transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4) + last_row = np.zeros((1, 4)) + last_row[0, -1] = 1.0 + w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs] + c2ws = self.get_c2w(w2cs, transform_matrix, relative_c2w) + + for frame_idx, c2w in enumerate(c2ws): + self.extrinsic2pyramid(c2w, frame_idx / frames, hw_ratio=1/1, base_xval=base_xval, + zval=(fxs[frame_idx] if use_exact_fx else zval)) + + cmap = mpl.cm.rainbow + norm = mpl.colors.Normalize(vmin=0, vmax=frames) + self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax, orientation='vertical', label='Frame Number') + plt.title('Extrinsic Parameters') + plt.draw() + buf = io.BytesIO() + plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) + buf.seek(0) + img = Image.open(buf) + tensor_img = ToTensor()(img) + buf.close() + tensor_img = tensor_img.permute(1, 2, 0).unsqueeze(0) + if use_viewer: + time.sleep(1) + plt.show() + return (tensor_img,) + + def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=1/1, base_xval=1, zval=3): + from mpl_toolkits.mplot3d.art3d import Poly3DCollection + vertex_std = np.array([[0, 0, 0, 1], + [base_xval, -base_xval * hw_ratio, zval, 1], + [base_xval, base_xval * hw_ratio, zval, 1], + [-base_xval, base_xval * hw_ratio, zval, 1], + [-base_xval, -base_xval * hw_ratio, zval, 1]]) + vertex_transformed = vertex_std @ extrinsic.T + meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]], + [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]], + [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]], + [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]], + [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]]] + + color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map) + + self.ax.add_collection3d( + Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35)) + + def customize_legend(self, list_label): + from matplotlib.patches import Patch + list_handle = [] + for idx, label in enumerate(list_label): + color = plt.cm.rainbow(idx / len(list_label)) + patch = Patch(color=color, label=label) + list_handle.append(patch) + plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle) + + def get_c2w(self, w2cs, transform_matrix, relative_c2w): + if relative_c2w: + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]] + else: + ret_poses = [np.linalg.inv(w2c) for w2c in w2cs] + ret_poses = [transform_matrix @ x for x in ret_poses] + return np.array(ret_poses, dtype=np.float32) + + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -4034,7 +4159,8 @@ NODE_CLASS_MAPPINGS = { "RemapMaskRange": RemapMaskRange, "LoadResAdapterNormalization": LoadResAdapterNormalization, "Superprompt": Superprompt, - "RemapImageRange": RemapImageRange + "RemapImageRange": RemapImageRange, + "CameraPoseVisualizer": CameraPoseVisualizer } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -4107,4 +4233,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadResAdapterNormalization": "LoadResAdapterNormalization", "Superprompt": "Superprompt", "RemapImageRange": "RemapImageRange", + "CameraPoseVisualizer": "CameraPoseVisualizer", } \ No newline at end of file