mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
More fp8 torch.compile regressions fixed. (#10625)
This commit is contained in:
parent
0f4ef3afa0
commit
af4b7b5edb
@ -446,6 +446,25 @@ def fp8_linear(func, args, kwargs):
|
|||||||
|
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|
||||||
|
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
||||||
|
if out_dtype is None:
|
||||||
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||||
|
|
||||||
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||||
|
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
plain_input.contiguous(),
|
||||||
|
plain_weight,
|
||||||
|
bias=bias,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||||
|
output = output[0]
|
||||||
|
return output
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
||||||
def fp8_addmm(func, args, kwargs):
|
def fp8_addmm(func, args, kwargs):
|
||||||
@ -454,25 +473,7 @@ def fp8_addmm(func, args, kwargs):
|
|||||||
bias = args[0]
|
bias = args[0]
|
||||||
|
|
||||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
out_dtype = kwargs.get("out_dtype")
|
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
||||||
if out_dtype is None:
|
|
||||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
|
||||||
|
|
||||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
|
||||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
|
||||||
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
plain_input.contiguous(),
|
|
||||||
plain_weight,
|
|
||||||
bias=bias,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
|
||||||
output = output[0]
|
|
||||||
return output
|
|
||||||
|
|
||||||
a = list(args)
|
a = list(args)
|
||||||
if isinstance(args[0], QuantizedTensor):
|
if isinstance(args[0], QuantizedTensor):
|
||||||
@ -484,6 +485,21 @@ def fp8_addmm(func, args, kwargs):
|
|||||||
|
|
||||||
return func(*a, **kwargs)
|
return func(*a, **kwargs)
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_mm(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
weight = args[1]
|
||||||
|
|
||||||
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
|
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
||||||
|
|
||||||
|
a = list(args)
|
||||||
|
if isinstance(args[0], QuantizedTensor):
|
||||||
|
a[0] = args[0].dequantize()
|
||||||
|
if isinstance(args[1], QuantizedTensor):
|
||||||
|
a[1] = args[1].dequantize()
|
||||||
|
return func(*a, **kwargs)
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||||
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||||
def fp8_func(func, args, kwargs):
|
def fp8_func(func, args, kwargs):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user