diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 215f35bad34d9..51900de1cc099 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1274,8 +1274,7 @@ def scaled_fp8_quant( scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 and num_token_padding is None) + assert scale.numel() == 1 torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale