From 65ecf487ad134e521cc6fe93a370fbaa5d989d92 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 12 Nov 2025 07:43:16 +0000 Subject: [PATCH] optional input scales Signed-off-by: vllmellm --- tests/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 1fc8b260d1e9b..848c4efa8bcde 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1325,7 +1325,8 @@ class TestFP8Layer(torch.nn.Module): weight_quant_key (QuantKey): Key for weight quantization configuration. weight (torch.Tensor): Weight tensor for linear transformation. weight_scale (torch.Tensor): Per-tensor or per-group scale for weights. - input_scale (torch.Tensor): Scale tensor for input quantization. + input_scale (torch.Tensor, optional): Scale tensor for input quantization. + Defaults to None. out_dtype (torch.dtype, optional): Output tensor data type. Defaults to torch.get_default_dtype(). """ @@ -1336,7 +1337,7 @@ class TestFP8Layer(torch.nn.Module): weight_quant_key: QuantKey, weight: torch.Tensor, weight_scale: torch.Tensor, - input_scale: torch.Tensor, + input_scale: torch.Tensor | None = None, out_dtype: torch.dtype | None = None, force_kernel: FP8ScaledMMLinearKernel | None = None, ):