mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
191 lines
6.8 KiB
Python
191 lines
6.8 KiB
Python
import unittest
|
|
import torch
|
|
import sys
|
|
import os
|
|
|
|
# Add comfy to path
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
|
|
def has_gpu():
|
|
return torch.cuda.is_available()
|
|
|
|
from comfy.cli_args import args
|
|
if not has_gpu():
|
|
args.cpu = True
|
|
|
|
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
|
|
|
|
|
class TestQuantizedTensor(unittest.TestCase):
|
|
"""Test the QuantizedTensor subclass with FP8 layout"""
|
|
|
|
def test_creation(self):
|
|
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
|
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(2.0)
|
|
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
|
|
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
|
|
self.assertIsInstance(qt, QuantizedTensor)
|
|
self.assertEqual(qt.shape, (256, 128))
|
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
|
self.assertEqual(qt._layout_params['scale'], scale)
|
|
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
|
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
|
|
|
def test_dequantize(self):
|
|
"""Test explicit dequantization"""
|
|
|
|
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(3.0)
|
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
|
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
dequantized = qt.dequantize()
|
|
|
|
self.assertEqual(dequantized.dtype, torch.float32)
|
|
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
|
|
|
def test_from_float(self):
|
|
"""Test creating QuantizedTensor from float tensor"""
|
|
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
|
scale = torch.tensor(1.5)
|
|
|
|
qt = QuantizedTensor.from_float(
|
|
float_tensor,
|
|
"TensorCoreFP8Layout",
|
|
scale=scale,
|
|
dtype=torch.float8_e4m3fn
|
|
)
|
|
|
|
self.assertIsInstance(qt, QuantizedTensor)
|
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
|
self.assertEqual(qt.shape, (64, 32))
|
|
|
|
# Verify dequantization gives approximately original values
|
|
dequantized = qt.dequantize()
|
|
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
|
self.assertLess(mean_rel_error, 0.1)
|
|
|
|
|
|
class TestGenericUtilities(unittest.TestCase):
|
|
"""Test generic utility operations"""
|
|
|
|
def test_detach(self):
|
|
"""Test detach operation on quantized tensor"""
|
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(1.5)
|
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
|
|
# Detach should return a new QuantizedTensor
|
|
qt_detached = qt.detach()
|
|
|
|
self.assertIsInstance(qt_detached, QuantizedTensor)
|
|
self.assertEqual(qt_detached.shape, qt.shape)
|
|
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
|
|
|
|
def test_clone(self):
|
|
"""Test clone operation on quantized tensor"""
|
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(1.5)
|
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
|
|
# Clone should return a new QuantizedTensor
|
|
qt_cloned = qt.clone()
|
|
|
|
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
|
self.assertEqual(qt_cloned.shape, qt.shape)
|
|
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
|
|
|
|
# Verify it's a deep copy
|
|
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
|
|
|
@unittest.skipUnless(has_gpu(), "GPU not available")
|
|
def test_to_device(self):
|
|
"""Test device transfer"""
|
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(1.5)
|
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
|
|
# Moving to same device should work (CPU to CPU)
|
|
qt_cpu = qt.to('cpu')
|
|
|
|
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
|
self.assertEqual(qt_cpu.device.type, 'cpu')
|
|
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
|
|
|
|
|
class TestTensorCoreFP8Layout(unittest.TestCase):
|
|
"""Test the TensorCoreFP8Layout implementation"""
|
|
|
|
def test_quantize(self):
|
|
"""Test quantization method"""
|
|
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
|
scale = torch.tensor(1.5)
|
|
|
|
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
|
float_tensor,
|
|
scale=scale,
|
|
dtype=torch.float8_e4m3fn
|
|
)
|
|
|
|
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
|
self.assertEqual(qdata.shape, float_tensor.shape)
|
|
self.assertIn('scale', layout_params)
|
|
self.assertIn('orig_dtype', layout_params)
|
|
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
|
|
|
def test_dequantize(self):
|
|
"""Test dequantization method"""
|
|
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
|
scale = torch.tensor(1.0)
|
|
|
|
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
|
float_tensor,
|
|
scale=scale,
|
|
dtype=torch.float8_e4m3fn
|
|
)
|
|
|
|
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
|
|
|
# Should approximately match original
|
|
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
|
|
|
|
|
class TestFallbackMechanism(unittest.TestCase):
|
|
"""Test fallback for unsupported operations"""
|
|
|
|
def test_unsupported_op_dequantizes(self):
|
|
"""Test that unsupported operations fall back to dequantization"""
|
|
# Set seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Create quantized tensor
|
|
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
|
scale = torch.tensor(1.0)
|
|
a_q = QuantizedTensor.from_float(
|
|
a_fp32,
|
|
"TensorCoreFP8Layout",
|
|
scale=scale,
|
|
dtype=torch.float8_e4m3fn
|
|
)
|
|
|
|
# Call an operation that doesn't have a registered handler
|
|
# For example, torch.abs
|
|
result = torch.abs(a_q)
|
|
|
|
# Should work via fallback (dequantize → abs → return)
|
|
self.assertNotIsInstance(result, QuantizedTensor)
|
|
expected = torch.abs(a_fp32)
|
|
# FP8 introduces quantization error, so use loose tolerance
|
|
mean_error = (result - expected).abs().mean()
|
|
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|