mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-29 20:07:23 +08:00
[Bugfix][VLM] Make apply_fp8_linear work with >2D input (#9812)
This commit is contained in:
parent
64cb1cdc3f
commit
226688bd61
@ -96,21 +96,26 @@ def apply_fp8_linear(
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if cutlass_fp8_supported:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input,
|
||||
input_2d,
|
||||
input_scale,
|
||||
scale_ub=input_scale_ub,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
return ops.cutlass_scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
output = ops.cutlass_scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
return output.view(*output_shape)
|
||||
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
@ -119,7 +124,7 @@ def apply_fp8_linear(
|
||||
# for matrices with batch dimension > 16.
|
||||
# This could change in the future.
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input,
|
||||
input_2d,
|
||||
input_scale,
|
||||
num_token_padding=17,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
@ -138,8 +143,10 @@ def apply_fp8_linear(
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
return torch.narrow(output[0], 0, 0, input.shape[0])
|
||||
return torch.narrow(output, 0, 0, input.shape[0])
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0,
|
||||
input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
else:
|
||||
# Fallback for channelwise case, where we use unfused DQ
|
||||
@ -176,15 +183,15 @@ def apply_fp8_linear(
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, input.shape[0])
|
||||
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * weight_scale.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype)
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def apply_int8_linear(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user