mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
feat(api-nodes): add Nano Banana Pro (#10814)
* feat(api-nodes): add Nano Banana Pro * frontend bump to 1.28.9
This commit is contained in:
parent
9e00ce5b76
commit
7b8389578e
@ -68,7 +68,7 @@ class GeminiTextPart(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class GeminiContent(BaseModel):
|
class GeminiContent(BaseModel):
|
||||||
parts: list[GeminiPart] = Field(...)
|
parts: list[GeminiPart] = Field([])
|
||||||
role: GeminiRole = Field(..., examples=["user"])
|
role: GeminiRole = Field(..., examples=["user"])
|
||||||
|
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ class GeminiGenerationConfig(BaseModel):
|
|||||||
|
|
||||||
class GeminiImageConfig(BaseModel):
|
class GeminiImageConfig(BaseModel):
|
||||||
aspectRatio: str | None = Field(None)
|
aspectRatio: str | None = Field(None)
|
||||||
resolution: str | None = Field(None)
|
imageSize: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||||
@ -227,3 +227,4 @@ class GeminiGenerateContentResponse(BaseModel):
|
|||||||
candidates: list[GeminiCandidate] | None = Field(None)
|
candidates: list[GeminiCandidate] | None = Field(None)
|
||||||
promptFeedback: GeminiPromptFeedback | None = Field(None)
|
promptFeedback: GeminiPromptFeedback | None = Field(None)
|
||||||
usageMetadata: GeminiUsageMetadata | None = Field(None)
|
usageMetadata: GeminiUsageMetadata | None = Field(None)
|
||||||
|
modelVersion: str | None = Field(None)
|
||||||
|
|||||||
@ -29,11 +29,13 @@ from comfy_api_nodes.apis.gemini_api import (
|
|||||||
GeminiMimeType,
|
GeminiMimeType,
|
||||||
GeminiPart,
|
GeminiPart,
|
||||||
GeminiRole,
|
GeminiRole,
|
||||||
|
Modality,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
audio_to_base64_string,
|
audio_to_base64_string,
|
||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
|
get_number_of_images,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
validate_string,
|
validate_string,
|
||||||
@ -147,6 +149,49 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te
|
|||||||
return torch.cat(image_tensors, dim=0)
|
return torch.cat(image_tensors, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None:
|
||||||
|
if not response.modelVersion:
|
||||||
|
return None
|
||||||
|
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
|
||||||
|
if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"):
|
||||||
|
input_tokens_price = 1.25
|
||||||
|
output_text_tokens_price = 10.0
|
||||||
|
output_image_tokens_price = 0.0
|
||||||
|
elif response.modelVersion in (
|
||||||
|
"gemini-2.5-flash-preview-04-17",
|
||||||
|
"gemini-2.5-flash",
|
||||||
|
):
|
||||||
|
input_tokens_price = 0.30
|
||||||
|
output_text_tokens_price = 2.50
|
||||||
|
output_image_tokens_price = 0.0
|
||||||
|
elif response.modelVersion in (
|
||||||
|
"gemini-2.5-flash-image-preview",
|
||||||
|
"gemini-2.5-flash-image",
|
||||||
|
):
|
||||||
|
input_tokens_price = 0.30
|
||||||
|
output_text_tokens_price = 2.50
|
||||||
|
output_image_tokens_price = 30.0
|
||||||
|
elif response.modelVersion == "gemini-3-pro-preview":
|
||||||
|
input_tokens_price = 2
|
||||||
|
output_text_tokens_price = 12.0
|
||||||
|
output_image_tokens_price = 0.0
|
||||||
|
elif response.modelVersion == "gemini-3-pro-image-preview":
|
||||||
|
input_tokens_price = 2
|
||||||
|
output_text_tokens_price = 12.0
|
||||||
|
output_image_tokens_price = 120.0
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
||||||
|
for i in response.usageMetadata.candidatesTokensDetails:
|
||||||
|
if i.modality == Modality.IMAGE:
|
||||||
|
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
|
||||||
|
else:
|
||||||
|
final_price += output_text_tokens_price * i.tokenCount
|
||||||
|
if response.usageMetadata.thoughtsTokenCount:
|
||||||
|
final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount
|
||||||
|
return final_price / 1_000_000.0
|
||||||
|
|
||||||
|
|
||||||
class GeminiNode(IO.ComfyNode):
|
class GeminiNode(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Node to generate text responses from a Gemini model.
|
Node to generate text responses from a Gemini model.
|
||||||
@ -314,6 +359,7 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
|
price_extractor=calculate_tokens_price,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_text = get_text_from_response(response)
|
output_text = get_text_from_response(response)
|
||||||
@ -476,6 +522,13 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
"or otherwise generates 1:1 squares.",
|
"or otherwise generates 1:1 squares.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"response_modalities",
|
||||||
|
options=["IMAGE+TEXT", "IMAGE"],
|
||||||
|
tooltip="Choose 'IMAGE' for image-only output, or "
|
||||||
|
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
@ -498,6 +551,7 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
images: torch.Tensor | None = None,
|
images: torch.Tensor | None = None,
|
||||||
files: list[GeminiPart] | None = None,
|
files: list[GeminiPart] | None = None,
|
||||||
aspect_ratio: str = "auto",
|
aspect_ratio: str = "auto",
|
||||||
|
response_modalities: str = "IMAGE+TEXT",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||||
@ -520,17 +574,16 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||||
],
|
],
|
||||||
generationConfig=GeminiImageGenerationConfig(
|
generationConfig=GeminiImageGenerationConfig(
|
||||||
responseModalities=["TEXT", "IMAGE"],
|
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
|
price_extractor=calculate_tokens_price,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_image = get_image_from_response(response)
|
|
||||||
output_text = get_text_from_response(response)
|
output_text = get_text_from_response(response)
|
||||||
if output_text:
|
if output_text:
|
||||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
|
||||||
render_spec = {
|
render_spec = {
|
||||||
"node_id": cls.hidden.unique_id,
|
"node_id": cls.hidden.unique_id,
|
||||||
"component": "ChatHistoryWidget",
|
"component": "ChatHistoryWidget",
|
||||||
@ -551,9 +604,150 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
"display_component",
|
"display_component",
|
||||||
render_spec,
|
render_spec,
|
||||||
)
|
)
|
||||||
|
return IO.NodeOutput(get_image_from_response(response), output_text)
|
||||||
|
|
||||||
output_text = output_text or "Empty response from Gemini model..."
|
|
||||||
return IO.NodeOutput(output_image, output_text)
|
class GeminiImage2(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GeminiImage2Node",
|
||||||
|
display_name="Nano Banana Pro (Google Gemini Image)",
|
||||||
|
category="api node/image/Gemini",
|
||||||
|
description="Generate or edit images synchronously via Google Vertex API.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text prompt describing the image to generate or the edits to apply. "
|
||||||
|
"Include any constraints, styles, or details the model should follow.",
|
||||||
|
default="",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["gemini-3-pro-image-preview"],
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=42,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
|
||||||
|
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||||
|
"Also, changing the model or parameter settings, such as the temperature, "
|
||||||
|
"can cause variations in the response even when you use the same seed value. "
|
||||||
|
"By default, a random seed value is used.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
||||||
|
default="auto",
|
||||||
|
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||||
|
"if no image is provided, generates a 1:1 square.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=["1K", "2K", "4K"],
|
||||||
|
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"response_modalities",
|
||||||
|
options=["IMAGE+TEXT", "IMAGE"],
|
||||||
|
tooltip="Choose 'IMAGE' for image-only output, or "
|
||||||
|
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"images",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Optional reference image(s). "
|
||||||
|
"To include multiple images, use the Batch Images node (up to 14).",
|
||||||
|
),
|
||||||
|
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||||
|
"files",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Optional file(s) to use as context for the model. "
|
||||||
|
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(),
|
||||||
|
IO.String.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
seed: int,
|
||||||
|
aspect_ratio: str,
|
||||||
|
resolution: str,
|
||||||
|
response_modalities: str,
|
||||||
|
images: torch.Tensor | None = None,
|
||||||
|
files: list[GeminiPart] | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
|
||||||
|
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||||
|
if images is not None:
|
||||||
|
if get_number_of_images(images) > 14:
|
||||||
|
raise ValueError("The current maximum number of supported images is 14.")
|
||||||
|
parts.extend(create_image_parts(images))
|
||||||
|
if files is not None:
|
||||||
|
parts.extend(files)
|
||||||
|
|
||||||
|
image_config = GeminiImageConfig(imageSize=resolution)
|
||||||
|
if aspect_ratio != "auto":
|
||||||
|
image_config.aspectRatio = aspect_ratio
|
||||||
|
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||||
|
data=GeminiImageGenerateContentRequest(
|
||||||
|
contents=[
|
||||||
|
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||||
|
],
|
||||||
|
generationConfig=GeminiImageGenerationConfig(
|
||||||
|
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||||
|
imageConfig=image_config,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
response_model=GeminiGenerateContentResponse,
|
||||||
|
price_extractor=calculate_tokens_price,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_text = get_text_from_response(response)
|
||||||
|
if output_text:
|
||||||
|
render_spec = {
|
||||||
|
"node_id": cls.hidden.unique_id,
|
||||||
|
"component": "ChatHistoryWidget",
|
||||||
|
"props": {
|
||||||
|
"history": json.dumps(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
"response": output_text,
|
||||||
|
"response_id": str(uuid.uuid4()),
|
||||||
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
PromptServer.instance.send_sync(
|
||||||
|
"display_component",
|
||||||
|
render_spec,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(get_image_from_response(response), output_text)
|
||||||
|
|
||||||
|
|
||||||
class GeminiExtension(ComfyExtension):
|
class GeminiExtension(ComfyExtension):
|
||||||
@ -562,6 +756,7 @@ class GeminiExtension(ComfyExtension):
|
|||||||
return [
|
return [
|
||||||
GeminiNode,
|
GeminiNode,
|
||||||
GeminiImage,
|
GeminiImage,
|
||||||
|
GeminiImage2,
|
||||||
GeminiInputFiles,
|
GeminiInputFiles,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -63,6 +63,7 @@ class _RequestConfig:
|
|||||||
estimated_total: Optional[int] = None
|
estimated_total: Optional[int] = None
|
||||||
final_label_on_success: Optional[str] = "Completed"
|
final_label_on_success: Optional[str] = "Completed"
|
||||||
progress_origin_ts: Optional[float] = None
|
progress_origin_ts: Optional[float] = None
|
||||||
|
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -87,6 +88,7 @@ async def sync_op(
|
|||||||
endpoint: ApiEndpoint,
|
endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
response_model: Type[M],
|
response_model: Type[M],
|
||||||
|
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||||||
data: Optional[BaseModel] = None,
|
data: Optional[BaseModel] = None,
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
@ -104,6 +106,7 @@ async def sync_op(
|
|||||||
raw = await sync_op_raw(
|
raw = await sync_op_raw(
|
||||||
cls,
|
cls,
|
||||||
endpoint,
|
endpoint,
|
||||||
|
price_extractor=_wrap_model_extractor(response_model, price_extractor),
|
||||||
data=data,
|
data=data,
|
||||||
files=files,
|
files=files,
|
||||||
content_type=content_type,
|
content_type=content_type,
|
||||||
@ -175,6 +178,7 @@ async def sync_op_raw(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
endpoint: ApiEndpoint,
|
endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
|
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
@ -216,6 +220,7 @@ async def sync_op_raw(
|
|||||||
estimated_total=estimated_duration,
|
estimated_total=estimated_duration,
|
||||||
final_label_on_success=final_label_on_success,
|
final_label_on_success=final_label_on_success,
|
||||||
progress_origin_ts=progress_origin_ts,
|
progress_origin_ts=progress_origin_ts,
|
||||||
|
price_extractor=price_extractor,
|
||||||
)
|
)
|
||||||
return await _request_base(cfg, expect_binary=as_binary)
|
return await _request_base(cfg, expect_binary=as_binary)
|
||||||
|
|
||||||
@ -425,6 +430,7 @@ def _display_text(
|
|||||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||||
if price is not None:
|
if price is not None:
|
||||||
p = f"{float(price):,.4f}".rstrip("0").rstrip(".")
|
p = f"{float(price):,.4f}".rstrip("0").rstrip(".")
|
||||||
|
if p != "0":
|
||||||
display_lines.append(f"Price: ${p}")
|
display_lines.append(f"Price: ${p}")
|
||||||
if text is not None:
|
if text is not None:
|
||||||
display_lines.append(text)
|
display_lines.append(text)
|
||||||
@ -581,6 +587,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
delay = cfg.retry_delay
|
delay = cfg.retry_delay
|
||||||
operation_succeeded: bool = False
|
operation_succeeded: bool = False
|
||||||
final_elapsed_seconds: Optional[int] = None
|
final_elapsed_seconds: Optional[int] = None
|
||||||
|
extracted_price: Optional[float] = None
|
||||||
while True:
|
while True:
|
||||||
attempt += 1
|
attempt += 1
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
@ -768,6 +775,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
payload = {"_raw": text}
|
payload = {"_raw": text}
|
||||||
response_content_to_log = payload if isinstance(payload, dict) else text
|
response_content_to_log = payload if isinstance(payload, dict) else text
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||||
operation_succeeded = True
|
operation_succeeded = True
|
||||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
try:
|
try:
|
||||||
@ -872,7 +881,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
else int(time.monotonic() - start_time)
|
else int(time.monotonic() - start_time)
|
||||||
),
|
),
|
||||||
estimated_total=cfg.estimated_total,
|
estimated_total=cfg.estimated_total,
|
||||||
price=None,
|
price=extracted_price,
|
||||||
is_queued=False,
|
is_queued=False,
|
||||||
processing_elapsed_seconds=final_elapsed_seconds,
|
processing_elapsed_seconds=final_elapsed_seconds,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.28.8
|
comfyui-frontend-package==1.28.9
|
||||||
comfyui-workflow-templates==0.3.1
|
comfyui-workflow-templates==0.3.1
|
||||||
comfyui-embedded-docs==0.3.1
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user