Bring back fp8 torch compile performance to what it should be. (#10622)

This commit is contained in:
comfyanonymous 2025-11-03 16:22:10 -08:00 committed by GitHub
parent e199c8cc67
commit 6b88478f9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -126,7 +126,7 @@ class QuantizedTensor(torch.Tensor):
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata.contiguous()
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
@ -411,7 +411,7 @@ def fp8_linear(func, args, kwargs):
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]),
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
@ -447,6 +447,43 @@ def fp8_linear(func, args, kwargs):
return torch.nn.functional.linear(input_tensor, weight, bias)
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):