Fix small performance regression with fp8 fast and scaled fp8. (#10537)

This commit is contained in:
comfyanonymous 2025-10-29 16:29:01 -07:00 committed by GitHub
parent 25de7b1bfa
commit 906c089957
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 3 deletions

View File

@ -421,14 +421,18 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
else:
scale_input = scale_input.to(input.device)
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
uncast_bias_weight(self, w, bias, offload_stream)

View File

@ -357,9 +357,10 @@ class TensorCoreFP8Layout(QuantizedLayout):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
lp_amax = torch.finfo(dtype).max
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
layout_params = {