mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
674 lines
28 KiB
Python
674 lines
28 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Stability AI
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
import torch
|
|
import logging
|
|
import comfy.model_management
|
|
from comfy.cli_args import args, PerformanceFeature
|
|
import comfy.float
|
|
import comfy.rmsnorm
|
|
import contextlib
|
|
import json
|
|
|
|
def run_every_op():
|
|
if torch.compiler.is_compiling():
|
|
return
|
|
|
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
|
|
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
|
|
|
|
|
try:
|
|
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
import inspect
|
|
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
|
SDPA_BACKEND_PRIORITY = [
|
|
SDPBackend.FLASH_ATTENTION,
|
|
SDPBackend.EFFICIENT_ATTENTION,
|
|
SDPBackend.MATH,
|
|
]
|
|
|
|
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
|
|
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
|
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
|
else:
|
|
logging.warning("Torch version too old to set sdpa backend priority.")
|
|
except (ModuleNotFoundError, TypeError):
|
|
logging.warning("Could not set sdpa backend priority.")
|
|
|
|
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
|
|
try:
|
|
if comfy.model_management.is_nvidia():
|
|
cudnn_version = torch.backends.cudnn.version()
|
|
if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
|
|
#TODO: change upper bound version once it's fixed'
|
|
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
|
|
logging.info("working around nvidia conv3d memory bug.")
|
|
except:
|
|
pass
|
|
|
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
|
|
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
|
|
|
|
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
|
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
|
# will add async-offload support to your cast and improve performance.
|
|
if input is not None:
|
|
if dtype is None:
|
|
if isinstance(input, QuantizedTensor):
|
|
dtype = input._layout_params["orig_dtype"]
|
|
else:
|
|
dtype = input.dtype
|
|
if bias_dtype is None:
|
|
bias_dtype = dtype
|
|
if device is None:
|
|
device = input.device
|
|
|
|
if offloadable and (device != s.weight.device or
|
|
(s.bias is not None and device != s.bias.device)):
|
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
|
else:
|
|
offload_stream = None
|
|
|
|
if offload_stream is not None:
|
|
wf_context = offload_stream
|
|
if hasattr(wf_context, "as_context"):
|
|
wf_context = wf_context.as_context(offload_stream)
|
|
else:
|
|
wf_context = contextlib.nullcontext()
|
|
|
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
|
|
|
weight_has_function = len(s.weight_function) > 0
|
|
bias_has_function = len(s.bias_function) > 0
|
|
|
|
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
|
|
|
bias = None
|
|
if s.bias is not None:
|
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
|
|
|
comfy.model_management.sync_stream(device, offload_stream)
|
|
|
|
bias_a = bias
|
|
weight_a = weight
|
|
|
|
if s.bias is not None:
|
|
for f in s.bias_function:
|
|
bias = f(bias)
|
|
|
|
if weight_has_function or weight.dtype != dtype:
|
|
weight = weight.to(dtype=dtype)
|
|
if isinstance(weight, QuantizedTensor):
|
|
weight = weight.dequantize()
|
|
for f in s.weight_function:
|
|
weight = f(weight)
|
|
|
|
if offloadable:
|
|
return weight, bias, (offload_stream, weight_a, bias_a)
|
|
else:
|
|
#Legacy function signature
|
|
return weight, bias
|
|
|
|
|
|
def uncast_bias_weight(s, weight, bias, offload_stream):
|
|
if offload_stream is None:
|
|
return
|
|
os, weight_a, bias_a = offload_stream
|
|
if os is None:
|
|
return
|
|
if weight_a is not None:
|
|
device = weight_a.device
|
|
else:
|
|
if bias_a is None:
|
|
return
|
|
device = bias_a.device
|
|
os.wait_stream(comfy.model_management.current_stream(device))
|
|
|
|
|
|
class CastWeightBiasOp:
|
|
comfy_cast_weights = False
|
|
weight_function = []
|
|
bias_function = []
|
|
|
|
class disable_weight_init:
|
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = torch.nn.functional.linear(input, weight, bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = self._conv_forward(input, weight, bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = self._conv_forward(input, weight, bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
|
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
|
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
|
if bias is not None:
|
|
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
|
|
return out
|
|
else:
|
|
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = self._conv_forward(input, weight, bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
if self.weight is not None:
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
else:
|
|
weight = None
|
|
bias = None
|
|
offload_stream = None
|
|
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
self.bias = None
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
if self.weight is not None:
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
else:
|
|
weight = None
|
|
bias = None
|
|
offload_stream = None
|
|
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
|
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input, output_size=None):
|
|
num_spatial_dims = 2
|
|
output_padding = self._output_padding(
|
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
|
num_spatial_dims, self.dilation)
|
|
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = torch.nn.functional.conv_transpose2d(
|
|
input, weight, bias, self.stride, self.padding,
|
|
output_padding, self.groups, self.dilation)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input, output_size=None):
|
|
num_spatial_dims = 1
|
|
output_padding = self._output_padding(
|
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
|
num_spatial_dims, self.dilation)
|
|
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = torch.nn.functional.conv_transpose1d(
|
|
input, weight, bias, self.stride, self.padding,
|
|
output_padding, self.groups, self.dilation)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
self.bias = None
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
|
output_dtype = out_dtype
|
|
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
|
out_dtype = None
|
|
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
|
|
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
if "out_dtype" in kwargs:
|
|
kwargs.pop("out_dtype")
|
|
return super().forward(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def conv_nd(s, dims, *args, **kwargs):
|
|
if dims == 2:
|
|
return s.Conv2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return s.Conv3d(*args, **kwargs)
|
|
else:
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
class manual_cast(disable_weight_init):
|
|
class Linear(disable_weight_init.Linear):
|
|
comfy_cast_weights = True
|
|
|
|
class Conv1d(disable_weight_init.Conv1d):
|
|
comfy_cast_weights = True
|
|
|
|
class Conv2d(disable_weight_init.Conv2d):
|
|
comfy_cast_weights = True
|
|
|
|
class Conv3d(disable_weight_init.Conv3d):
|
|
comfy_cast_weights = True
|
|
|
|
class GroupNorm(disable_weight_init.GroupNorm):
|
|
comfy_cast_weights = True
|
|
|
|
class LayerNorm(disable_weight_init.LayerNorm):
|
|
comfy_cast_weights = True
|
|
|
|
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
|
comfy_cast_weights = True
|
|
|
|
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
|
comfy_cast_weights = True
|
|
|
|
class RMSNorm(disable_weight_init.RMSNorm):
|
|
comfy_cast_weights = True
|
|
|
|
class Embedding(disable_weight_init.Embedding):
|
|
comfy_cast_weights = True
|
|
|
|
|
|
def fp8_linear(self, input):
|
|
"""
|
|
Legacy FP8 linear function for backward compatibility.
|
|
Uses QuantizedTensor subclass for dispatch.
|
|
"""
|
|
dtype = self.weight.dtype
|
|
if dtype not in [torch.float8_e4m3fn]:
|
|
return None
|
|
|
|
input_dtype = input.dtype
|
|
|
|
if input.ndim == 3 or input.ndim == 2:
|
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
|
|
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
|
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
|
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
|
|
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
|
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
|
|
|
uncast_bias_weight(self, w, bias, offload_stream)
|
|
return o
|
|
|
|
return None
|
|
|
|
class fp8_ops(manual_cast):
|
|
class Linear(manual_cast.Linear):
|
|
def reset_parameters(self):
|
|
self.scale_weight = None
|
|
self.scale_input = None
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
|
|
try:
|
|
out = fp8_linear(self, input)
|
|
if out is not None:
|
|
return out
|
|
except Exception as e:
|
|
logging.info("Exception during fp8 op: {}".format(e))
|
|
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = torch.nn.functional.linear(input, weight, bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
CUBLAS_IS_AVAILABLE = False
|
|
try:
|
|
from cublas_ops import CublasLinear
|
|
CUBLAS_IS_AVAILABLE = True
|
|
except ImportError:
|
|
pass
|
|
|
|
if CUBLAS_IS_AVAILABLE:
|
|
class cublas_ops(disable_weight_init):
|
|
class Linear(CublasLinear, disable_weight_init.Linear):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
return super().forward(input)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return super().forward(*args, **kwargs)
|
|
|
|
|
|
# ==============================================================================
|
|
# Mixed Precision Operations
|
|
# ==============================================================================
|
|
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
|
|
|
|
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
|
class MixedPrecisionOps(manual_cast):
|
|
_quant_config = quant_config
|
|
_compute_dtype = compute_dtype
|
|
_full_precision_mm = full_precision_mm
|
|
|
|
class Linear(torch.nn.Module, CastWeightBiasOp):
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
bias: bool = True,
|
|
device=None,
|
|
dtype=None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
|
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
|
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
if bias:
|
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
self.tensor_class = None
|
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
|
strict, missing_keys, unexpected_keys, error_msgs):
|
|
|
|
device = self.factory_kwargs["device"]
|
|
layer_name = prefix.rstrip('.')
|
|
weight_key = f"{prefix}weight"
|
|
weight = state_dict.pop(weight_key, None)
|
|
if weight is None:
|
|
raise ValueError(f"Missing weight for layer {layer_name}")
|
|
|
|
manually_loaded_keys = [weight_key]
|
|
|
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
|
if layer_conf is not None:
|
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
|
|
|
if layer_conf is None:
|
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
|
else:
|
|
self.quant_format = layer_conf.get("format", None)
|
|
if not self._full_precision_mm:
|
|
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
|
|
|
if self.quant_format is None:
|
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
|
|
|
qconfig = QUANT_ALGOS[self.quant_format]
|
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
|
|
|
weight_scale_key = f"{prefix}weight_scale"
|
|
scale = state_dict.pop(weight_scale_key, None)
|
|
if scale is not None:
|
|
scale = scale.to(device)
|
|
layout_params = {
|
|
'scale': scale,
|
|
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
|
'block_size': qconfig.get("group_size", None),
|
|
}
|
|
|
|
if scale is not None:
|
|
manually_loaded_keys.append(weight_scale_key)
|
|
|
|
self.weight = torch.nn.Parameter(
|
|
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
|
|
requires_grad=False
|
|
)
|
|
|
|
for param_name in qconfig["parameters"]:
|
|
param_key = f"{prefix}{param_name}"
|
|
_v = state_dict.pop(param_key, None)
|
|
if _v is None:
|
|
continue
|
|
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
|
manually_loaded_keys.append(param_key)
|
|
|
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
for key in manually_loaded_keys:
|
|
if key in missing_keys:
|
|
missing_keys.remove(key)
|
|
|
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
|
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
|
if isinstance(self.weight, QuantizedTensor):
|
|
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
|
|
quant_conf = {"format": self.quant_format}
|
|
if self._full_precision_mm:
|
|
quant_conf["full_precision_matrix_mult"] = True
|
|
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
|
|
return sd
|
|
|
|
def _forward(self, input, weight, bias):
|
|
return torch.nn.functional.linear(input, weight, bias)
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
x = self._forward(input, weight, bias)
|
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
return x
|
|
|
|
def forward(self, input, *args, **kwargs):
|
|
run_every_op()
|
|
|
|
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
|
if (getattr(self, 'layout_type', None) is not None and
|
|
not isinstance(input, QuantizedTensor)):
|
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
|
|
return self._forward(input, self.weight, self.bias)
|
|
|
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
|
if isinstance(weight, QuantizedTensor):
|
|
return weight.dequantize()
|
|
else:
|
|
return weight
|
|
|
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
|
if getattr(self, 'layout_type', None) is not None:
|
|
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
|
else:
|
|
weight = weight.to(self.weight.dtype)
|
|
if return_weight:
|
|
return weight
|
|
|
|
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
|
|
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
|
if recurse:
|
|
for module in self.children():
|
|
module._apply(fn)
|
|
|
|
for key, param in self._parameters.items():
|
|
if param is None:
|
|
continue
|
|
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
|
for key, buf in self._buffers.items():
|
|
if buf is not None:
|
|
self._buffers[key] = fn(buf)
|
|
return self
|
|
|
|
return MixedPrecisionOps
|
|
|
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
|
|
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
|
logging.info("Using mixed precision operations")
|
|
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
|
|
|
if (
|
|
fp8_compute and
|
|
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
|
|
not disable_fast_fp8
|
|
):
|
|
return fp8_ops
|
|
|
|
if (
|
|
PerformanceFeature.CublasOps in args.fast and
|
|
CUBLAS_IS_AVAILABLE and
|
|
weight_dtype == torch.float16 and
|
|
(compute_dtype == torch.float16 or compute_dtype is None)
|
|
):
|
|
logging.info("Using cublas ops")
|
|
return cublas_ops
|
|
|
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
|
return disable_weight_init
|
|
|
|
return manual_cast
|