mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Fix torch compile regression on fp8 ops. (#10580)
This commit is contained in:
parent
7f374e42c8
commit
c58c13b2ba
24
comfy/ops.py
24
comfy/ops.py
@ -401,15 +401,9 @@ def fp8_linear(self, input):
|
|||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
tensor_2d = False
|
|
||||||
if len(input.shape) == 2:
|
|
||||||
tensor_2d = True
|
|
||||||
input = input.unsqueeze(1)
|
|
||||||
|
|
||||||
input_shape = input.shape
|
|
||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if input.ndim == 3 or input.ndim == 2:
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
@ -422,24 +416,20 @@ def fp8_linear(self, input):
|
|||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||||
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
|
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
uncast_bias_weight(self, w, bias, offload_stream)
|
uncast_bias_weight(self, w, bias, offload_stream)
|
||||||
|
return o
|
||||||
if tensor_2d:
|
|
||||||
return o.reshape(input_shape[0], -1)
|
|
||||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -540,12 +530,12 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
from .quant_ops import QuantizedTensor
|
||||||
|
|
||||||
QUANT_FORMAT_MIXINS = {
|
QUANT_FORMAT_MIXINS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
"dtype": torch.float8_e4m3fn,
|
"dtype": torch.float8_e4m3fn,
|
||||||
"layout_type": TensorCoreFP8Layout,
|
"layout_type": "TensorCoreFP8Layout",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
|||||||
@ -123,7 +123,7 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
layout_type: Layout class (subclass of QuantizedLayout)
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
layout_params: Dict with layout-specific parameters
|
layout_params: Dict with layout-specific parameters
|
||||||
"""
|
"""
|
||||||
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
|
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||||
|
|
||||||
def __init__(self, qdata, layout_type, layout_params):
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
self._qdata = qdata.contiguous()
|
self._qdata = qdata.contiguous()
|
||||||
@ -183,11 +183,11 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||||
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
|
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
||||||
return cls(qdata, layout_type, layout_params)
|
return cls(qdata, layout_type, layout_params)
|
||||||
|
|
||||||
def dequantize(self) -> torch.Tensor:
|
def dequantize(self) -> torch.Tensor:
|
||||||
return self._layout_type.dequantize(self._qdata, **self._layout_params)
|
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
@ -379,7 +379,12 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
return qtensor._qdata, qtensor._layout_params['scale']
|
return qtensor._qdata, qtensor._layout_params['scale']
|
||||||
|
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
LAYOUTS = {
|
||||||
|
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
||||||
def fp8_linear(func, args, kwargs):
|
def fp8_linear(func, args, kwargs):
|
||||||
input_tensor = args[0]
|
input_tensor = args[0]
|
||||||
weight = args[1]
|
weight = args[1]
|
||||||
@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs):
|
|||||||
'scale': output_scale,
|
'scale': output_scale,
|
||||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||||
}
|
}
|
||||||
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
|
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs):
|
|||||||
input_tensor = input_tensor.dequantize()
|
input_tensor = input_tensor.dequantize()
|
||||||
|
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||||
|
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_func(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
|
ar = list(args)
|
||||||
|
ar[0] = plain_input
|
||||||
|
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|||||||
@ -14,7 +14,7 @@ if not has_gpu():
|
|||||||
args.cpu = True
|
args.cpu = True
|
||||||
|
|
||||||
from comfy import ops
|
from comfy import ops
|
||||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
|
||||||
class SimpleModel(torch.nn.Module):
|
class SimpleModel(torch.nn.Module):
|
||||||
@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Layer 2 should NOT be quantized
|
# Layer 2 should NOT be quantized
|
||||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||||
|
|
||||||
# Layer 3 should be quantized
|
# Layer 3 should be quantized
|
||||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify scales were loaded
|
# Verify scales were loaded
|
||||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||||
@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify non-quantized layers are standard tensors
|
# Verify non-quantized layers are standard tensors
|
||||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||||
|
|||||||
@ -25,14 +25,14 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
scale = torch.tensor(2.0)
|
scale = torch.tensor(2.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
self.assertIsInstance(qt, QuantizedTensor)
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
self.assertEqual(qt.shape, (256, 128))
|
self.assertEqual(qt.shape, (256, 128))
|
||||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
self.assertEqual(qt._layout_params['scale'], scale)
|
self.assertEqual(qt._layout_params['scale'], scale)
|
||||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||||
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
def test_dequantize(self):
|
def test_dequantize(self):
|
||||||
"""Test explicit dequantization"""
|
"""Test explicit dequantization"""
|
||||||
@ -41,7 +41,7 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
scale = torch.tensor(3.0)
|
scale = torch.tensor(3.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
dequantized = qt.dequantize()
|
dequantized = qt.dequantize()
|
||||||
|
|
||||||
self.assertEqual(dequantized.dtype, torch.float32)
|
self.assertEqual(dequantized.dtype, torch.float32)
|
||||||
@ -54,7 +54,7 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
|
|
||||||
qt = QuantizedTensor.from_float(
|
qt = QuantizedTensor.from_float(
|
||||||
float_tensor,
|
float_tensor,
|
||||||
TensorCoreFP8Layout,
|
"TensorCoreFP8Layout",
|
||||||
scale=scale,
|
scale=scale,
|
||||||
dtype=torch.float8_e4m3fn
|
dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
@ -77,28 +77,28 @@ class TestGenericUtilities(unittest.TestCase):
|
|||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Detach should return a new QuantizedTensor
|
# Detach should return a new QuantizedTensor
|
||||||
qt_detached = qt.detach()
|
qt_detached = qt.detach()
|
||||||
|
|
||||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||||
self.assertEqual(qt_detached.shape, qt.shape)
|
self.assertEqual(qt_detached.shape, qt.shape)
|
||||||
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
def test_clone(self):
|
def test_clone(self):
|
||||||
"""Test clone operation on quantized tensor"""
|
"""Test clone operation on quantized tensor"""
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Clone should return a new QuantizedTensor
|
# Clone should return a new QuantizedTensor
|
||||||
qt_cloned = qt.clone()
|
qt_cloned = qt.clone()
|
||||||
|
|
||||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||||
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify it's a deep copy
|
# Verify it's a deep copy
|
||||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||||
@ -109,7 +109,7 @@ class TestGenericUtilities(unittest.TestCase):
|
|||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Moving to same device should work (CPU to CPU)
|
# Moving to same device should work (CPU to CPU)
|
||||||
qt_cpu = qt.to('cpu')
|
qt_cpu = qt.to('cpu')
|
||||||
@ -169,7 +169,7 @@ class TestFallbackMechanism(unittest.TestCase):
|
|||||||
scale = torch.tensor(1.0)
|
scale = torch.tensor(1.0)
|
||||||
a_q = QuantizedTensor.from_float(
|
a_q = QuantizedTensor.from_float(
|
||||||
a_fp32,
|
a_fp32,
|
||||||
TensorCoreFP8Layout,
|
"TensorCoreFP8Layout",
|
||||||
scale=scale,
|
scale=scale,
|
||||||
dtype=torch.float8_e4m3fn
|
dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user