[Misc] Fix linter issues in examples/fp8/quantizer/quantize.py (#3864)

This commit is contained in:
Cade Daniel 2024-04-04 21:57:33 -07:00 committed by GitHub
parent e5043a3e75
commit e0dd4d3589
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -131,7 +131,8 @@ def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None: if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
assert tokenizer.pad_token is not None, f"Pad token for {model_type} cannot be set!" assert (tokenizer.pad_token
is not None), f"Pad token for {model_type} cannot be set!"
return tokenizer return tokenizer
@ -158,9 +159,9 @@ def get_model(ckpt_path, dtype="fp16", device="cuda"):
model_dtype = next(model.parameters()).dtype model_dtype = next(model.parameters()).dtype
if dtype != model_dtype: if dtype != model_dtype:
print( print("[TensorRT-LLM][WARNING] The manually set model data type is "
f"[TensorRT-LLM][WARNING] The manually set model data type is {dtype}, " f"{dtype}, but the data type of the HuggingFace model is "
f"but the data type of the HuggingFace model is {model_dtype}.") f"{model_dtype}.")
return model return model
@ -244,15 +245,13 @@ def main(args):
else: else:
if "awq" in args.qformat: if "awq" in args.qformat:
if args.calib_size > 32: if args.calib_size > 32:
print( print("AWQ calibration could take longer with calib_size = "
f"AWQ calibration could take longer with calib_size = {args.calib_size}, Using" f"{args.calib_size}, Using calib_size=32 instead")
" calib_size=32 instead")
args.calib_size = 32 args.calib_size = 32
print( print("\nAWQ calibration could take longer than other calibration "
"\nAWQ calibration could take longer than other calibration methods. Please" "methods. Please increase the batch size to speed up the "
" increase the batch size to speed up the calibration process. Batch size can be" "calibration process. Batch size can be set by adding the "
" set by adding the argument --batch_size <batch_size> to the command line.\n" "argument --batch_size <batch_size> to the command line.\n")
)
calib_dataloader = get_calib_dataloader( calib_dataloader = get_calib_dataloader(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -287,9 +286,8 @@ def main(args):
with torch.inference_mode(): with torch.inference_mode():
if model_type is None: if model_type is None:
print( print(f"Unknown model type {type(model).__name__}. Continue "
f"Unknown model type {type(model).__name__}. Continue exporting..." "exporting...")
)
model_type = f"unknown:{type(model).__name__}" model_type = f"unknown:{type(model).__name__}"
export_path = args.output_dir export_path = args.output_dir