ComfyUI/comfy_api_nodes/nodes_tripo.py
2025-10-24 15:48:34 -07:00

646 lines
24 KiB
Python

import os
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.tripo_api import (
TripoAnimateRetargetRequest,
TripoAnimateRigRequest,
TripoConvertModelRequest,
TripoFileEmptyReference,
TripoFileReference,
TripoImageToModelRequest,
TripoModelVersion,
TripoMultiviewToModelRequest,
TripoOrientation,
TripoRefineModelRequest,
TripoStyle,
TripoTaskResponse,
TripoTaskStatus,
TripoTaskType,
TripoTextToModelRequest,
TripoTextureModelRequest,
TripoUrlReference,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_as_bytesio,
poll_op,
sync_op,
upload_images_to_comfyapi,
)
from folder_paths import get_output_directory
def get_model_url_from_response(response: TripoTaskResponse) -> str:
if response.data is not None:
for key in ["pbr_model", "model", "base_model"]:
if getattr(response.data.output, key, None) is not None:
return getattr(response.data.output, key)
raise RuntimeError(f"Failed to get model url from response: {response}")
async def poll_until_finished(
node_cls: type[IO.ComfyNode],
response: TripoTaskResponse,
average_duration: Optional[int] = None,
) -> IO.NodeOutput:
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
if response.code != 0:
raise RuntimeError(f"Failed to generate mesh: {response.error}")
task_id = response.data.task_id
response_poll = await poll_op(
node_cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"),
response_model=TripoTaskResponse,
completed_statuses=[TripoTaskStatus.SUCCESS],
failed_statuses=[
TripoTaskStatus.FAILED,
TripoTaskStatus.CANCELLED,
TripoTaskStatus.UNKNOWN,
TripoTaskStatus.BANNED,
TripoTaskStatus.EXPIRED,
],
status_extractor=lambda x: x.data.status,
progress_extractor=lambda x: x.data.progress,
estimated_duration=average_duration,
)
if response_poll.data.status == TripoTaskStatus.SUCCESS:
url = get_model_url_from_response(response_poll)
bytesio = await download_url_as_bytesio(url)
# Save the downloaded model file
model_file = f"tripo_model_{task_id}.glb"
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
f.write(bytesio.getvalue())
return IO.NodeOutput(model_file, task_id)
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
class TripoTextToModelNode(IO.ComfyNode):
"""
Generates 3D models synchronously based on a text prompt using Tripo's API.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoTextToModelNode",
display_name="Tripo: Text to Model",
category="api node/3d/Tripo",
inputs=[
IO.String.Input("prompt", multiline=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
IO.Combo.Input(
"model_version", options=TripoModelVersion, default=TripoModelVersion.v2_5_20250123, optional=True
),
IO.Combo.Input("style", options=TripoStyle, default="None", optional=True),
IO.Boolean.Input("texture", default=True, optional=True),
IO.Boolean.Input("pbr", default=True, optional=True),
IO.Int.Input("image_seed", default=42, optional=True),
IO.Int.Input("model_seed", default=42, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_output_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: Optional[str] = None,
model_version=None,
style: Optional[str] = None,
texture: Optional[bool] = None,
pbr: Optional[bool] = None,
image_seed: Optional[int] = None,
model_seed: Optional[int] = None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
) -> IO.NodeOutput:
style_enum = None if style == "None" else style
if not prompt:
raise RuntimeError("Prompt is required")
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=TripoTextToModelRequest(
type=TripoTaskType.TEXT_TO_MODEL,
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
model_version=model_version,
style=style_enum,
texture=texture,
pbr=pbr,
image_seed=image_seed,
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
auto_size=True,
quad=quad,
),
)
return await poll_until_finished(cls, response, average_duration=80)
class TripoImageToModelNode(IO.ComfyNode):
"""
Generates 3D models synchronously based on a single image using Tripo's API.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoImageToModelNode",
display_name="Tripo: Image to Model",
category="api node/3d/Tripo",
inputs=[
IO.Image.Input("image"),
IO.Combo.Input(
"model_version",
options=TripoModelVersion,
tooltip="The model version to use for generation",
optional=True,
),
IO.Combo.Input("style", options=TripoStyle, default="None", optional=True),
IO.Boolean.Input("texture", default=True, optional=True),
IO.Boolean.Input("pbr", default=True, optional=True),
IO.Int.Input("model_seed", default=42, optional=True),
IO.Combo.Input(
"orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True
),
IO.Int.Input("texture_seed", default=42, optional=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
IO.Combo.Input(
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_output_node=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
model_version: Optional[str] = None,
style: Optional[str] = None,
texture: Optional[bool] = None,
pbr: Optional[bool] = None,
model_seed: Optional[int] = None,
orientation=None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
) -> IO.NodeOutput:
style_enum = None if style == "None" else style
if image is None:
raise RuntimeError("Image is required")
tripo_file = TripoFileReference(
root=TripoUrlReference(
url=(await upload_images_to_comfyapi(cls, image, max_images=1))[0],
type="jpeg",
)
)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=TripoImageToModelRequest(
type=TripoTaskType.IMAGE_TO_MODEL,
file=tripo_file,
model_version=model_version,
style=style_enum,
texture=texture,
pbr=pbr,
model_seed=model_seed,
orientation=orientation,
texture_alignment=texture_alignment,
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
auto_size=True,
quad=quad,
),
)
return await poll_until_finished(cls, response, average_duration=80)
class TripoMultiviewToModelNode(IO.ComfyNode):
"""
Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoMultiviewToModelNode",
display_name="Tripo: Multiview to Model",
category="api node/3d/Tripo",
inputs=[
IO.Image.Input("image"),
IO.Image.Input("image_left", optional=True),
IO.Image.Input("image_back", optional=True),
IO.Image.Input("image_right", optional=True),
IO.Combo.Input(
"model_version",
options=TripoModelVersion,
optional=True,
tooltip="The model version to use for generation",
),
IO.Combo.Input(
"orientation",
options=TripoOrientation,
default=TripoOrientation.DEFAULT,
optional=True,
),
IO.Boolean.Input("texture", default=True, optional=True),
IO.Boolean.Input("pbr", default=True, optional=True),
IO.Int.Input("model_seed", default=42, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
IO.Combo.Input(
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_output_node=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
image_left: Optional[torch.Tensor] = None,
image_back: Optional[torch.Tensor] = None,
image_right: Optional[torch.Tensor] = None,
model_version: Optional[str] = None,
orientation: Optional[str] = None,
texture: Optional[bool] = None,
pbr: Optional[bool] = None,
model_seed: Optional[int] = None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
) -> IO.NodeOutput:
if image is None:
raise RuntimeError("front image for multiview is required")
images = []
image_dict = {"image": image, "image_left": image_left, "image_back": image_back, "image_right": image_right}
if image_left is None and image_back is None and image_right is None:
raise RuntimeError("At least one of left, back, or right image must be provided for multiview")
for image_name in ["image", "image_left", "image_back", "image_right"]:
image_ = image_dict[image_name]
if image_ is not None:
images.append(
TripoFileReference(
root=TripoUrlReference(
url=(await upload_images_to_comfyapi(cls, image_, max_images=1))[0], type="jpeg"
)
)
)
else:
images.append(TripoFileEmptyReference())
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=TripoMultiviewToModelRequest(
type=TripoTaskType.MULTIVIEW_TO_MODEL,
files=images,
model_version=model_version,
orientation=orientation,
texture=texture,
pbr=pbr,
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
texture_alignment=texture_alignment,
face_limit=face_limit,
quad=quad,
),
)
return await poll_until_finished(cls, response, average_duration=80)
class TripoTextureNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoTextureNode",
display_name="Tripo: Texture model",
category="api node/3d/Tripo",
inputs=[
IO.Custom("MODEL_TASK_ID").Input("model_task_id"),
IO.Boolean.Input("texture", default=True, optional=True),
IO.Boolean.Input("pbr", default=True, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
IO.Combo.Input(
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True
),
],
outputs=[
IO.String.Output(display_name="model_file"),
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_output_node=True,
)
@classmethod
async def execute(
cls,
model_task_id,
texture: Optional[bool] = None,
pbr: Optional[bool] = None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
texture_alignment: Optional[str] = None,
) -> IO.NodeOutput:
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=TripoTextureModelRequest(
original_model_task_id=model_task_id,
texture=texture,
pbr=pbr,
texture_seed=texture_seed,
texture_quality=texture_quality,
texture_alignment=texture_alignment,
),
)
return await poll_until_finished(cls, response, average_duration=80)
class TripoRefineNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoRefineNode",
display_name="Tripo: Refine Draft model",
category="api node/3d/Tripo",
description="Refine a draft model created by v1.4 Tripo models only.",
inputs=[
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
],
outputs=[
IO.String.Output(display_name="model_file"),
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_output_node=True,
)
@classmethod
async def execute(cls, model_task_id) -> IO.NodeOutput:
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=TripoRefineModelRequest(draft_model_task_id=model_task_id),
)
return await poll_until_finished(cls, response, average_duration=240)
class TripoRigNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoRigNode",
display_name="Tripo: Rig model",
category="api node/3d/Tripo",
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
outputs=[
IO.String.Output(display_name="model_file"),
IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_output_node=True,
)
@classmethod
async def execute(cls, original_model_task_id) -> IO.NodeOutput:
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=TripoAnimateRigRequest(original_model_task_id=original_model_task_id, out_format="glb", spec="tripo"),
)
return await poll_until_finished(cls, response, average_duration=180)
class TripoRetargetNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoRetargetNode",
display_name="Tripo: Retarget rigged model",
category="api node/3d/Tripo",
inputs=[
IO.Custom("RIG_TASK_ID").Input("original_model_task_id"),
IO.Combo.Input(
"animation",
options=[
"preset:idle",
"preset:walk",
"preset:climb",
"preset:jump",
"preset:slash",
"preset:shoot",
"preset:hurt",
"preset:fall",
"preset:turn",
],
),
],
outputs=[
IO.String.Output(display_name="model_file"),
IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_output_node=True,
)
@classmethod
async def execute(cls, original_model_task_id, animation: str) -> IO.NodeOutput:
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=TripoAnimateRetargetRequest(
original_model_task_id=original_model_task_id,
animation=animation,
out_format="glb",
bake_animation=True,
),
)
return await poll_until_finished(cls, response, average_duration=30)
class TripoConversionNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoConversionNode",
display_name="Tripo: Convert model",
category="api node/3d/Tripo",
inputs=[
IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"),
IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Int.Input(
"face_limit",
default=-1,
min=-1,
max=500000,
optional=True,
),
IO.Int.Input(
"texture_size",
default=4096,
min=128,
max=4096,
optional=True,
),
IO.Combo.Input(
"texture_format",
options=["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"],
default="JPEG",
optional=True,
),
],
outputs=[],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_output_node=True,
)
@classmethod
def validate_inputs(cls, input_types):
# The min and max of input1 and input2 are still validated because
# we didn't take `input1` or `input2` as arguments
if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"):
return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type"
return True
@classmethod
async def execute(
cls,
original_model_task_id,
format: str,
quad: bool,
face_limit: int,
texture_size: int,
texture_format: str,
) -> IO.NodeOutput:
if not original_model_task_id:
raise RuntimeError("original_model_task_id is required")
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=TripoConvertModelRequest(
original_model_task_id=original_model_task_id,
format=format,
quad=quad if quad else None,
face_limit=face_limit if face_limit != -1 else None,
texture_size=texture_size if texture_size != 4096 else None,
texture_format=texture_format if texture_format != "JPEG" else None,
),
)
return await poll_until_finished(cls, response, average_duration=30)
class TripoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TripoTextToModelNode,
TripoImageToModelNode,
TripoMultiviewToModelNode,
TripoTextureNode,
TripoRefineNode,
TripoRigNode,
TripoRetargetNode,
TripoConversionNode,
]
async def comfy_entrypoint() -> TripoExtension:
return TripoExtension()