ComfyUI/comfy_api_nodes/nodes_pixverse.py

432 lines
15 KiB
Python

import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
PixverseTransitionVideoRequest,
PixverseImageUploadResponse,
PixverseVideoResponse,
PixverseGenerationStatusResponse,
PixverseAspectRatio,
PixverseQuality,
PixverseDuration,
PixverseMotionMode,
PixverseStatus,
PixverseIO,
pixverse_templates,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
poll_op,
sync_op,
tensor_to_bytesio,
validate_string,
)
AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52
async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
response_upload = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
response_model=PixverseImageUploadResponse,
files={"image": tensor_to_bytesio(image)},
content_type="multipart/form-data",
)
if response_upload.Resp is None:
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
return response_upload.Resp.img_id
class PixverseTemplateNode(IO.ComfyNode):
"""
Select template for PixVerse Video generation.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTemplateNode",
display_name="PixVerse Template",
category="api node/video/PixVerse",
inputs=[
IO.Combo.Input("template", options=list(pixverse_templates.keys())),
],
outputs=[IO.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
)
@classmethod
def execute(cls, template: str) -> IO.NodeOutput:
template_id = pixverse_templates.get(template, None)
if template_id is None:
raise Exception(f"Template '{template}' is not recognized.")
return IO.NodeOutput(template_id)
class PixverseTextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
IO.Combo.Input(
"aspect_ratio",
options=PixverseAspectRatio,
),
IO.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
IO.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
IO.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
IO.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
IO.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
aspect_ratio: str,
quality: str,
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str = None,
pixverse_template: int = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p:
motion_mode = PixverseMotionMode.normal
duration_seconds = PixverseDuration.dur_5
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseTextVideoRequest(
prompt=prompt,
aspect_ratio=aspect_ratio,
quality=quality,
duration=duration_seconds,
motion_mode=motion_mode,
negative_prompt=negative_prompt if negative_prompt else None,
template_id=pixverse_template,
seed=seed,
),
)
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
estimated_duration=AVERAGE_DURATION_T2V,
)
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
class PixverseImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
IO.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
IO.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
IO.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
IO.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
IO.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt: str,
quality: str,
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str = None,
pixverse_template: int = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
img_id = await upload_image_to_pixverse(cls, image)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p:
motion_mode = PixverseMotionMode.normal
duration_seconds = PixverseDuration.dur_5
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseImageVideoRequest(
img_id=img_id,
prompt=prompt,
quality=quality,
duration=duration_seconds,
motion_mode=motion_mode,
negative_prompt=negative_prompt if negative_prompt else None,
template_id=pixverse_template,
seed=seed,
),
)
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
estimated_duration=AVERAGE_DURATION_I2V,
)
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
class PixverseTransitionVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("first_frame"),
IO.Image.Input("last_frame"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
IO.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
IO.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
IO.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
IO.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
prompt: str,
quality: str,
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
first_frame_id = await upload_image_to_pixverse(cls, first_frame)
last_frame_id = await upload_image_to_pixverse(cls, last_frame)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p:
motion_mode = PixverseMotionMode.normal
duration_seconds = PixverseDuration.dur_5
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseTransitionVideoRequest(
first_frame_img=first_frame_id,
last_frame_img=last_frame_id,
prompt=prompt,
quality=quality,
duration=duration_seconds,
motion_mode=motion_mode,
negative_prompt=negative_prompt if negative_prompt else None,
seed=seed,
),
)
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
estimated_duration=AVERAGE_DURATION_T2V,
)
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
class PixVerseExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PixverseTextToVideoNode,
PixverseImageToVideoNode,
PixverseTransitionVideoNode,
PixverseTemplateNode,
]
async def comfy_entrypoint() -> PixVerseExtension:
return PixVerseExtension()