convert nodes_mahiro.py to V3 schema (#10070)

This commit is contained in:
Alexander Piskun 2025-09-29 22:35:51 +03:00 committed by GitHub
parent ed0f4a609b
commit 8accf50908
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,17 +1,29 @@
from typing_extensions import override
import torch
import torch.nn.functional as F
class Mahiro:
from comfy_api.latest import ComfyExtension, io
class Mahiro(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt."
def patch(self, model):
def define_schema(cls):
return io.Schema(
node_id="Mahiro",
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
category="_for_testing",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(display_name="patched_model"),
],
is_experimental=True,
)
@classmethod
def execute(cls, model) -> io.NodeOutput:
m = model.clone()
def mahiro_normd(args):
scale: float = args['cond_scale']
@ -30,12 +42,16 @@ class Mahiro:
wm = (simsc*cfg + (4-simsc)*leap) / 4
return wm
m.set_model_sampler_post_cfg_function(mahiro_normd)
return (m, )
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"Mahiro": Mahiro
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
}
class MahiroExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Mahiro,
]
async def comfy_entrypoint() -> MahiroExtension:
return MahiroExtension()