mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 05:35:02 +08:00
[Bugfix] Fix for ROCM compressed tensor support (#11561)
This commit is contained in:
parent
dde1fa18c9
commit
ac79799403
@ -41,10 +41,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
input_scale=layer.input_scale)
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
@ -57,11 +59,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_rocm():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale)
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
@ -76,7 +80,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user