mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-06-05 00:24:30 +08:00
Add Superprompt test node
https://huggingface.co/roborovski/superprompt-v1
This commit is contained in:
parent
a7ea300a57
commit
8a7d3b9d58
3
.gitignore
vendored
3
.gitignore
vendored
@ -3,4 +3,5 @@ __pycache__
|
|||||||
.vscode
|
.vscode
|
||||||
*.ckpt
|
*.ckpt
|
||||||
*.pth
|
*.pth
|
||||||
types
|
types
|
||||||
|
models
|
||||||
38
nodes.py
38
nodes.py
@ -3836,6 +3836,40 @@ class LoadResAdapterNormalization:
|
|||||||
|
|
||||||
|
|
||||||
return (model_clone, )
|
return (model_clone, )
|
||||||
|
|
||||||
|
class Superprompt:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"instruction_prompt": ("STRING", {"default": 'Expand the following prompt to add more detail', "multiline": True}),
|
||||||
|
"prompt": ("STRING", {"default": '', "multiline": True, "forceInput": True}),
|
||||||
|
"max_new_tokens": ("INT", {"default": 128, "min": 1, "max": 4096, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
def process(self, instruction_prompt, prompt, max_new_tokens):
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||||
|
|
||||||
|
checkpoint_path = os.path.join(script_dir, "models","superprompt-v1")
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small", legacy=False)
|
||||||
|
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(checkpoint_path, device_map=device)
|
||||||
|
model.to(device)
|
||||||
|
input_text = instruction_prompt + ": " + prompt
|
||||||
|
print(input_text)
|
||||||
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
||||||
|
outputs = model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||||
|
out = (tokenizer.decode(outputs[0]))
|
||||||
|
out = out.replace('<pad>', '')
|
||||||
|
out = out.replace('</s>', '')
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
return (out, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"INTConstant": INTConstant,
|
"INTConstant": INTConstant,
|
||||||
@ -3907,7 +3941,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ImageNormalize_Neg1_To_1": ImageNormalize_Neg1_To_1,
|
"ImageNormalize_Neg1_To_1": ImageNormalize_Neg1_To_1,
|
||||||
"Intrinsic_lora_sampling": Intrinsic_lora_sampling,
|
"Intrinsic_lora_sampling": Intrinsic_lora_sampling,
|
||||||
"RemapMaskRange": RemapMaskRange,
|
"RemapMaskRange": RemapMaskRange,
|
||||||
"LoadResAdapterNormalization": LoadResAdapterNormalization
|
"LoadResAdapterNormalization": LoadResAdapterNormalization,
|
||||||
|
"Superprompt": Superprompt
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"INTConstant": "INT Constant",
|
"INTConstant": "INT Constant",
|
||||||
@ -3979,4 +4014,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"Intrinsic_lora_sampling": "Intrinsic_lora_sampling",
|
"Intrinsic_lora_sampling": "Intrinsic_lora_sampling",
|
||||||
"RemapMaskRange": "RemapMaskRange",
|
"RemapMaskRange": "RemapMaskRange",
|
||||||
"LoadResAdapterNormalization": "LoadResAdapterNormalization",
|
"LoadResAdapterNormalization": "LoadResAdapterNormalization",
|
||||||
|
"Superprompt": "Superprompt",
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user