Fix torch compile regression on fp8 ops. (#10580)

This commit is contained in:
comfyanonymous 2025-10-31 21:25:17 -07:00 committed by GitHub
parent 7f374e42c8
commit c58c13b2ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 36 deletions

View File

@ -401,15 +401,9 @@ def fp8_linear(self, input):
if dtype not in [torch.float8_e4m3fn]: if dtype not in [torch.float8_e4m3fn]:
return None return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype input_dtype = input.dtype
if len(input.shape) == 3: if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = self.scale_weight scale_weight = self.scale_weight
@ -422,24 +416,20 @@ def fp8_linear(self, input):
if scale_input is None: if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input) input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight) quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else: else:
scale_input = scale_input.to(input.device) scale_input = scale_input.to(input.device)
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
# Wrap weight in QuantizedTensor - this enables unified dispatch # Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
uncast_bias_weight(self, w, bias, offload_stream) uncast_bias_weight(self, w, bias, offload_stream)
return o
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return None return None
@ -540,12 +530,12 @@ if CUBLAS_IS_AVAILABLE:
# ============================================================================== # ==============================================================================
# Mixed Precision Operations # Mixed Precision Operations
# ============================================================================== # ==============================================================================
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout from .quant_ops import QuantizedTensor
QUANT_FORMAT_MIXINS = { QUANT_FORMAT_MIXINS = {
"float8_e4m3fn": { "float8_e4m3fn": {
"dtype": torch.float8_e4m3fn, "dtype": torch.float8_e4m3fn,
"layout_type": TensorCoreFP8Layout, "layout_type": "TensorCoreFP8Layout",
"parameters": { "parameters": {
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),

View File

@ -123,7 +123,7 @@ class QuantizedTensor(torch.Tensor):
layout_type: Layout class (subclass of QuantizedLayout) layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters layout_params: Dict with layout-specific parameters
""" """
return torch.Tensor._make_subclass(cls, qdata, require_grad=False) return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params): def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata.contiguous() self._qdata = qdata.contiguous()
@ -183,11 +183,11 @@ class QuantizedTensor(torch.Tensor):
@classmethod @classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params) return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor: def dequantize(self) -> torch.Tensor:
return self._layout_type.dequantize(self._qdata, **self._layout_params) return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
@ -379,7 +379,12 @@ class TensorCoreFP8Layout(QuantizedLayout):
return qtensor._qdata, qtensor._layout_params['scale'] return qtensor._qdata, qtensor._layout_params['scale']
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs): def fp8_linear(func, args, kwargs):
input_tensor = args[0] input_tensor = args[0]
weight = args[1] weight = args[1]
@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs):
'scale': output_scale, 'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype'] 'orig_dtype': input_tensor._layout_params['orig_dtype']
} }
return QuantizedTensor(output, TensorCoreFP8Layout, output_params) return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else: else:
return output return output
@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs):
input_tensor = input_tensor.dequantize() input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias) return torch.nn.functional.linear(input_tensor, weight, bias)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@ -14,7 +14,7 @@ if not has_gpu():
args.cpu = True args.cpu = True
from comfy import ops from comfy import ops
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout from comfy.quant_ops import QuantizedTensor
class SimpleModel(torch.nn.Module): class SimpleModel(torch.nn.Module):
@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Verify weights are wrapped in QuantizedTensor # Verify weights are wrapped in QuantizedTensor
self.assertIsInstance(model.layer1.weight, QuantizedTensor) self.assertIsInstance(model.layer1.weight, QuantizedTensor)
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
# Layer 2 should NOT be quantized # Layer 2 should NOT be quantized
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
# Layer 3 should be quantized # Layer 3 should be quantized
self.assertIsInstance(model.layer3.weight, QuantizedTensor) self.assertIsInstance(model.layer3.weight, QuantizedTensor)
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
# Verify scales were loaded # Verify scales were loaded
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Verify layer1.weight is a QuantizedTensor with scale preserved # Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
# Verify non-quantized layers are standard tensors # Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)

View File

@ -25,14 +25,14 @@ class TestQuantizedTensor(unittest.TestCase):
scale = torch.tensor(2.0) scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
self.assertIsInstance(qt, QuantizedTensor) self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128)) self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale) self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, TensorCoreFP8Layout) self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
def test_dequantize(self): def test_dequantize(self):
"""Test explicit dequantization""" """Test explicit dequantization"""
@ -41,7 +41,7 @@ class TestQuantizedTensor(unittest.TestCase):
scale = torch.tensor(3.0) scale = torch.tensor(3.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
dequantized = qt.dequantize() dequantized = qt.dequantize()
self.assertEqual(dequantized.dtype, torch.float32) self.assertEqual(dequantized.dtype, torch.float32)
@ -54,7 +54,7 @@ class TestQuantizedTensor(unittest.TestCase):
qt = QuantizedTensor.from_float( qt = QuantizedTensor.from_float(
float_tensor, float_tensor,
TensorCoreFP8Layout, "TensorCoreFP8Layout",
scale=scale, scale=scale,
dtype=torch.float8_e4m3fn dtype=torch.float8_e4m3fn
) )
@ -77,28 +77,28 @@ class TestGenericUtilities(unittest.TestCase):
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5) scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Detach should return a new QuantizedTensor # Detach should return a new QuantizedTensor
qt_detached = qt.detach() qt_detached = qt.detach()
self.assertIsInstance(qt_detached, QuantizedTensor) self.assertIsInstance(qt_detached, QuantizedTensor)
self.assertEqual(qt_detached.shape, qt.shape) self.assertEqual(qt_detached.shape, qt.shape)
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
def test_clone(self): def test_clone(self):
"""Test clone operation on quantized tensor""" """Test clone operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5) scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Clone should return a new QuantizedTensor # Clone should return a new QuantizedTensor
qt_cloned = qt.clone() qt_cloned = qt.clone()
self.assertIsInstance(qt_cloned, QuantizedTensor) self.assertIsInstance(qt_cloned, QuantizedTensor)
self.assertEqual(qt_cloned.shape, qt.shape) self.assertEqual(qt_cloned.shape, qt.shape)
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
# Verify it's a deep copy # Verify it's a deep copy
self.assertIsNot(qt_cloned._qdata, qt._qdata) self.assertIsNot(qt_cloned._qdata, qt._qdata)
@ -109,7 +109,7 @@ class TestGenericUtilities(unittest.TestCase):
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5) scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Moving to same device should work (CPU to CPU) # Moving to same device should work (CPU to CPU)
qt_cpu = qt.to('cpu') qt_cpu = qt.to('cpu')
@ -169,7 +169,7 @@ class TestFallbackMechanism(unittest.TestCase):
scale = torch.tensor(1.0) scale = torch.tensor(1.0)
a_q = QuantizedTensor.from_float( a_q = QuantizedTensor.from_float(
a_fp32, a_fp32,
TensorCoreFP8Layout, "TensorCoreFP8Layout",
scale=scale, scale=scale,
dtype=torch.float8_e4m3fn dtype=torch.float8_e4m3fn
) )