From 2ed74f7ac78d3ff713d0a8583695c31055914b76 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 4 Oct 2025 22:29:09 +0300 Subject: [PATCH] convert nodes_rodin.py to V3 schema (#10195) --- comfy_api_nodes/nodes_rodin.py | 941 +++++++++++++++++---------------- 1 file changed, 478 insertions(+), 463 deletions(-) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 633ac46d3..bd758f762 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -7,14 +7,15 @@ Rodin API docs: https://developer.hyper3d.ai/ from __future__ import annotations from inspect import cleandoc -from comfy.comfy_types.node_typing import IO import folder_paths as comfy_paths import aiohttp import os import asyncio -import io import logging import math +from typing import Optional +from io import BytesIO +from typing_extensions import override from PIL import Image from comfy_api_nodes.apis.rodin_api import ( Rodin3DGenerateRequest, @@ -31,428 +32,436 @@ from comfy_api_nodes.apis.client import ( SynchronousOperation, PollingOperation, ) +from comfy_api.latest import ComfyExtension, io as comfy_io -COMMON_PARAMETERS = { - "Seed": ( - IO.INT, - { - "default":0, - "min":0, - "max":65535, - "display":"number" - } +COMMON_PARAMETERS = [ + comfy_io.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "Material_Type": ( - IO.COMBO, - { - "options": ["PBR", "Shaded"], - "default": "PBR" - } + comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True), + comfy_io.Combo.Input( + "Polygon_count", + options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"], + default="18K-Quad", + optional=True, ), - "Polygon_count": ( - IO.COMBO, - { - "options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"], - "default": "18K-Quad" - } +] + + +def get_quality_mode(poly_count): + polycount = poly_count.split("-") + poly = polycount[1] + count = polycount[0] + if poly == "Triangle": + mesh_mode = "Raw" + elif poly == "Quad": + mesh_mode = "Quad" + else: + mesh_mode = "Quad" + + if count == "4K": + quality_override = 4000 + elif count == "8K": + quality_override = 8000 + elif count == "18K": + quality_override = 18000 + elif count == "50K": + quality_override = 50000 + elif count == "2K": + quality_override = 2000 + elif count == "20K": + quality_override = 20000 + elif count == "150K": + quality_override = 150000 + elif count == "500K": + quality_override = 500000 + else: + quality_override = 18000 + + return mesh_mode, quality_override + + +def tensor_to_filelike(tensor, max_pixels: int = 2048*2048): + """ + Converts a PyTorch tensor to a file-like object. + + Args: + - tensor (torch.Tensor): A tensor representing an image of shape (H, W, C) + where C is the number of channels (3 for RGB), H is height, and W is width. + + Returns: + - io.BytesIO: A file-like object containing the image data. + """ + array = tensor.cpu().numpy() + array = (array * 255).astype('uint8') + image = Image.fromarray(array, 'RGB') + + original_width, original_height = image.size + original_pixels = original_width * original_height + if original_pixels > max_pixels: + scale = math.sqrt(max_pixels / original_pixels) + new_width = int(original_width * scale) + new_height = int(original_height * scale) + else: + new_width, new_height = original_width, original_height + + if new_width != original_width or new_height != original_height: + image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + + img_byte_arr = BytesIO() + image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression + img_byte_arr.seek(0) + return img_byte_arr + + +async def create_generate_task( + images=None, + seed=1, + material="PBR", + quality_override=18000, + tier="Regular", + mesh_mode="Quad", + TAPose = False, + auth_kwargs: Optional[dict[str, str]] = None, +): + if images is None: + raise Exception("Rodin 3D generate requires at least 1 image.") + if len(images) > 5: + raise Exception("Rodin 3D generate requires up to 5 image.") + + path = "/proxy/rodin/api/v2/rodin" + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=Rodin3DGenerateRequest, + response_model=Rodin3DGenerateResponse, + ), + request=Rodin3DGenerateRequest( + seed=seed, + tier=tier, + material=material, + quality_override=quality_override, + mesh_mode=mesh_mode, + TAPose=TAPose, + ), + files=[ + ( + "images", + open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image) + ) + for image in images if image is not None + ], + content_type="multipart/form-data", + auth_kwargs=auth_kwargs, ) -} -def create_task_error(response: Rodin3DGenerateResponse): - """Check if the response has error""" - return hasattr(response, "error") + response = await operation.execute() + + if hasattr(response, "error"): + error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" + logging.error(error_message) + raise Exception(error_message) + + logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!") + subscription_key = response.jobs.subscription_key + task_uuid = response.uuid + logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") + return task_uuid, subscription_key -class Rodin3DAPI: - """ - Generate 3D Assets using Rodin API - """ - RETURN_TYPES = (IO.STRING,) - RETURN_NAMES = ("3D Model Path",) - CATEGORY = "api node/3d/Rodin" - DESCRIPTION = cleandoc(__doc__ or "") - FUNCTION = "api_call" - API_NODE = True - - def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048): - """ - Converts a PyTorch tensor to a file-like object. - - Args: - - tensor (torch.Tensor): A tensor representing an image of shape (H, W, C) - where C is the number of channels (3 for RGB), H is height, and W is width. - - Returns: - - io.BytesIO: A file-like object containing the image data. - """ - array = tensor.cpu().numpy() - array = (array * 255).astype('uint8') - image = Image.fromarray(array, 'RGB') - - original_width, original_height = image.size - original_pixels = original_width * original_height - if original_pixels > max_pixels: - scale = math.sqrt(max_pixels / original_pixels) - new_width = int(original_width * scale) - new_height = int(original_height * scale) - else: - new_width, new_height = original_width, original_height - - if new_width != original_width or new_height != original_height: - image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) - - img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression - img_byte_arr.seek(0) - return img_byte_arr - - def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str: - has_failed = any(job.status == JobStatus.Failed for job in response.jobs) - all_done = all(job.status == JobStatus.Done for job in response.jobs) - status_list = [str(job.status) for job in response.jobs] - logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}") - if has_failed: - logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.") - raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.") - elif all_done: - return "DONE" - else: - return "Generating" - - async def create_generate_task(self, images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", TAPose = False, **kwargs): - if images is None: - raise Exception("Rodin 3D generate requires at least 1 image.") - if len(images) > 5: - raise Exception("Rodin 3D generate requires up to 5 image.") - - path = "/proxy/rodin/api/v2/rodin" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DGenerateRequest, - response_model=Rodin3DGenerateResponse, - ), - request=Rodin3DGenerateRequest( - seed=seed, - tier=tier, - material=material, - quality_override=quality_override, - mesh_mode=mesh_mode, - TAPose=TAPose, - ), - files=[ - ( - "images", - open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image) - ) - for image in images if image is not None - ], - content_type = "multipart/form-data", - auth_kwargs=kwargs, - ) - - response = await operation.execute() - - if create_task_error(response): - error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" - logging.error(error_message) - raise Exception(error_message) - - logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!") - subscription_key = response.jobs.subscription_key - task_uuid = response.uuid - logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") - return task_uuid, subscription_key - - async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: - - path = "/proxy/rodin/api/v2/status" - - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path = path, - method=HttpMethod.POST, - request_model=Rodin3DCheckStatusRequest, - response_model=Rodin3DCheckStatusResponse, - ), - request=Rodin3DCheckStatusRequest( - subscription_key = subscription_key - ), - completed_statuses=["DONE"], - failed_statuses=["FAILED"], - status_extractor=self.check_rodin_status, - poll_interval=3.0, - auth_kwargs=kwargs, - ) - - logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - - return await poll_operation.execute() - - async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse: - logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") - - path = "/proxy/rodin/api/v2/download" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DDownloadRequest, - response_model=Rodin3DDownloadResponse, - ), - request=Rodin3DDownloadRequest( - task_uuid=uuid - ), - auth_kwargs=kwargs - ) - - return await operation.execute() - - def get_quality_mode(self, poly_count): - polycount = poly_count.split("-") - poly = polycount[1] - count = polycount[0] - if poly == "Triangle": - mesh_mode = "Raw" - elif poly == "Quad": - mesh_mode = "Quad" - else: - mesh_mode = "Quad" - - if count == "4K": - quality_override = 4000 - elif count == "8K": - quality_override = 8000 - elif count == "18K": - quality_override = 18000 - elif count == "50K": - quality_override = 50000 - elif count == "2K": - quality_override = 2000 - elif count == "20K": - quality_override = 20000 - elif count == "150K": - quality_override = 150000 - elif count == "500K": - quality_override = 500000 - else: - quality_override = 18000 - - return mesh_mode, quality_override - - async def download_files(self, url_list, task_uuid): - save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}") - os.makedirs(save_path, exist_ok=True) - model_file_path = None - async with aiohttp.ClientSession() as session: - for i in url_list.list: - url = i.url - file_name = i.name - file_path = os.path.join(save_path, file_name) - if file_path.endswith(".glb"): - model_file_path = file_path - logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") - max_retries = 5 - for attempt in range(max_retries): - try: - async with session.get(url) as resp: - resp.raise_for_status() - with open(file_path, "wb") as f: - async for chunk in resp.content.iter_chunked(32 * 1024): - f.write(chunk) - break - except Exception as e: - logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") - if attempt < max_retries - 1: - logging.info("Retrying...") - await asyncio.sleep(2) - else: - logging.info( - "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", - file_path, - max_retries, - ) - - return model_file_path +def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: + all_done = all(job.status == JobStatus.Done for job in response.jobs) + status_list = [str(job.status) for job in response.jobs] + logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}") + if any(job.status == JobStatus.Failed for job in response.jobs): + logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.") + raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.") + if all_done: + return "DONE" + return "Generating" -class Rodin3D_Regular(Rodin3DAPI): +async def poll_for_task_status( + subscription_key, auth_kwargs: Optional[dict[str, str]] = None, +) -> Rodin3DCheckStatusResponse: + poll_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path="/proxy/rodin/api/v2/status", + method=HttpMethod.POST, + request_model=Rodin3DCheckStatusRequest, + response_model=Rodin3DCheckStatusResponse, + ), + request=Rodin3DCheckStatusRequest(subscription_key=subscription_key), + completed_statuses=["DONE"], + failed_statuses=["FAILED"], + status_extractor=check_rodin_status, + poll_interval=3.0, + auth_kwargs=auth_kwargs, + ) + logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") + return await poll_operation.execute() + + +async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse: + logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/rodin/api/v2/download", + method=HttpMethod.POST, + request_model=Rodin3DDownloadRequest, + response_model=Rodin3DDownloadResponse, + ), + request=Rodin3DDownloadRequest(task_uuid=uuid), + auth_kwargs=auth_kwargs, + ) + return await operation.execute() + + +async def download_files(url_list, task_uuid): + save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}") + os.makedirs(save_path, exist_ok=True) + model_file_path = None + async with aiohttp.ClientSession() as session: + for i in url_list.list: + url = i.url + file_name = i.name + file_path = os.path.join(save_path, file_name) + if file_path.endswith(".glb"): + model_file_path = file_path + logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") + max_retries = 5 + for attempt in range(max_retries): + try: + async with session.get(url) as resp: + resp.raise_for_status() + with open(file_path, "wb") as f: + async for chunk in resp.content.iter_chunked(32 * 1024): + f.write(chunk) + break + except Exception as e: + logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") + if attempt < max_retries - 1: + logging.info("Retrying...") + await asyncio.sleep(2) + else: + logging.info( + "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", + file_path, + max_retries, + ) + return model_file_path + + +class Rodin3D_Regular(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Regular", + display_name="Rodin 3D Generate - Regular Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) - async def api_call( - self, + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Regular" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list, task_uuid) - - return (model,) - - -class Rodin3D_Detail(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Detail(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Detail", + display_name="Rodin 3D Generate - Detail Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Detail" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list, task_uuid) - - return (model,) - - -class Rodin3D_Smooth(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Smooth(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Smooth", + display_name="Rodin 3D Generate - Smooth Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Smooth" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list, task_uuid) - - return (model,) - - -class Rodin3D_Sketch(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - "Seed": - ( - IO.INT, - { - "default":0, - "min":0, - "max":65535, - "display":"number" - } - ) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Sketch(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Sketch", + display_name="Rodin 3D Generate - Sketch Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + comfy_io.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=comfy_io.NumberDisplay.number, + optional=True, + ), + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Sketch" num_images = Images.shape[0] m_images = [] @@ -461,104 +470,110 @@ class Rodin3D_Sketch(Rodin3DAPI): material_type = "PBR" quality_override = 18000 mesh_mode = "Quad" - task_uuid, subscription_key = await self.create_generate_task( - images=m_images, seed=Seed, material=material_type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs - ) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list, task_uuid) - - return (model,) - -class Rodin3D_Gen2(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - "Seed": ( - IO.INT, - { - "default":0, - "min":0, - "max":65535, - "display":"number" - } - ), - "Material_Type": ( - IO.COMBO, - { - "options": ["PBR", "Shaded"], - "default": "PBR" - } - ), - "Polygon_count": ( - IO.COMBO, - { - "options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"], - "default": "500K-Triangle" - } - ), - "TAPose": ( - IO.BOOLEAN, - { - "default": False, - } - ) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=material_type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Gen2(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Gen2", + display_name="Rodin 3D Generate - Gen-2 Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + comfy_io.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=comfy_io.NumberDisplay.number, + optional=True, + ), + comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True), + comfy_io.Combo.Input( + "Polygon_count", + options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"], + default="500K-Triangle", + optional=True, + ), + comfy_io.Boolean.Input("TAPose", default=False), + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, TAPose, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Gen-2" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, TAPose=TAPose, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list, task_uuid) + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + TAPose=TAPose, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - return (model,) + return comfy_io.NodeOutput(model) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "Rodin3D_Regular": Rodin3D_Regular, - "Rodin3D_Detail": Rodin3D_Detail, - "Rodin3D_Smooth": Rodin3D_Smooth, - "Rodin3D_Sketch": Rodin3D_Sketch, - "Rodin3D_Gen2": Rodin3D_Gen2, -} -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "Rodin3D_Regular": "Rodin 3D Generate - Regular Generate", - "Rodin3D_Detail": "Rodin 3D Generate - Detail Generate", - "Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate", - "Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate", - "Rodin3D_Gen2": "Rodin 3D Generate - Gen-2 Generate", -} +class Rodin3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + Rodin3D_Regular, + Rodin3D_Detail, + Rodin3D_Smooth, + Rodin3D_Sketch, + Rodin3D_Gen2, + ] + + +async def comfy_entrypoint() -> Rodin3DExtension: + return Rodin3DExtension()