mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
Merge remote-tracking branch 'upstream/master' into multitalk
This commit is contained in:
commit
25063f25cc
@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
|
|||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||||
|
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
@ -146,7 +147,7 @@ class PerformanceFeature(enum.Enum):
|
|||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
PinnedMem = "pinned_memory"
|
PinnedMem = "pinned_memory"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||||
|
|||||||
@ -522,7 +522,7 @@ class NextDiT(nn.Module):
|
|||||||
max_cap_len = max(l_effective_cap_len)
|
max_cap_len = max(l_effective_cap_len)
|
||||||
max_img_len = max(l_effective_img_len)
|
max_img_len = max(l_effective_img_len)
|
||||||
|
|
||||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
cap_len = l_effective_cap_len[i]
|
cap_len = l_effective_cap_len[i]
|
||||||
@ -531,10 +531,22 @@ class NextDiT(nn.Module):
|
|||||||
H_tokens, W_tokens = H // pH, W // pW
|
H_tokens, W_tokens = H // pH, W // pW
|
||||||
assert H_tokens * W_tokens == img_len
|
assert H_tokens * W_tokens == img_len
|
||||||
|
|
||||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
h_scale = 1.0
|
||||||
|
w_scale = 1.0
|
||||||
|
h_start = 0
|
||||||
|
w_start = 0
|
||||||
|
if rope_options is not None:
|
||||||
|
h_scale = rope_options.get("scale_y", 1.0)
|
||||||
|
w_scale = rope_options.get("scale_x", 1.0)
|
||||||
|
|
||||||
|
h_start = rope_options.get("shift_y", 0.0)
|
||||||
|
w_start = rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
|
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
|
||||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
||||||
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||||
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||||
|
|
||||||
|
|||||||
@ -599,7 +599,7 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
|
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||||
@ -612,10 +612,22 @@ class WanModel(torch.nn.Module):
|
|||||||
if steps_w is None:
|
if steps_w is None:
|
||||||
steps_w = w_len
|
steps_w = w_len
|
||||||
|
|
||||||
|
h_start = 0
|
||||||
|
w_start = 0
|
||||||
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
if rope_options is not None:
|
||||||
|
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||||
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||||
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||||
|
|
||||||
|
t_start += rope_options.get("shift_t", 0.0)
|
||||||
|
h_start += rope_options.get("shift_y", 0.0)
|
||||||
|
w_start += rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
@ -641,7 +653,7 @@ class WanModel(torch.nn.Module):
|
|||||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||||
t_len += 1
|
t_len += 1
|
||||||
|
|
||||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
|
|||||||
@ -276,6 +276,9 @@ class ModelPatcher:
|
|||||||
self.size = comfy.model_management.module_size(self.model)
|
self.size = comfy.model_management.module_size(self.model)
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.model_size()
|
||||||
|
|
||||||
def loaded_size(self):
|
def loaded_size(self):
|
||||||
return self.model.model_loaded_weight_memory
|
return self.model.model_loaded_weight_memory
|
||||||
|
|
||||||
@ -295,6 +298,7 @@ class ModelPatcher:
|
|||||||
n.backup = self.backup
|
n.backup = self.backup
|
||||||
n.object_patches_backup = self.object_patches_backup
|
n.object_patches_backup = self.object_patches_backup
|
||||||
n.parent = self
|
n.parent = self
|
||||||
|
n.pinned = self.pinned
|
||||||
|
|
||||||
n.force_cast_weights = self.force_cast_weights
|
n.force_cast_weights = self.force_cast_weights
|
||||||
|
|
||||||
@ -451,6 +455,19 @@ class ModelPatcher:
|
|||||||
def set_model_post_input_patch(self, patch):
|
def set_model_post_input_patch(self, patch):
|
||||||
self.set_model_patch(patch, "post_input")
|
self.set_model_patch(patch, "post_input")
|
||||||
|
|
||||||
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||||
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||||
|
rope_options["scale_x"] = scale_x
|
||||||
|
rope_options["scale_y"] = scale_y
|
||||||
|
rope_options["scale_t"] = scale_t
|
||||||
|
|
||||||
|
rope_options["shift_x"] = shift_x
|
||||||
|
rope_options["shift_y"] = shift_y
|
||||||
|
rope_options["shift_t"] = shift_t
|
||||||
|
|
||||||
|
self.model_options["transformer_options"]["rope_options"] = rope_options
|
||||||
|
|
||||||
|
|
||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
|
|||||||
47
comfy/ops.py
47
comfy/ops.py
@ -84,7 +84,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
if offloadable:
|
if offloadable and (device != s.weight.device or
|
||||||
|
(s.bias is not None and device != s.bias.device)):
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
else:
|
else:
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
@ -94,20 +95,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
else:
|
else:
|
||||||
wf_context = contextlib.nullcontext()
|
wf_context = contextlib.nullcontext()
|
||||||
|
|
||||||
bias = None
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
|
||||||
has_function = len(s.bias_function) > 0
|
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
|
||||||
|
|
||||||
if has_function:
|
weight_has_function = len(s.weight_function) > 0
|
||||||
|
bias_has_function = len(s.bias_function) > 0
|
||||||
|
|
||||||
|
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||||
|
|
||||||
|
bias = None
|
||||||
|
if s.bias is not None:
|
||||||
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||||
|
|
||||||
|
if bias_has_function:
|
||||||
with wf_context:
|
with wf_context:
|
||||||
for f in s.bias_function:
|
for f in s.bias_function:
|
||||||
bias = f(bias)
|
bias = f(bias)
|
||||||
|
|
||||||
has_function = len(s.weight_function) > 0
|
weight = weight.to(dtype=dtype)
|
||||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
if weight_has_function:
|
||||||
if has_function:
|
|
||||||
with wf_context:
|
with wf_context:
|
||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
@ -401,15 +406,9 @@ def fp8_linear(self, input):
|
|||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
tensor_2d = False
|
|
||||||
if len(input.shape) == 2:
|
|
||||||
tensor_2d = True
|
|
||||||
input = input.unsqueeze(1)
|
|
||||||
|
|
||||||
input_shape = input.shape
|
|
||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if input.ndim == 3 or input.ndim == 2:
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
@ -422,24 +421,20 @@ def fp8_linear(self, input):
|
|||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||||
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
|
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
uncast_bias_weight(self, w, bias, offload_stream)
|
uncast_bias_weight(self, w, bias, offload_stream)
|
||||||
|
return o
|
||||||
if tensor_2d:
|
|
||||||
return o.reshape(input_shape[0], -1)
|
|
||||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -540,12 +535,12 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
from .quant_ops import QuantizedTensor
|
||||||
|
|
||||||
QUANT_FORMAT_MIXINS = {
|
QUANT_FORMAT_MIXINS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
"dtype": torch.float8_e4m3fn,
|
"dtype": torch.float8_e4m3fn,
|
||||||
"layout_type": TensorCoreFP8Layout,
|
"layout_type": "TensorCoreFP8Layout",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
|||||||
@ -123,7 +123,7 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
layout_type: Layout class (subclass of QuantizedLayout)
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
layout_params: Dict with layout-specific parameters
|
layout_params: Dict with layout-specific parameters
|
||||||
"""
|
"""
|
||||||
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
|
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||||
|
|
||||||
def __init__(self, qdata, layout_type, layout_params):
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
self._qdata = qdata.contiguous()
|
self._qdata = qdata.contiguous()
|
||||||
@ -183,11 +183,11 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||||
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
|
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
||||||
return cls(qdata, layout_type, layout_params)
|
return cls(qdata, layout_type, layout_params)
|
||||||
|
|
||||||
def dequantize(self) -> torch.Tensor:
|
def dequantize(self) -> torch.Tensor:
|
||||||
return self._layout_type.dequantize(self._qdata, **self._layout_params)
|
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
@ -379,7 +379,12 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
return qtensor._qdata, qtensor._layout_params['scale']
|
return qtensor._qdata, qtensor._layout_params['scale']
|
||||||
|
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
LAYOUTS = {
|
||||||
|
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
||||||
def fp8_linear(func, args, kwargs):
|
def fp8_linear(func, args, kwargs):
|
||||||
input_tensor = args[0]
|
input_tensor = args[0]
|
||||||
weight = args[1]
|
weight = args[1]
|
||||||
@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs):
|
|||||||
'scale': output_scale,
|
'scale': output_scale,
|
||||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||||
}
|
}
|
||||||
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
|
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs):
|
|||||||
input_tensor = input_tensor.dequantize()
|
input_tensor = input_tensor.dequantize()
|
||||||
|
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||||
|
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_func(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
|
ar = list(args)
|
||||||
|
ar[0] = plain_input
|
||||||
|
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|||||||
14
comfy/sd.py
14
comfy/sd.py
@ -143,6 +143,9 @@ class CLIP:
|
|||||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.patcher.get_ram_usage()
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||||
|
|
||||||
@ -293,6 +296,7 @@ class VAE:
|
|||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = False
|
self.disable_offload = False
|
||||||
self.not_video = False
|
self.not_video = False
|
||||||
|
self.size = None
|
||||||
|
|
||||||
self.downscale_index_formula = None
|
self.downscale_index_formula = None
|
||||||
self.upscale_index_formula = None
|
self.upscale_index_formula = None
|
||||||
@ -595,6 +599,16 @@ class VAE:
|
|||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
self.model_size()
|
||||||
|
|
||||||
|
def model_size(self):
|
||||||
|
if self.size is not None:
|
||||||
|
return self.size
|
||||||
|
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.model_size()
|
||||||
|
|
||||||
def throw_exception_if_invalid(self):
|
def throw_exception_if_invalid(self):
|
||||||
if self.first_stage_model is None:
|
if self.first_stage_model is None:
|
||||||
|
|||||||
@ -1,73 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import aiohttp
|
|
||||||
import mimetypes
|
|
||||||
from typing import Union
|
|
||||||
from server import PromptServer
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
import base64
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_and_cast_response(
|
|
||||||
response, timeout: int = None, node_id: Union[str, None] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Validates and casts a response to a torch.Tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: The response to validate and cast.
|
|
||||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A torch.Tensor representing the image (1, H, W, C).
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the response is not valid.
|
|
||||||
"""
|
|
||||||
# validate raw JSON response
|
|
||||||
data = response.data
|
|
||||||
if not data or len(data) == 0:
|
|
||||||
raise ValueError("No images returned from API endpoint")
|
|
||||||
|
|
||||||
# Initialize list to store image tensors
|
|
||||||
image_tensors: list[torch.Tensor] = []
|
|
||||||
|
|
||||||
# Process each image in the data array
|
|
||||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
|
||||||
for img_data in data:
|
|
||||||
img_bytes: bytes
|
|
||||||
if img_data.b64_json:
|
|
||||||
img_bytes = base64.b64decode(img_data.b64_json)
|
|
||||||
elif img_data.url:
|
|
||||||
if node_id:
|
|
||||||
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
|
||||||
async with session.get(img_data.url) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
raise ValueError("Failed to download generated image")
|
|
||||||
img_bytes = await resp.read()
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
|
||||||
|
|
||||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
|
||||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
|
||||||
image_tensors.append(torch.from_numpy(arr))
|
|
||||||
|
|
||||||
return torch.stack(image_tensors, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
|
||||||
"""Converts a text file to a base64 string."""
|
|
||||||
with open(filepath, "rb") as f:
|
|
||||||
file_content = f.read()
|
|
||||||
return base64.b64encode(file_content).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
|
||||||
"""Converts a text file to a data URI."""
|
|
||||||
base64_string = text_filepath_to_base64_string(filepath)
|
|
||||||
mime_type, _ = mimetypes.guess_type(filepath)
|
|
||||||
if mime_type is None:
|
|
||||||
mime_type = "application/octet-stream"
|
|
||||||
return f"data:{mime_type};base64,{base64_string}"
|
|
||||||
@ -46,7 +46,7 @@ class TextToVideoNode(IO.ComfyNode):
|
|||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
),
|
),
|
||||||
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
|
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[
|
options=[
|
||||||
@ -85,6 +85,10 @@ class TextToVideoNode(IO.ComfyNode):
|
|||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1, max_length=10000)
|
validate_string(prompt, min_length=1, max_length=10000)
|
||||||
|
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||||
|
raise ValueError(
|
||||||
|
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||||
|
)
|
||||||
response = await sync_op_raw(
|
response = await sync_op_raw(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
||||||
@ -118,7 +122,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
|||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
),
|
),
|
||||||
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
|
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[
|
options=[
|
||||||
@ -158,6 +162,10 @@ class ImageToVideoNode(IO.ComfyNode):
|
|||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1, max_length=10000)
|
validate_string(prompt, min_length=1, max_length=10000)
|
||||||
|
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||||
|
raise ValueError(
|
||||||
|
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||||
|
)
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
raise ValueError("Currently only one input image is supported.")
|
raise ValueError("Currently only one input image is supported.")
|
||||||
response = await sync_op_raw(
|
response = await sync_op_raw(
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -7,24 +7,23 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, TypeVar
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||||
from comfy_api_nodes.apis import pika_defs
|
from comfy_api_nodes.apis import pika_api as pika_defs
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
|
validate_string,
|
||||||
|
download_url_to_video_output,
|
||||||
|
tensor_to_bytesio,
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
EmptyRequest,
|
sync_op,
|
||||||
HttpMethod,
|
poll_op,
|
||||||
PollingOperation,
|
|
||||||
SynchronousOperation,
|
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio
|
|
||||||
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||||
@ -40,28 +39,18 @@ PATH_VIDEO_GET = "/proxy/pika/videos"
|
|||||||
|
|
||||||
|
|
||||||
async def execute_task(
|
async def execute_task(
|
||||||
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
|
task_id: str,
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
cls: type[IO.ComfyNode],
|
||||||
node_id: Optional[str] = None,
|
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
task_id = (await initial_operation.execute()).video_id
|
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
||||||
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
|
cls,
|
||||||
poll_endpoint=ApiEndpoint(
|
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
||||||
path=f"{PATH_VIDEO_GET}/{task_id}",
|
response_model=pika_defs.PikaVideoResponse,
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=pika_defs.PikaVideoResponse,
|
|
||||||
),
|
|
||||||
completed_statuses=["finished"],
|
|
||||||
failed_statuses=["failed", "cancelled"],
|
|
||||||
status_extractor=lambda response: (response.status.value if response.status else None),
|
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||||
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
|
|
||||||
node_id=node_id,
|
|
||||||
estimated_duration=60,
|
estimated_duration=60,
|
||||||
max_poll_attempts=240,
|
max_poll_attempts=240,
|
||||||
).execute()
|
)
|
||||||
if not final_response.url:
|
if not final_response.url:
|
||||||
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
@ -124,23 +113,15 @@ class PikaImageToVideo(IO.ComfyNode):
|
|||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_IMAGE_TO_VIDEO,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaTextToVideoNode(IO.ComfyNode):
|
class PikaTextToVideoNode(IO.ComfyNode):
|
||||||
@ -183,18 +164,11 @@ class PikaTextToVideoNode(IO.ComfyNode):
|
|||||||
duration: int,
|
duration: int,
|
||||||
aspect_ratio: float,
|
aspect_ratio: float,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_TEXT_TO_VIDEO,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -202,10 +176,9 @@ class PikaTextToVideoNode(IO.ComfyNode):
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
aspectRatio=aspect_ratio,
|
aspectRatio=aspect_ratio,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
content_type="application/x-www-form-urlencoded",
|
content_type="application/x-www-form-urlencoded",
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaScenes(IO.ComfyNode):
|
class PikaScenes(IO.ComfyNode):
|
||||||
@ -309,24 +282,16 @@ class PikaScenes(IO.ComfyNode):
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
aspectRatio=aspect_ratio,
|
aspectRatio=aspect_ratio,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKASCENES,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikAdditionsNode(IO.ComfyNode):
|
class PikAdditionsNode(IO.ComfyNode):
|
||||||
@ -383,24 +348,16 @@ class PikAdditionsNode(IO.ComfyNode):
|
|||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKADDITIONS,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaSwapsNode(IO.ComfyNode):
|
class PikaSwapsNode(IO.ComfyNode):
|
||||||
@ -472,23 +429,15 @@ class PikaSwapsNode(IO.ComfyNode):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKASWAPS,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaffectsNode(IO.ComfyNode):
|
class PikaffectsNode(IO.ComfyNode):
|
||||||
@ -528,18 +477,11 @@ class PikaffectsNode(IO.ComfyNode):
|
|||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKAFFECTS,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
|
||||||
pikaffect=pikaffect,
|
pikaffect=pikaffect,
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
@ -547,9 +489,8 @@ class PikaffectsNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaStartEndFrameNode(IO.ComfyNode):
|
class PikaStartEndFrameNode(IO.ComfyNode):
|
||||||
@ -592,18 +533,11 @@ class PikaStartEndFrameNode(IO.ComfyNode):
|
|||||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||||
]
|
]
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKAFRAMES,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -612,9 +546,8 @@ class PikaStartEndFrameNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaApiNodesExtension(ComfyExtension):
|
class PikaApiNodesExtension(ComfyExtension):
|
||||||
|
|||||||
@ -225,21 +225,20 @@ async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] =
|
|||||||
|
|
||||||
|
|
||||||
async def download_files(url_list, task_uuid):
|
async def download_files(url_list, task_uuid):
|
||||||
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
|
result_folder_name = f"Rodin3D_{task_uuid}"
|
||||||
|
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
|
||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
model_file_path = None
|
model_file_path = None
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
for i in url_list.list:
|
for i in url_list.list:
|
||||||
url = i.url
|
file_path = os.path.join(save_path, i.name)
|
||||||
file_name = i.name
|
|
||||||
file_path = os.path.join(save_path, file_name)
|
|
||||||
if file_path.endswith(".glb"):
|
if file_path.endswith(".glb"):
|
||||||
model_file_path = file_path
|
model_file_path = os.path.join(result_folder_name, i.name)
|
||||||
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
|
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
|
||||||
max_retries = 5
|
max_retries = 5
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
async with session.get(url) as resp:
|
async with session.get(i.url) as resp:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
with open(file_path, "wb") as f:
|
with open(file_path, "wb") as f:
|
||||||
async for chunk in resp.content.iter_chunked(32 * 1024):
|
async for chunk in resp.content.iter_chunked(32 * 1024):
|
||||||
|
|||||||
@ -20,13 +20,6 @@ from comfy_api_nodes.apis.stability_api import (
|
|||||||
StabilityAudioInpaintRequest,
|
StabilityAudioInpaintRequest,
|
||||||
StabilityAudioResponse,
|
StabilityAudioResponse,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiEndpoint,
|
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
PollingOperation,
|
|
||||||
EmptyRequest,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
validate_audio_duration,
|
validate_audio_duration,
|
||||||
validate_string,
|
validate_string,
|
||||||
@ -34,6 +27,9 @@ from comfy_api_nodes.util import (
|
|||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
audio_bytes_to_audio_input,
|
audio_bytes_to_audio_input,
|
||||||
|
sync_op,
|
||||||
|
poll_op,
|
||||||
|
ApiEndpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
data=StabilityStableUltraRequest(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/generate/ultra",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityStableUltraRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=StabilityStableUltraRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||||
@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
data=StabilityStable3_5Request(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/generate/sd3",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityStable3_5Request,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=StabilityStable3_5Request(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||||
@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
data=StabilityUpscaleConservativeRequest(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityUpscaleConservativeRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=StabilityUpscaleConservativeRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
creativity=round(creativity,2),
|
creativity=round(creativity,2),
|
||||||
@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||||
@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
|
||||||
}
|
response_model=StabilityAsyncResponse,
|
||||||
|
data=StabilityUpscaleCreativeRequest(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/creative",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityUpscaleCreativeRequest,
|
|
||||||
response_model=StabilityAsyncResponse,
|
|
||||||
),
|
|
||||||
request=StabilityUpscaleCreativeRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
creativity=round(creativity,2),
|
creativity=round(creativity,2),
|
||||||
@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/stability/v2beta/results/{response_api.id}",
|
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
|
||||||
method=HttpMethod.GET,
|
response_model=StabilityResultsGetResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=StabilityResultsGetResponse,
|
|
||||||
),
|
|
||||||
poll_interval=3,
|
poll_interval=3,
|
||||||
completed_statuses=[StabilityPollStatus.finished],
|
|
||||||
failed_statuses=[StabilityPollStatus.failed],
|
|
||||||
status_extractor=lambda x: get_async_dummy_status(x),
|
status_extractor=lambda x: get_async_dummy_status(x),
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
)
|
)
|
||||||
response_poll: StabilityResultsGetResponse = await operation.execute()
|
|
||||||
|
|
||||||
if response_poll.finish_reason != "SUCCESS":
|
if response_poll.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||||
@ -628,24 +576,13 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/fast",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=EmptyRequest(),
|
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||||
@ -717,21 +654,13 @@ class StabilityTextToAudio(IO.ComfyNode):
|
|||||||
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
||||||
validate_string(prompt, max_length=10000)
|
validate_string(prompt, max_length=10000)
|
||||||
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
|
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=StabilityAudioResponse,
|
||||||
request_model=StabilityTextToAudioRequest,
|
data=payload,
|
||||||
response_model=StabilityAudioResponse,
|
|
||||||
),
|
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs= {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
@ -814,22 +743,14 @@ class StabilityAudioToAudio(IO.ComfyNode):
|
|||||||
payload = StabilityAudioToAudioRequest(
|
payload = StabilityAudioToAudioRequest(
|
||||||
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
||||||
)
|
)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
|
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=StabilityAudioResponse,
|
||||||
request_model=StabilityAudioToAudioRequest,
|
data=payload,
|
||||||
response_model=StabilityAudioResponse,
|
|
||||||
),
|
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
files={"audio": audio_input_to_mp3(audio)},
|
files={"audio": audio_input_to_mp3(audio)},
|
||||||
auth_kwargs= {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
@ -935,22 +856,14 @@ class StabilityAudioInpaint(IO.ComfyNode):
|
|||||||
mask_start=mask_start,
|
mask_start=mask_start,
|
||||||
mask_end=mask_end,
|
mask_end=mask_end,
|
||||||
)
|
)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
|
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=StabilityAudioResponse,
|
||||||
request_model=StabilityAudioInpaintRequest,
|
data=payload,
|
||||||
response_model=StabilityAudioResponse,
|
|
||||||
),
|
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
files={"audio": audio_input_to_mp3(audio)},
|
files={"audio": audio_input_to_mp3(audio)},
|
||||||
auth_kwargs={
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
|
|||||||
@ -18,6 +18,8 @@ from .conversions import (
|
|||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
tensor_to_pil,
|
tensor_to_pil,
|
||||||
|
text_filepath_to_base64_string,
|
||||||
|
text_filepath_to_data_uri,
|
||||||
trim_video,
|
trim_video,
|
||||||
video_to_base64_string,
|
video_to_base64_string,
|
||||||
)
|
)
|
||||||
@ -75,6 +77,8 @@ __all__ = [
|
|||||||
"tensor_to_base64_string",
|
"tensor_to_base64_string",
|
||||||
"tensor_to_bytesio",
|
"tensor_to_bytesio",
|
||||||
"tensor_to_pil",
|
"tensor_to_pil",
|
||||||
|
"text_filepath_to_base64_string",
|
||||||
|
"text_filepath_to_data_uri",
|
||||||
"trim_video",
|
"trim_video",
|
||||||
"video_to_base64_string",
|
"video_to_base64_string",
|
||||||
# Validation utilities
|
# Validation utilities
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class _PollUIState:
|
|||||||
|
|
||||||
|
|
||||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
|
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"]
|
||||||
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
||||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||||||
|
|
||||||
@ -589,7 +589,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||||
|
|
||||||
payload_headers = {"Accept": "*/*"}
|
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||||
if cfg.endpoint.headers:
|
if cfg.endpoint.headers:
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import mimetypes
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -12,7 +13,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from comfy.utils import common_upscale
|
from comfy.utils import common_upscale
|
||||||
from comfy_api.latest import Input, InputImpl
|
from comfy_api.latest import Input, InputImpl
|
||||||
from comfy_api.util import VideoContainer, VideoCodec
|
from comfy_api.util import VideoCodec, VideoContainer
|
||||||
|
|
||||||
from ._helpers import mimetype_to_extension
|
from ._helpers import mimetype_to_extension
|
||||||
|
|
||||||
@ -451,3 +452,19 @@ def resize_mask_to_image(
|
|||||||
if not allow_gradient:
|
if not allow_gradient:
|
||||||
mask = (mask > 0.5).float()
|
mask = (mask > 0.5).float()
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||||
|
"""Converts a text file to a base64 string."""
|
||||||
|
with open(filepath, "rb") as f:
|
||||||
|
file_content = f.read()
|
||||||
|
return base64.b64encode(file_content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||||
|
"""Converts a text file to a data URI."""
|
||||||
|
base64_string = text_filepath_to_base64_string(filepath)
|
||||||
|
mime_type, _ = mimetypes.guess_type(filepath)
|
||||||
|
if mime_type is None:
|
||||||
|
mime_type = "application/octet-stream"
|
||||||
|
return f"data:{mime_type};base64,{base64_string}"
|
||||||
|
|||||||
@ -1,4 +1,9 @@
|
|||||||
|
import bisect
|
||||||
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
|
import psutil
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
from typing import Sequence, Mapping, Dict
|
from typing import Sequence, Mapping, Dict
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -48,7 +53,7 @@ class Unhashable:
|
|||||||
def to_hashable(obj):
|
def to_hashable(obj):
|
||||||
# So that we don't infinitely recurse since frozenset and tuples
|
# So that we don't infinitely recurse since frozenset and tuples
|
||||||
# are Sequences.
|
# are Sequences.
|
||||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
||||||
return obj
|
return obj
|
||||||
elif isinstance(obj, Mapping):
|
elif isinstance(obj, Mapping):
|
||||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||||
@ -188,6 +193,9 @@ class BasicCache:
|
|||||||
self._clean_cache()
|
self._clean_cache()
|
||||||
self._clean_subcaches()
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
def poll(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
def _set_immediate(self, node_id, value):
|
def _set_immediate(self, node_id, value):
|
||||||
assert self.initialized
|
assert self.initialized
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
@ -276,6 +284,9 @@ class NullCache:
|
|||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def poll(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
def get(self, node_id):
|
def get(self, node_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -336,3 +347,75 @@ class LRUCache(BasicCache):
|
|||||||
self._mark_used(child_id)
|
self._mark_used(child_id)
|
||||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
||||||
|
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
||||||
|
|
||||||
|
RAM_CACHE_HYSTERESIS = 1.1
|
||||||
|
|
||||||
|
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
|
||||||
|
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
|
||||||
|
|
||||||
|
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
||||||
|
|
||||||
|
#Exponential bias towards evicting older workflows so garbage will be taken out
|
||||||
|
#in constantly changing setups.
|
||||||
|
|
||||||
|
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||||
|
|
||||||
|
class RAMPressureCache(LRUCache):
|
||||||
|
|
||||||
|
def __init__(self, key_class):
|
||||||
|
super().__init__(key_class, 0)
|
||||||
|
self.timestamps = {}
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
def set(self, node_id, value):
|
||||||
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
|
super().set(node_id, value)
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
|
return super().get(node_id)
|
||||||
|
|
||||||
|
def poll(self, ram_headroom):
|
||||||
|
def _ram_gb():
|
||||||
|
return psutil.virtual_memory().available / (1024**3)
|
||||||
|
|
||||||
|
if _ram_gb() > ram_headroom:
|
||||||
|
return
|
||||||
|
gc.collect()
|
||||||
|
if _ram_gb() > ram_headroom:
|
||||||
|
return
|
||||||
|
|
||||||
|
clean_list = []
|
||||||
|
|
||||||
|
for key, (outputs, _), in self.cache.items():
|
||||||
|
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||||
|
|
||||||
|
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||||
|
def scan_list_for_ram_usage(outputs):
|
||||||
|
nonlocal ram_usage
|
||||||
|
for output in outputs:
|
||||||
|
if isinstance(output, list):
|
||||||
|
scan_list_for_ram_usage(output)
|
||||||
|
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||||
|
#score Tensors at a 50% discount for RAM usage as they are likely to
|
||||||
|
#be high value intermediates
|
||||||
|
ram_usage += (output.numel() * output.element_size()) * 0.5
|
||||||
|
elif hasattr(output, "get_ram_usage"):
|
||||||
|
ram_usage += output.get_ram_usage()
|
||||||
|
scan_list_for_ram_usage(outputs)
|
||||||
|
|
||||||
|
oom_score *= ram_usage
|
||||||
|
#In the case where we have no information on the node ram usage at all,
|
||||||
|
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||||
|
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||||
|
|
||||||
|
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||||
|
_, _, key = clean_list.pop()
|
||||||
|
del self.cache[key]
|
||||||
|
gc.collect()
|
||||||
|
|||||||
@ -209,10 +209,15 @@ class ExecutionList(TopologicalSort):
|
|||||||
self.execution_cache_listeners[from_node_id] = set()
|
self.execution_cache_listeners[from_node_id] = set()
|
||||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||||
|
|
||||||
def get_output_cache(self, from_node_id, to_node_id):
|
def get_cache(self, from_node_id, to_node_id):
|
||||||
if not to_node_id in self.execution_cache:
|
if not to_node_id in self.execution_cache:
|
||||||
return None
|
return None
|
||||||
return self.execution_cache[to_node_id].get(from_node_id)
|
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
#Write back to the main cache on touch.
|
||||||
|
self.output_cache.set(from_node_id, value)
|
||||||
|
return value
|
||||||
|
|
||||||
def cache_update(self, node_id, value):
|
def cache_update(self, node_id, value):
|
||||||
if node_id in self.execution_cache_listeners:
|
if node_id in self.execution_cache_listeners:
|
||||||
|
|||||||
@ -2,6 +2,9 @@ import comfy.utils
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
def load_hypernetwork_patch(path, strength):
|
def load_hypernetwork_patch(path, strength):
|
||||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||||
@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
|
|
||||||
return hypernetwork_patch(out, strength)
|
return hypernetwork_patch(out, strength)
|
||||||
|
|
||||||
class HypernetworkLoader:
|
class HypernetworkLoader(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return IO.Schema(
|
||||||
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
node_id="HypernetworkLoader",
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
category="loaders",
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("MODEL",)
|
IO.Model.Input("model"),
|
||||||
FUNCTION = "load_hypernetwork"
|
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
|
||||||
|
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
@classmethod
|
||||||
|
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
|
||||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
|
||||||
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||||
model_hypernetwork = model.clone()
|
model_hypernetwork = model.clone()
|
||||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
if patch is not None:
|
if patch is not None:
|
||||||
model_hypernetwork.set_model_attn1_patch(patch)
|
model_hypernetwork.set_model_attn1_patch(patch)
|
||||||
model_hypernetwork.set_model_attn2_patch(patch)
|
model_hypernetwork.set_model_attn2_patch(patch)
|
||||||
return (model_hypernetwork,)
|
return IO.NodeOutput(model_hypernetwork)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
load_hypernetwork = execute # TODO: remove
|
||||||
"HypernetworkLoader": HypernetworkLoader
|
|
||||||
}
|
|
||||||
|
class HyperNetworkExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
HypernetworkLoader,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> HyperNetworkExtension:
|
||||||
|
return HyperNetworkExtension()
|
||||||
|
|||||||
47
comfy_extras/nodes_rope.py
Normal file
47
comfy_extras/nodes_rope.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
class ScaleROPE(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ScaleROPE",
|
||||||
|
category="advanced/model_patches",
|
||||||
|
description="Scale and shift the ROPE of the model.",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||||
|
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||||
|
|
||||||
|
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||||
|
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||||
|
|
||||||
|
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||||
|
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||||
|
|
||||||
|
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput:
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
class RopeExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
ScaleROPE
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> RopeExtension:
|
||||||
|
return RopeExtension()
|
||||||
81
execution.py
81
execution.py
@ -21,6 +21,7 @@ from comfy_execution.caching import (
|
|||||||
NullCache,
|
NullCache,
|
||||||
HierarchicalCache,
|
HierarchicalCache,
|
||||||
LRUCache,
|
LRUCache,
|
||||||
|
RAMPressureCache,
|
||||||
)
|
)
|
||||||
from comfy_execution.graph import (
|
from comfy_execution.graph import (
|
||||||
DynamicPrompt,
|
DynamicPrompt,
|
||||||
@ -88,49 +89,56 @@ class IsChangedCache:
|
|||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
|
||||||
|
class CacheEntry(NamedTuple):
|
||||||
|
ui: dict
|
||||||
|
outputs: list
|
||||||
|
|
||||||
|
|
||||||
class CacheType(Enum):
|
class CacheType(Enum):
|
||||||
CLASSIC = 0
|
CLASSIC = 0
|
||||||
LRU = 1
|
LRU = 1
|
||||||
NONE = 2
|
NONE = 2
|
||||||
|
RAM_PRESSURE = 3
|
||||||
|
|
||||||
|
|
||||||
class CacheSet:
|
class CacheSet:
|
||||||
def __init__(self, cache_type=None, cache_size=None):
|
def __init__(self, cache_type=None, cache_args={}):
|
||||||
if cache_type == CacheType.NONE:
|
if cache_type == CacheType.NONE:
|
||||||
self.init_null_cache()
|
self.init_null_cache()
|
||||||
logging.info("Disabling intermediate node cache.")
|
logging.info("Disabling intermediate node cache.")
|
||||||
|
elif cache_type == CacheType.RAM_PRESSURE:
|
||||||
|
cache_ram = cache_args.get("ram", 16.0)
|
||||||
|
self.init_ram_cache(cache_ram)
|
||||||
|
logging.info("Using RAM pressure cache.")
|
||||||
elif cache_type == CacheType.LRU:
|
elif cache_type == CacheType.LRU:
|
||||||
if cache_size is None:
|
cache_size = cache_args.get("lru", 0)
|
||||||
cache_size = 0
|
|
||||||
self.init_lru_cache(cache_size)
|
self.init_lru_cache(cache_size)
|
||||||
logging.info("Using LRU cache")
|
logging.info("Using LRU cache")
|
||||||
else:
|
else:
|
||||||
self.init_classic_cache()
|
self.init_classic_cache()
|
||||||
|
|
||||||
self.all = [self.outputs, self.ui, self.objects]
|
self.all = [self.outputs, self.objects]
|
||||||
|
|
||||||
# Performs like the old cache -- dump data ASAP
|
# Performs like the old cache -- dump data ASAP
|
||||||
def init_classic_cache(self):
|
def init_classic_cache(self):
|
||||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_lru_cache(self, cache_size):
|
def init_lru_cache(self, cache_size):
|
||||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
|
def init_ram_cache(self, min_headroom):
|
||||||
|
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_null_cache(self):
|
def init_null_cache(self):
|
||||||
self.outputs = NullCache()
|
self.outputs = NullCache()
|
||||||
#The UI cache is expected to be iterable at the end of each workflow
|
|
||||||
#so it must cache at least a full workflow. Use Heirachical
|
|
||||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
|
||||||
self.objects = NullCache()
|
self.objects = NullCache()
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
def recursive_debug_dump(self):
|
||||||
result = {
|
result = {
|
||||||
"outputs": self.outputs.recursive_debug_dump(),
|
"outputs": self.outputs.recursive_debug_dump(),
|
||||||
"ui": self.ui.recursive_debug_dump(),
|
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -157,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
if execution_list is None:
|
if execution_list is None:
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue # This might be a lazily-evaluated input
|
continue # This might be a lazily-evaluated input
|
||||||
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
|
cached = execution_list.get_cache(input_unique_id, unique_id)
|
||||||
if cached_output is None:
|
if cached is None or cached.outputs is None:
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
if output_index >= len(cached_output):
|
if output_index >= len(cached.outputs):
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
obj = cached_output[output_index]
|
obj = cached.outputs[output_index]
|
||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
elif input_category is not None:
|
elif input_category is not None:
|
||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
@ -393,7 +401,7 @@ def format_value(x):
|
|||||||
else:
|
else:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
|
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||||
@ -401,12 +409,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
if caches.outputs.get(unique_id) is not None:
|
cached = caches.outputs.get(unique_id)
|
||||||
|
if cached is not None:
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
cached_output = caches.ui.get(unique_id) or {}
|
cached_ui = cached.ui or {}
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||||
|
if cached.ui is not None:
|
||||||
|
ui_outputs[unique_id] = cached.ui
|
||||||
get_progress_state().finish_progress(unique_id)
|
get_progress_state().finish_progress(unique_id)
|
||||||
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
|
execution_list.cache_update(unique_id, cached)
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
input_data_all = None
|
input_data_all = None
|
||||||
@ -436,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
for r in result:
|
for r in result:
|
||||||
if is_link(r):
|
if is_link(r):
|
||||||
source_node, source_output = r[0], r[1]
|
source_node, source_output = r[0], r[1]
|
||||||
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
|
node_cached = execution_list.get_cache(source_node, unique_id)
|
||||||
for o in node_output:
|
for o in node_cached.outputs[source_output]:
|
||||||
resolved_output.append(o)
|
resolved_output.append(o)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -507,7 +518,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
asyncio.create_task(await_completion())
|
asyncio.create_task(await_completion())
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
caches.ui.set(unique_id, {
|
ui_outputs[unique_id] = {
|
||||||
"meta": {
|
"meta": {
|
||||||
"node_id": unique_id,
|
"node_id": unique_id,
|
||||||
"display_node": display_node_id,
|
"display_node": display_node_id,
|
||||||
@ -515,7 +526,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
"real_node_id": real_node_id,
|
"real_node_id": real_node_id,
|
||||||
},
|
},
|
||||||
"output": output_ui
|
"output": output_ui
|
||||||
})
|
}
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||||
if has_subgraph:
|
if has_subgraph:
|
||||||
@ -554,8 +565,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
pending_subgraph_results[unique_id] = cached_outputs
|
pending_subgraph_results[unique_id] = cached_outputs
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
|
||||||
caches.outputs.set(unique_id, output_data)
|
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
||||||
execution_list.cache_update(unique_id, output_data)
|
execution_list.cache_update(unique_id, cache_entry)
|
||||||
|
caches.outputs.set(unique_id, cache_entry)
|
||||||
|
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
logging.info("Processing interrupted")
|
logging.info("Processing interrupted")
|
||||||
@ -600,14 +612,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server, cache_type=False, cache_size=None):
|
def __init__(self, server, cache_type=False, cache_args=None):
|
||||||
self.cache_size = cache_size
|
self.cache_args = cache_args
|
||||||
self.cache_type = cache_type
|
self.cache_type = cache_type
|
||||||
self.server = server
|
self.server = server
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
|
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.success = True
|
self.success = True
|
||||||
|
|
||||||
@ -682,6 +694,7 @@ class PromptExecutor:
|
|||||||
broadcast=False)
|
broadcast=False)
|
||||||
pending_subgraph_results = {}
|
pending_subgraph_results = {}
|
||||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||||
|
ui_node_outputs = {}
|
||||||
executed = set()
|
executed = set()
|
||||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||||
current_outputs = self.caches.outputs.all_node_ids()
|
current_outputs = self.caches.outputs.all_node_ids()
|
||||||
@ -695,7 +708,7 @@ class PromptExecutor:
|
|||||||
break
|
break
|
||||||
|
|
||||||
assert node_id is not None, "Node ID should not be None at this point"
|
assert node_id is not None, "Node ID should not be None at this point"
|
||||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||||
self.success = result != ExecutionResult.FAILURE
|
self.success = result != ExecutionResult.FAILURE
|
||||||
if result == ExecutionResult.FAILURE:
|
if result == ExecutionResult.FAILURE:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
@ -704,18 +717,16 @@ class PromptExecutor:
|
|||||||
execution_list.unstage_node_execution()
|
execution_list.unstage_node_execution()
|
||||||
else: # result == ExecutionResult.SUCCESS:
|
else: # result == ExecutionResult.SUCCESS:
|
||||||
execution_list.complete_node_execution()
|
execution_list.complete_node_execution()
|
||||||
|
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||||
else:
|
else:
|
||||||
# Only execute when the while-loop ends without break
|
# Only execute when the while-loop ends without break
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||||
|
|
||||||
ui_outputs = {}
|
ui_outputs = {}
|
||||||
meta_outputs = {}
|
meta_outputs = {}
|
||||||
all_node_ids = self.caches.ui.all_node_ids()
|
for node_id, ui_info in ui_node_outputs.items():
|
||||||
for node_id in all_node_ids:
|
ui_outputs[node_id] = ui_info["output"]
|
||||||
ui_info = self.caches.ui.get(node_id)
|
meta_outputs[node_id] = ui_info["meta"]
|
||||||
if ui_info is not None:
|
|
||||||
ui_outputs[node_id] = ui_info["output"]
|
|
||||||
meta_outputs[node_id] = ui_info["meta"]
|
|
||||||
self.history_result = {
|
self.history_result = {
|
||||||
"outputs": ui_outputs,
|
"outputs": ui_outputs,
|
||||||
"meta": meta_outputs,
|
"meta": meta_outputs,
|
||||||
|
|||||||
4
main.py
4
main.py
@ -172,10 +172,12 @@ def prompt_worker(q, server_instance):
|
|||||||
cache_type = execution.CacheType.CLASSIC
|
cache_type = execution.CacheType.CLASSIC
|
||||||
if args.cache_lru > 0:
|
if args.cache_lru > 0:
|
||||||
cache_type = execution.CacheType.LRU
|
cache_type = execution.CacheType.LRU
|
||||||
|
elif args.cache_ram > 0:
|
||||||
|
cache_type = execution.CacheType.RAM_PRESSURE
|
||||||
elif args.cache_none:
|
elif args.cache_none:
|
||||||
cache_type = execution.CacheType.NONE
|
cache_type = execution.CacheType.NONE
|
||||||
|
|
||||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -2329,6 +2329,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_model_patch.py",
|
"nodes_model_patch.py",
|
||||||
"nodes_easycache.py",
|
"nodes_easycache.py",
|
||||||
"nodes_audio_encoder.py",
|
"nodes_audio_encoder.py",
|
||||||
|
"nodes_rope.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -14,7 +14,7 @@ if not has_gpu():
|
|||||||
args.cpu = True
|
args.cpu = True
|
||||||
|
|
||||||
from comfy import ops
|
from comfy import ops
|
||||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
|
||||||
class SimpleModel(torch.nn.Module):
|
class SimpleModel(torch.nn.Module):
|
||||||
@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Layer 2 should NOT be quantized
|
# Layer 2 should NOT be quantized
|
||||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||||
|
|
||||||
# Layer 3 should be quantized
|
# Layer 3 should be quantized
|
||||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify scales were loaded
|
# Verify scales were loaded
|
||||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||||
@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify non-quantized layers are standard tensors
|
# Verify non-quantized layers are standard tensors
|
||||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||||
|
|||||||
@ -25,14 +25,14 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
scale = torch.tensor(2.0)
|
scale = torch.tensor(2.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
self.assertIsInstance(qt, QuantizedTensor)
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
self.assertEqual(qt.shape, (256, 128))
|
self.assertEqual(qt.shape, (256, 128))
|
||||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
self.assertEqual(qt._layout_params['scale'], scale)
|
self.assertEqual(qt._layout_params['scale'], scale)
|
||||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||||
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
def test_dequantize(self):
|
def test_dequantize(self):
|
||||||
"""Test explicit dequantization"""
|
"""Test explicit dequantization"""
|
||||||
@ -41,7 +41,7 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
scale = torch.tensor(3.0)
|
scale = torch.tensor(3.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
dequantized = qt.dequantize()
|
dequantized = qt.dequantize()
|
||||||
|
|
||||||
self.assertEqual(dequantized.dtype, torch.float32)
|
self.assertEqual(dequantized.dtype, torch.float32)
|
||||||
@ -54,7 +54,7 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
|
|
||||||
qt = QuantizedTensor.from_float(
|
qt = QuantizedTensor.from_float(
|
||||||
float_tensor,
|
float_tensor,
|
||||||
TensorCoreFP8Layout,
|
"TensorCoreFP8Layout",
|
||||||
scale=scale,
|
scale=scale,
|
||||||
dtype=torch.float8_e4m3fn
|
dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
@ -77,28 +77,28 @@ class TestGenericUtilities(unittest.TestCase):
|
|||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Detach should return a new QuantizedTensor
|
# Detach should return a new QuantizedTensor
|
||||||
qt_detached = qt.detach()
|
qt_detached = qt.detach()
|
||||||
|
|
||||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||||
self.assertEqual(qt_detached.shape, qt.shape)
|
self.assertEqual(qt_detached.shape, qt.shape)
|
||||||
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
def test_clone(self):
|
def test_clone(self):
|
||||||
"""Test clone operation on quantized tensor"""
|
"""Test clone operation on quantized tensor"""
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Clone should return a new QuantizedTensor
|
# Clone should return a new QuantizedTensor
|
||||||
qt_cloned = qt.clone()
|
qt_cloned = qt.clone()
|
||||||
|
|
||||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||||
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify it's a deep copy
|
# Verify it's a deep copy
|
||||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||||
@ -109,7 +109,7 @@ class TestGenericUtilities(unittest.TestCase):
|
|||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Moving to same device should work (CPU to CPU)
|
# Moving to same device should work (CPU to CPU)
|
||||||
qt_cpu = qt.to('cpu')
|
qt_cpu = qt.to('cpu')
|
||||||
@ -169,7 +169,7 @@ class TestFallbackMechanism(unittest.TestCase):
|
|||||||
scale = torch.tensor(1.0)
|
scale = torch.tensor(1.0)
|
||||||
a_q = QuantizedTensor.from_float(
|
a_q = QuantizedTensor.from_float(
|
||||||
a_fp32,
|
a_fp32,
|
||||||
TensorCoreFP8Layout,
|
"TensorCoreFP8Layout",
|
||||||
scale=scale,
|
scale=scale,
|
||||||
dtype=torch.float8_e4m3fn
|
dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user