Add node to use SD3 through API

This commit is contained in:
Kijai 2024-04-17 18:32:39 +03:00
parent 76d28598e7
commit 22cf8d8996

133
nodes.py
View File

@ -4613,6 +4613,135 @@ class SplineEditor:
print(masks_out.shape)
return (masks_out, coordinates, normalized_y_values,)
class StabilityAPI_SD3:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"api_key": ("STRING", {"multiline": True}),
"prompt": ("STRING", {"multiline": True}),
"n_prompt": ("STRING", {"multiline": True}),
"seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
"model": (
[
'sd3',
'sd3-turbo',
],
{
"default": 'sd3'
}),
"aspect_ratio": (
[
'1:1',
'16:9',
'21:9',
'2:3',
'3:2',
'4:5',
'5:4',
'9:16',
'9:21',
],
{
"default": '1:1'
}),
"output_format": (
[
'png',
'jpeg',
],
{
"default": 'jpeg'
}),
},
"optional": {
"image": ("IMAGE",),
"img2img_strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "apicall"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = """
## Calls StabilityAI API
- Your Stability API key, used to authenticate your requests.
Although you may have multiple keys in your account,
you should use the same key for all requests to this API.
Get your API key here: https://platform.stability.ai/account/keys
sd3 requires 6.5 credits per generation
sd3-turbo requires 4 credits per generation
If no image is provided, mode is set to text-to-image
"""
def apicall(self, api_key, prompt, n_prompt, model, seed, aspect_ratio, output_format,
img2img_strength=0.5, image=None):
import requests
from io import BytesIO
from torchvision import transforms
data = {
"mode": "text-to-image",
"prompt": prompt,
"model": model,
"seed": seed,
"output_format": output_format
}
if image is not None:
image = image.permute(0, 3, 1, 2).squeeze(0)
to_pil = transforms.ToPILImage()
pil_image = to_pil(image)
# Save the PIL Image to a BytesIO object
buffer = BytesIO()
pil_image.save(buffer, format='PNG')
buffer.seek(0)
files = {"image": ("image.png", buffer, "image/png")}
data["mode"] = "image-to-image"
data["image"] = pil_image
data["strength"] = img2img_strength
else:
data["aspect_ratio"] = aspect_ratio,
files = {"none": ''}
if model != "sd3-turbo":
data["negative_prompt"] = n_prompt
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={
"authorization": api_key,
"accept": "image/*"
},
files = files,
data = data,
)
if response.status_code == 200:
# Convert the response content to a PIL Image
image = Image.open(BytesIO(response.content))
# Convert the PIL Image to a PyTorch tensor
transform = transforms.ToTensor()
tensor_image = transform(image)
tensor_image = tensor_image.unsqueeze(0)
tensor_image = tensor_image.permute(0, 2, 3, 1).cpu().float()
return (tensor_image,)
else:
try:
# Attempt to parse the response as JSON
error_data = response.json()
raise Exception(f"Server error: {error_data}")
except json.JSONDecodeError:
# If the response is not valid JSON, raise a different exception
raise Exception(f"Server error: {response.text}")
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
@ -4693,7 +4822,8 @@ NODE_CLASS_MAPPINGS = {
"Sleep": Sleep,
"ImagePadForOutpaintMasked": ImagePadForOutpaintMasked,
"SplineEditor": SplineEditor,
"ImageAndMaskPreview": ImageAndMaskPreview
"ImageAndMaskPreview": ImageAndMaskPreview,
"StabilityAPI_SD3": StabilityAPI_SD3
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -4775,4 +4905,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImagePadForOutpaintMasked": "Pad Image For Outpaint Masked",
"SplineEditor": "Spline Editor",
"ImageAndMaskPreview": "Image & Mask Preview",
"StabilityAPI_SD3": "Stability API SD3",
}