mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Quantized Ops fixes (#10715)
* offload support, bug fixes, remove mixins * add readme
This commit is contained in:
parent
8b0b93df51
commit
3b3ef9a77a
168
QUANTIZATION.md
Normal file
168
QUANTIZATION.md
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
# The Comfy guide to Quantization
|
||||||
|
|
||||||
|
|
||||||
|
## How does quantization work?
|
||||||
|
|
||||||
|
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
|
||||||
|
|
||||||
|
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
|
||||||
|
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
|
||||||
|
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
|
||||||
|
|
||||||
|
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
|
||||||
|
|
||||||
|
```
|
||||||
|
absmax = max(abs(tensor))
|
||||||
|
scale = amax / max_dynamic_range_low_precision
|
||||||
|
|
||||||
|
# Quantization
|
||||||
|
tensor_q = (tensor / scale).to(low_precision_dtype)
|
||||||
|
|
||||||
|
# De-Quantization
|
||||||
|
tensor_dq = tensor_q.to(fp16) * scale
|
||||||
|
|
||||||
|
tensor_dq ~ tensor
|
||||||
|
```
|
||||||
|
|
||||||
|
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
|
||||||
|
|
||||||
|
|
||||||
|
## Quantization in Comfy
|
||||||
|
|
||||||
|
```
|
||||||
|
QuantizedTensor (torch.Tensor subclass)
|
||||||
|
↓ __torch_dispatch__
|
||||||
|
Two-Level Registry (generic + layout handlers)
|
||||||
|
↓
|
||||||
|
MixedPrecisionOps + Metadata Detection
|
||||||
|
```
|
||||||
|
|
||||||
|
### Representation
|
||||||
|
|
||||||
|
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
|
||||||
|
|
||||||
|
A `Layout` class defines how a specific quantization format behaves:
|
||||||
|
- Required parameters
|
||||||
|
- Quantize method
|
||||||
|
- De-Quantize method
|
||||||
|
|
||||||
|
```python
|
||||||
|
from comfy.quant_ops import QuantizedLayout
|
||||||
|
|
||||||
|
class MyLayout(QuantizedLayout):
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, **kwargs):
|
||||||
|
# Convert to quantized format
|
||||||
|
qdata = ...
|
||||||
|
params = {'scale': ..., 'orig_dtype': tensor.dtype}
|
||||||
|
return qdata, params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||||
|
return qdata.to(orig_dtype) * scale
|
||||||
|
```
|
||||||
|
|
||||||
|
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
|
||||||
|
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
|
||||||
|
|
||||||
|
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
|
||||||
|
```python
|
||||||
|
from comfy.quant_ops import register_layout_op
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
|
||||||
|
def my_linear(func, args, kwargs):
|
||||||
|
# Extract tensors, call optimized kernel
|
||||||
|
...
|
||||||
|
```
|
||||||
|
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
|
||||||
|
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
|
||||||
|
|
||||||
|
|
||||||
|
### Mixed Precision
|
||||||
|
|
||||||
|
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
|
||||||
|
|
||||||
|
**Architecture:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MixedPrecisionOps(disable_weight_init):
|
||||||
|
_layer_quant_config = {} # Maps layer names to quantization configs
|
||||||
|
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key mechanism:**
|
||||||
|
|
||||||
|
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
|
||||||
|
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
|
||||||
|
- If the layer name **is** in `_layer_quant_config`:
|
||||||
|
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
|
||||||
|
- Load associated quantization parameters (scales, block_size, etc.)
|
||||||
|
|
||||||
|
**Why it's needed:**
|
||||||
|
|
||||||
|
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
|
||||||
|
|
||||||
|
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
|
||||||
|
|
||||||
|
|
||||||
|
## Checkpoint Format
|
||||||
|
|
||||||
|
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
|
||||||
|
|
||||||
|
The quantized checkpoint will contain the same layers as the original checkpoint but:
|
||||||
|
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
|
||||||
|
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
|
||||||
|
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
|
||||||
|
|
||||||
|
### Scaling Parameters details
|
||||||
|
We define 4 possible scaling parameters that should cover most recipes in the near-future:
|
||||||
|
- **weight_scale**: quantization scalers for the weights
|
||||||
|
- **weight_scale_2**: global scalers in the context of double scaling
|
||||||
|
- **pre_quant_scale**: scalers used for smoothing salient weights
|
||||||
|
- **input_scale**: quantization scalers for the activations
|
||||||
|
|
||||||
|
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
||||||
|
|--------|---------------|--------------|----------------|-----------------|-------------|
|
||||||
|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
||||||
|
|
||||||
|
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
||||||
|
|
||||||
|
### Quantization Metadata
|
||||||
|
|
||||||
|
The metadata stored alongside the checkpoint contains:
|
||||||
|
- **format_version**: String to define a version of the standard
|
||||||
|
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"_quantization_metadata": {
|
||||||
|
"format_version": "1.0",
|
||||||
|
"layers": {
|
||||||
|
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
||||||
|
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
||||||
|
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Creating Quantized Checkpoints
|
||||||
|
|
||||||
|
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
|
||||||
|
|
||||||
|
### Weight Quantization
|
||||||
|
|
||||||
|
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
|
||||||
|
|
||||||
|
### Calibration (for Activation Quantization)
|
||||||
|
|
||||||
|
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
|
||||||
|
|
||||||
|
1. **Collect statistics**: Run inference on N representative samples
|
||||||
|
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
|
||||||
|
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||||
|
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||||
|
|
||||||
|
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||||
37
comfy/ops.py
37
comfy/ops.py
@ -77,7 +77,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
# will add async-offload support to your cast and improve performance.
|
# will add async-offload support to your cast and improve performance.
|
||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = input.dtype
|
if isinstance(input, QuantizedTensor):
|
||||||
|
dtype = input._layout_params["orig_dtype"]
|
||||||
|
else:
|
||||||
|
dtype = input.dtype
|
||||||
if bias_dtype is None:
|
if bias_dtype is None:
|
||||||
bias_dtype = dtype
|
bias_dtype = dtype
|
||||||
if device is None:
|
if device is None:
|
||||||
@ -534,18 +537,7 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from .quant_ops import QuantizedTensor
|
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||||
|
|
||||||
QUANT_FORMAT_MIXINS = {
|
|
||||||
"float8_e4m3fn": {
|
|
||||||
"dtype": torch.float8_e4m3fn,
|
|
||||||
"layout_type": "TensorCoreFP8Layout",
|
|
||||||
"parameters": {
|
|
||||||
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
|
||||||
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class MixedPrecisionOps(disable_weight_init):
|
class MixedPrecisionOps(disable_weight_init):
|
||||||
_layer_quant_config = {}
|
_layer_quant_config = {}
|
||||||
@ -596,23 +588,24 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
if quant_format is None:
|
if quant_format is None:
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
qconfig = QUANT_ALGOS[quant_format]
|
||||||
self.layout_type = mixin["layout_type"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
|
||||||
scale_key = f"{prefix}weight_scale"
|
weight_scale_key = f"{prefix}weight_scale"
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': state_dict.pop(scale_key, None),
|
'scale': state_dict.pop(weight_scale_key, None),
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype
|
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||||
|
'block_size': qconfig.get("group_size", None),
|
||||||
}
|
}
|
||||||
if layout_params['scale'] is not None:
|
if layout_params['scale'] is not None:
|
||||||
manually_loaded_keys.append(scale_key)
|
manually_loaded_keys.append(weight_scale_key)
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
for param_name, param_value in mixin["parameters"].items():
|
for param_name in qconfig["parameters"]:
|
||||||
param_key = f"{prefix}{param_name}"
|
param_key = f"{prefix}{param_name}"
|
||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
if _v is None:
|
if _v is None:
|
||||||
@ -643,7 +636,7 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
if (getattr(self, 'layout_type', None) is not None and
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
getattr(self, 'input_scale', None) is not None and
|
getattr(self, 'input_scale', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||||
return self._forward(input, self.weight, self.bias)
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -74,6 +74,12 @@ def _copy_layout_params(params):
|
|||||||
new_params[k] = v
|
new_params[k] = v
|
||||||
return new_params
|
return new_params
|
||||||
|
|
||||||
|
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
||||||
|
for k, v in src.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
dst[k].copy_(v, non_blocking=non_blocking)
|
||||||
|
else:
|
||||||
|
dst[k] = v
|
||||||
|
|
||||||
class QuantizedLayout:
|
class QuantizedLayout:
|
||||||
"""
|
"""
|
||||||
@ -318,13 +324,13 @@ def generic_to_dtype_layout(func, args, kwargs):
|
|||||||
def generic_copy_(func, args, kwargs):
|
def generic_copy_(func, args, kwargs):
|
||||||
qt_dest = args[0]
|
qt_dest = args[0]
|
||||||
src = args[1]
|
src = args[1]
|
||||||
|
non_blocking = args[2] if len(args) > 2 else False
|
||||||
if isinstance(qt_dest, QuantizedTensor):
|
if isinstance(qt_dest, QuantizedTensor):
|
||||||
if isinstance(src, QuantizedTensor):
|
if isinstance(src, QuantizedTensor):
|
||||||
# Copy from another quantized tensor
|
# Copy from another quantized tensor
|
||||||
qt_dest._qdata.copy_(src._qdata)
|
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||||
qt_dest._layout_type = src._layout_type
|
qt_dest._layout_type = src._layout_type
|
||||||
qt_dest._layout_params = _copy_layout_params(src._layout_params)
|
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
# Copy from regular tensor - just copy raw data
|
# Copy from regular tensor - just copy raw data
|
||||||
qt_dest._qdata.copy_(src)
|
qt_dest._qdata.copy_(src)
|
||||||
@ -336,6 +342,26 @@ def generic_copy_(func, args, kwargs):
|
|||||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.empty_like.default)
|
||||||
|
def generic_empty_like(func, args, kwargs):
|
||||||
|
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
# Create empty tensor with same shape and dtype as the quantized data
|
||||||
|
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
||||||
|
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
||||||
|
|
||||||
|
# Handle device transfer for layout params
|
||||||
|
target_device = kwargs.get('device', new_qdata.device)
|
||||||
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||||
|
|
||||||
|
# Update orig_dtype if dtype is specified
|
||||||
|
new_params['orig_dtype'] = hp_dtype
|
||||||
|
|
||||||
|
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# FP8 Layout + Operation Handlers
|
# FP8 Layout + Operation Handlers
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -378,6 +404,13 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
def get_plain_tensors(cls, qtensor):
|
def get_plain_tensors(cls, qtensor):
|
||||||
return qtensor._qdata, qtensor._layout_params['scale']
|
return qtensor._qdata, qtensor._layout_params['scale']
|
||||||
|
|
||||||
|
QUANT_ALGOS = {
|
||||||
|
"float8_e4m3fn": {
|
||||||
|
"storage_t": torch.float8_e4m3fn,
|
||||||
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
LAYOUTS = {
|
LAYOUTS = {
|
||||||
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user