mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-14 16:34:36 +08:00
More fp8 torch.compile regressions fixed. (#10625)
This commit is contained in:
parent
0f4ef3afa0
commit
af4b7b5edb
@ -446,15 +446,7 @@ def fp8_linear(func, args, kwargs):
|
|||||||
|
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|
||||||
|
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
||||||
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
|
||||||
def fp8_addmm(func, args, kwargs):
|
|
||||||
input_tensor = args[1]
|
|
||||||
weight = args[2]
|
|
||||||
bias = args[0]
|
|
||||||
|
|
||||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
||||||
out_dtype = kwargs.get("out_dtype")
|
|
||||||
if out_dtype is None:
|
if out_dtype is None:
|
||||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||||
|
|
||||||
@ -474,6 +466,15 @@ def fp8_addmm(func, args, kwargs):
|
|||||||
output = output[0]
|
output = output[0]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_addmm(func, args, kwargs):
|
||||||
|
input_tensor = args[1]
|
||||||
|
weight = args[2]
|
||||||
|
bias = args[0]
|
||||||
|
|
||||||
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
|
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
||||||
|
|
||||||
a = list(args)
|
a = list(args)
|
||||||
if isinstance(args[0], QuantizedTensor):
|
if isinstance(args[0], QuantizedTensor):
|
||||||
a[0] = args[0].dequantize()
|
a[0] = args[0].dequantize()
|
||||||
@ -484,6 +485,21 @@ def fp8_addmm(func, args, kwargs):
|
|||||||
|
|
||||||
return func(*a, **kwargs)
|
return func(*a, **kwargs)
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_mm(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
weight = args[1]
|
||||||
|
|
||||||
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
|
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
||||||
|
|
||||||
|
a = list(args)
|
||||||
|
if isinstance(args[0], QuantizedTensor):
|
||||||
|
a[0] = args[0].dequantize()
|
||||||
|
if isinstance(args[1], QuantizedTensor):
|
||||||
|
a[1] = args[1].dequantize()
|
||||||
|
return func(*a, **kwargs)
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||||
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||||
def fp8_func(func, args, kwargs):
|
def fp8_func(func, args, kwargs):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user