Update nodes.py

This commit is contained in:
Kijai 2024-04-22 13:18:04 +03:00
parent a1d759c3cd
commit 76c536d156

View File

@ -4352,19 +4352,12 @@ class CameraPoseVisualizer:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"pose_file_path": ("STRING", {"default": 'pose file path here', "multiline": False}),
"sample_stride": ("INT", {"default": 1,"min": 0, "max": 100, "step": 1}),
"frames": ("INT", {"default": 16,"min": 0, "max": 100, "step": 1}),
"base_xval": ("FLOAT", {"default": 0.5,"min": 0, "max": 100, "step": 0.01}),
"zval": ("FLOAT", {"default": 2.0,"min": 0, "max": 100, "step": 0.01}),
"use_exact_fx": ("BOOLEAN", {"default": True}),
"pose_file_path": ("STRING", {"default": '', "multiline": False}),
"base_xval": ("FLOAT", {"default": 0.2,"min": 0, "max": 100, "step": 0.01}),
"zval": ("FLOAT", {"default": 0.3,"min": 0, "max": 100, "step": 0.01}),
"scale": ("FLOAT", {"default": 1.0,"min": 0.01, "max": 10.0, "step": 0.01}),
"use_exact_fx": ("BOOLEAN", {"default": False}),
"relative_c2w": ("BOOLEAN", {"default": True}),
"x_min": ("FLOAT", {"default": -5.0,"min": -100, "max": 100, "step": 0.01}),
"x_max": ("FLOAT", {"default": 5.0,"min": -100, "max": 100, "step": 0.01}),
"y_min": ("FLOAT", {"default": -5.0,"min": -100, "max": 100, "step": 0.01}),
"y_max": ("FLOAT", {"default": 5.0,"min": -100, "max": 100, "step": 0.01}),
"z_min": ("FLOAT", {"default": -5.0,"min": -100, "max": 100, "step": 0.01}),
"z_max": ("FLOAT", {"default": 5.0,"min": -100, "max": 100, "step": 0.01}),
"use_viewer": ("BOOLEAN", {"default": False}),
},
"optional": {
@ -4376,26 +4369,40 @@ class CameraPoseVisualizer:
FUNCTION = "plot"
CATEGORY = "KJNodes/misc"
DESCRIPTION = """
Visualizes the camera poses from a .txt file with
RealEstate camera intrinsics and coordinates in a 3D plot.
Visualizes the camera poses, from Animatediff-Evolved CameraCtrl Pose
or a .txt file with RealEstate camera intrinsics and coordinates, in a 3D plot.
"""
def plot(self, pose_file_path, sample_stride, frames, base_xval, zval, use_exact_fx, relative_c2w, x_min, x_max, y_min, y_max, z_min, z_max, use_viewer, cameractrl_poses=None):
def plot(self, pose_file_path, scale, base_xval, zval, use_exact_fx, relative_c2w, use_viewer, cameractrl_poses=None):
import matplotlib as mpl
import matplotlib.pyplot as plt
import io
from torchvision.transforms import ToTensor
x_min = -2.0 * scale
x_max = 2.0 * scale
y_min = -2.0 * scale
y_max = 2.0 * scale
z_min = -2.0 * scale
z_max = 2.0 * scale
plt.rcParams['text.color'] = '#999999'
self.fig = plt.figure(figsize=(18, 7))
self.fig.patch.set_facecolor('#353535')
self.ax = self.fig.add_subplot(projection='3d')
self.ax.set_facecolor('#353535') # Set the background color here
self.ax.grid(color='#999999', linestyle='-', linewidth=0.5)
self.plotly_data = None # plotly data traces
self.ax.set_aspect("auto")
self.ax.set_xlim(x_min, x_max)
self.ax.set_ylim(y_min, y_max)
self.ax.set_zlim(z_min, z_max)
self.ax.set_xlabel('x')
self.ax.set_ylabel('y')
self.ax.set_zlabel('z')
self.ax.set_xlabel('x', color='#999999')
self.ax.set_ylabel('y', color='#999999')
self.ax.set_zlabel('z', color='#999999')
for text in self.ax.get_xticklabels() + self.ax.get_yticklabels() + self.ax.get_zticklabels():
text.set_color('#999999')
print('initialize camera pose visualizer')
if pose_file_path != "":
with open(pose_file_path, 'r') as f:
poses = f.readlines()
@ -4409,26 +4416,35 @@ RealEstate camera intrinsics and coordinates in a 3D plot.
else:
raise ValueError("Please provide either pose_file_path or cameractrl_poses")
cropped_length = frames * sample_stride
total_frames = len(w2cs)
start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1))
end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
frame_ind = np.linspace(start_frame_ind, end_frame_ind - 1, frames, dtype=int)
w2cs = [w2cs[x] for x in frame_ind]
transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4)
last_row = np.zeros((1, 4))
last_row[0, -1] = 1.0
w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs]
c2ws = self.get_c2w(w2cs, transform_matrix, relative_c2w)
for frame_idx, c2w in enumerate(c2ws):
self.extrinsic2pyramid(c2w, frame_idx / frames, hw_ratio=1/1, base_xval=base_xval,
zval=(fxs[frame_idx] if use_exact_fx else zval))
self.extrinsic2pyramid(c2w, frame_idx / total_frames, hw_ratio=1/1, base_xval=base_xval,
zval=(fxs[frame_idx] if use_exact_fx else zval))
# Create the colorbar
cmap = mpl.cm.rainbow
norm = mpl.colors.Normalize(vmin=0, vmax=frames)
self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax, orientation='vertical', label='Frame Number')
plt.title('Extrinsic Parameters')
norm = mpl.colors.Normalize(vmin=0, vmax=total_frames)
colorbar = self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax, orientation='vertical')
# Change the colorbar label
colorbar.set_label('Frame', color='#999999') # Change the label and its color
# Change the tick colors
colorbar.ax.yaxis.set_tick_params(colors='#999999') # Change the tick color
# Change the tick frequency
# Assuming you want to set the ticks at every 10th frame
ticks = np.arange(0, total_frames, 10)
colorbar.ax.yaxis.set_ticks(ticks)
plt.title('')
plt.draw()
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
@ -4459,7 +4475,7 @@ RealEstate camera intrinsics and coordinates in a 3D plot.
color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map)
self.ax.add_collection3d(
Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35))
Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.25))
def customize_legend(self, list_label):
from matplotlib.patches import Patch