[Bugfix] Fix for ROCM compressed tensor support (#11561)

This commit is contained in:
Selali 2024-12-27 12:12:11 -08:00 committed by GitHub
parent dde1fa18c9
commit ac79799403
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -41,10 +41,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
)
if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
@ -57,11 +59,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight = layer.weight
if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
@ -76,7 +80,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
raise ValueError(f"Unknown quantization strategy {self.strategy}")
# INPUT SCALE
if self.is_static_input_scheme:
if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else: