From 22cf8d89968a47ce26be919f750f2311159145d1 Mon Sep 17 00:00:00 2001 From: Kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 17 Apr 2024 18:32:39 +0300 Subject: [PATCH] Add node to use SD3 through API --- nodes.py | 133 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 132 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 19ae207..c006612 100644 --- a/nodes.py +++ b/nodes.py @@ -4613,6 +4613,135 @@ class SplineEditor: print(masks_out.shape) return (masks_out, coordinates, normalized_y_values,) +class StabilityAPI_SD3: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "api_key": ("STRING", {"multiline": True}), + "prompt": ("STRING", {"multiline": True}), + "n_prompt": ("STRING", {"multiline": True}), + "seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}), + "model": ( + [ + 'sd3', + 'sd3-turbo', + ], + { + "default": 'sd3' + }), + "aspect_ratio": ( + [ + '1:1', + '16:9', + '21:9', + '2:3', + '3:2', + '4:5', + '5:4', + '9:16', + '9:21', + ], + { + "default": '1:1' + }), + "output_format": ( + [ + 'png', + 'jpeg', + ], + { + "default": 'jpeg' + }), + }, + "optional": { + "image": ("IMAGE",), + "img2img_strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "apicall" + + CATEGORY = "KJNodes/experimental" + DESCRIPTION = """ +## Calls StabilityAI API +- Your Stability API key, used to authenticate your requests. +Although you may have multiple keys in your account, +you should use the same key for all requests to this API. + +Get your API key here: https://platform.stability.ai/account/keys +sd3 requires 6.5 credits per generation +sd3-turbo requires 4 credits per generation + +If no image is provided, mode is set to text-to-image + +""" + + def apicall(self, api_key, prompt, n_prompt, model, seed, aspect_ratio, output_format, + img2img_strength=0.5, image=None): + + import requests + from io import BytesIO + from torchvision import transforms + + data = { + "mode": "text-to-image", + "prompt": prompt, + "model": model, + "seed": seed, + "output_format": output_format + } + + if image is not None: + image = image.permute(0, 3, 1, 2).squeeze(0) + to_pil = transforms.ToPILImage() + pil_image = to_pil(image) + # Save the PIL Image to a BytesIO object + buffer = BytesIO() + pil_image.save(buffer, format='PNG') + buffer.seek(0) + files = {"image": ("image.png", buffer, "image/png")} + + data["mode"] = "image-to-image" + data["image"] = pil_image + data["strength"] = img2img_strength + else: + data["aspect_ratio"] = aspect_ratio, + files = {"none": ''} + + if model != "sd3-turbo": + data["negative_prompt"] = n_prompt + + response = requests.post( + f"https://api.stability.ai/v2beta/stable-image/generate/sd3", + headers={ + "authorization": api_key, + "accept": "image/*" + }, + files = files, + data = data, + ) + + if response.status_code == 200: + # Convert the response content to a PIL Image + image = Image.open(BytesIO(response.content)) + # Convert the PIL Image to a PyTorch tensor + transform = transforms.ToTensor() + tensor_image = transform(image) + tensor_image = tensor_image.unsqueeze(0) + tensor_image = tensor_image.permute(0, 2, 3, 1).cpu().float() + return (tensor_image,) + else: + try: + # Attempt to parse the response as JSON + error_data = response.json() + raise Exception(f"Server error: {error_data}") + except json.JSONDecodeError: + # If the response is not valid JSON, raise a different exception + raise Exception(f"Server error: {response.text}") + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, @@ -4693,7 +4822,8 @@ NODE_CLASS_MAPPINGS = { "Sleep": Sleep, "ImagePadForOutpaintMasked": ImagePadForOutpaintMasked, "SplineEditor": SplineEditor, - "ImageAndMaskPreview": ImageAndMaskPreview + "ImageAndMaskPreview": ImageAndMaskPreview, + "StabilityAPI_SD3": StabilityAPI_SD3 } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -4775,4 +4905,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImagePadForOutpaintMasked": "Pad Image For Outpaint Masked", "SplineEditor": "Spline Editor", "ImageAndMaskPreview": "Image & Mask Preview", + "StabilityAPI_SD3": "Stability API SD3", } \ No newline at end of file