mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-15 09:33:31 +08:00
Add CameraPoseVisualizer
This commit is contained in:
parent
883f2f48e1
commit
6f88198cc5
131
nodes.py
131
nodes.py
@ -3953,7 +3953,7 @@ class RemapImageRange:
|
|||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "remap"
|
FUNCTION = "remap"
|
||||||
|
|
||||||
CATEGORY = "Marigold"
|
CATEGORY = "KJNodes"
|
||||||
|
|
||||||
def remap(self, image, min, max, clamp):
|
def remap(self, image, min, max, clamp):
|
||||||
if image.dtype == torch.float16:
|
if image.dtype == torch.float16:
|
||||||
@ -3963,6 +3963,131 @@ class RemapImageRange:
|
|||||||
image = torch.clamp(image, min=0.0, max=1.0)
|
image = torch.clamp(image, min=0.0, max=1.0)
|
||||||
return (image, )
|
return (image, )
|
||||||
|
|
||||||
|
class CameraPoseVisualizer:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"pose_file_path": ("STRING", {"default": 'pose file path here', "multiline": False}),
|
||||||
|
"sample_stride": ("INT", {"default": 1,"min": 0, "max": 100, "step": 1}),
|
||||||
|
"frames": ("INT", {"default": 16,"min": 0, "max": 100, "step": 1}),
|
||||||
|
"base_xval": ("FLOAT", {"default": 0.5,"min": 0, "max": 100, "step": 0.01}),
|
||||||
|
"zval": ("FLOAT", {"default": 2.0,"min": 0, "max": 100, "step": 0.01}),
|
||||||
|
"use_exact_fx": ("BOOLEAN", {"default": True}),
|
||||||
|
"relative_c2w": ("BOOLEAN", {"default": True}),
|
||||||
|
"x_min": ("FLOAT", {"default": -5.0,"min": -100, "max": 100, "step": 0.01}),
|
||||||
|
"x_max": ("FLOAT", {"default": 5.0,"min": -100, "max": 100, "step": 0.01}),
|
||||||
|
"y_min": ("FLOAT", {"default": -5.0,"min": -100, "max": 100, "step": 0.01}),
|
||||||
|
"y_max": ("FLOAT", {"default": 5.0,"min": -100, "max": 100, "step": 0.01}),
|
||||||
|
"z_min": ("FLOAT", {"default": -5.0,"min": -100, "max": 100, "step": 0.01}),
|
||||||
|
"z_max": ("FLOAT", {"default": 5.0,"min": -100, "max": 100, "step": 0.01}),
|
||||||
|
"use_viewer": ("BOOLEAN", {"default": False}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "plot"
|
||||||
|
|
||||||
|
CATEGORY = "KJNodes"
|
||||||
|
|
||||||
|
def plot(self, pose_file_path, sample_stride, frames, base_xval, zval, use_exact_fx, relative_c2w, x_min, x_max, y_min, y_max, z_min, z_max, use_viewer):
|
||||||
|
import matplotlib as mpl
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import io
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
|
self.fig = plt.figure(figsize=(18, 7))
|
||||||
|
self.ax = self.fig.add_subplot(projection='3d')
|
||||||
|
self.plotly_data = None # plotly data traces
|
||||||
|
self.ax.set_aspect("auto")
|
||||||
|
self.ax.set_xlim(x_min, x_max)
|
||||||
|
self.ax.set_ylim(y_min, y_max)
|
||||||
|
self.ax.set_zlim(z_min, z_max)
|
||||||
|
self.ax.set_xlabel('x')
|
||||||
|
self.ax.set_ylabel('y')
|
||||||
|
self.ax.set_zlabel('z')
|
||||||
|
print('initialize camera pose visualizer')
|
||||||
|
with open(pose_file_path, 'r') as f:
|
||||||
|
poses = f.readlines()
|
||||||
|
w2cs = [np.asarray([float(p) for p in pose.strip().split(' ')[7:]]).reshape(3, 4) for pose in poses[1:]]
|
||||||
|
fxs = [float(pose.strip().split(' ')[1]) for pose in poses[1:]]
|
||||||
|
|
||||||
|
cropped_length = frames * sample_stride
|
||||||
|
total_frames = len(w2cs)
|
||||||
|
start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1))
|
||||||
|
end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
|
||||||
|
frame_ind = np.linspace(start_frame_ind, end_frame_ind - 1, frames, dtype=int)
|
||||||
|
w2cs = [w2cs[x] for x in frame_ind]
|
||||||
|
transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4)
|
||||||
|
last_row = np.zeros((1, 4))
|
||||||
|
last_row[0, -1] = 1.0
|
||||||
|
w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs]
|
||||||
|
c2ws = self.get_c2w(w2cs, transform_matrix, relative_c2w)
|
||||||
|
|
||||||
|
for frame_idx, c2w in enumerate(c2ws):
|
||||||
|
self.extrinsic2pyramid(c2w, frame_idx / frames, hw_ratio=1/1, base_xval=base_xval,
|
||||||
|
zval=(fxs[frame_idx] if use_exact_fx else zval))
|
||||||
|
|
||||||
|
cmap = mpl.cm.rainbow
|
||||||
|
norm = mpl.colors.Normalize(vmin=0, vmax=frames)
|
||||||
|
self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax, orientation='vertical', label='Frame Number')
|
||||||
|
plt.title('Extrinsic Parameters')
|
||||||
|
plt.draw()
|
||||||
|
buf = io.BytesIO()
|
||||||
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
||||||
|
buf.seek(0)
|
||||||
|
img = Image.open(buf)
|
||||||
|
tensor_img = ToTensor()(img)
|
||||||
|
buf.close()
|
||||||
|
tensor_img = tensor_img.permute(1, 2, 0).unsqueeze(0)
|
||||||
|
if use_viewer:
|
||||||
|
time.sleep(1)
|
||||||
|
plt.show()
|
||||||
|
return (tensor_img,)
|
||||||
|
|
||||||
|
def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=1/1, base_xval=1, zval=3):
|
||||||
|
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
||||||
|
vertex_std = np.array([[0, 0, 0, 1],
|
||||||
|
[base_xval, -base_xval * hw_ratio, zval, 1],
|
||||||
|
[base_xval, base_xval * hw_ratio, zval, 1],
|
||||||
|
[-base_xval, base_xval * hw_ratio, zval, 1],
|
||||||
|
[-base_xval, -base_xval * hw_ratio, zval, 1]])
|
||||||
|
vertex_transformed = vertex_std @ extrinsic.T
|
||||||
|
meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]],
|
||||||
|
[vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]],
|
||||||
|
[vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]],
|
||||||
|
[vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]],
|
||||||
|
[vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]]]
|
||||||
|
|
||||||
|
color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map)
|
||||||
|
|
||||||
|
self.ax.add_collection3d(
|
||||||
|
Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35))
|
||||||
|
|
||||||
|
def customize_legend(self, list_label):
|
||||||
|
from matplotlib.patches import Patch
|
||||||
|
list_handle = []
|
||||||
|
for idx, label in enumerate(list_label):
|
||||||
|
color = plt.cm.rainbow(idx / len(list_label))
|
||||||
|
patch = Patch(color=color, label=label)
|
||||||
|
list_handle.append(patch)
|
||||||
|
plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle)
|
||||||
|
|
||||||
|
def get_c2w(self, w2cs, transform_matrix, relative_c2w):
|
||||||
|
if relative_c2w:
|
||||||
|
target_cam_c2w = np.array([
|
||||||
|
[1, 0, 0, 0],
|
||||||
|
[0, 1, 0, 0],
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 1]
|
||||||
|
])
|
||||||
|
abs2rel = target_cam_c2w @ w2cs[0]
|
||||||
|
ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]]
|
||||||
|
else:
|
||||||
|
ret_poses = [np.linalg.inv(w2c) for w2c in w2cs]
|
||||||
|
ret_poses = [transform_matrix @ x for x in ret_poses]
|
||||||
|
return np.array(ret_poses, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"INTConstant": INTConstant,
|
"INTConstant": INTConstant,
|
||||||
"FloatConstant": FloatConstant,
|
"FloatConstant": FloatConstant,
|
||||||
@ -4034,7 +4159,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"RemapMaskRange": RemapMaskRange,
|
"RemapMaskRange": RemapMaskRange,
|
||||||
"LoadResAdapterNormalization": LoadResAdapterNormalization,
|
"LoadResAdapterNormalization": LoadResAdapterNormalization,
|
||||||
"Superprompt": Superprompt,
|
"Superprompt": Superprompt,
|
||||||
"RemapImageRange": RemapImageRange
|
"RemapImageRange": RemapImageRange,
|
||||||
|
"CameraPoseVisualizer": CameraPoseVisualizer
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"INTConstant": "INT Constant",
|
"INTConstant": "INT Constant",
|
||||||
@ -4107,4 +4233,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LoadResAdapterNormalization": "LoadResAdapterNormalization",
|
"LoadResAdapterNormalization": "LoadResAdapterNormalization",
|
||||||
"Superprompt": "Superprompt",
|
"Superprompt": "Superprompt",
|
||||||
"RemapImageRange": "RemapImageRange",
|
"RemapImageRange": "RemapImageRange",
|
||||||
|
"CameraPoseVisualizer": "CameraPoseVisualizer",
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user