mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:18:39 +08:00
[Misc] Fix linter issues in examples/fp8/quantizer/quantize.py (#3864)
This commit is contained in:
parent
e5043a3e75
commit
e0dd4d3589
@ -1,4 +1,4 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -131,7 +131,8 @@ def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
assert tokenizer.pad_token is not None, f"Pad token for {model_type} cannot be set!"
|
assert (tokenizer.pad_token
|
||||||
|
is not None), f"Pad token for {model_type} cannot be set!"
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
@ -158,9 +159,9 @@ def get_model(ckpt_path, dtype="fp16", device="cuda"):
|
|||||||
|
|
||||||
model_dtype = next(model.parameters()).dtype
|
model_dtype = next(model.parameters()).dtype
|
||||||
if dtype != model_dtype:
|
if dtype != model_dtype:
|
||||||
print(
|
print("[TensorRT-LLM][WARNING] The manually set model data type is "
|
||||||
f"[TensorRT-LLM][WARNING] The manually set model data type is {dtype}, "
|
f"{dtype}, but the data type of the HuggingFace model is "
|
||||||
f"but the data type of the HuggingFace model is {model_dtype}.")
|
f"{model_dtype}.")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -244,15 +245,13 @@ def main(args):
|
|||||||
else:
|
else:
|
||||||
if "awq" in args.qformat:
|
if "awq" in args.qformat:
|
||||||
if args.calib_size > 32:
|
if args.calib_size > 32:
|
||||||
print(
|
print("AWQ calibration could take longer with calib_size = "
|
||||||
f"AWQ calibration could take longer with calib_size = {args.calib_size}, Using"
|
f"{args.calib_size}, Using calib_size=32 instead")
|
||||||
" calib_size=32 instead")
|
|
||||||
args.calib_size = 32
|
args.calib_size = 32
|
||||||
print(
|
print("\nAWQ calibration could take longer than other calibration "
|
||||||
"\nAWQ calibration could take longer than other calibration methods. Please"
|
"methods. Please increase the batch size to speed up the "
|
||||||
" increase the batch size to speed up the calibration process. Batch size can be"
|
"calibration process. Batch size can be set by adding the "
|
||||||
" set by adding the argument --batch_size <batch_size> to the command line.\n"
|
"argument --batch_size <batch_size> to the command line.\n")
|
||||||
)
|
|
||||||
|
|
||||||
calib_dataloader = get_calib_dataloader(
|
calib_dataloader = get_calib_dataloader(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -287,9 +286,8 @@ def main(args):
|
|||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
print(
|
print(f"Unknown model type {type(model).__name__}. Continue "
|
||||||
f"Unknown model type {type(model).__name__}. Continue exporting..."
|
"exporting...")
|
||||||
)
|
|
||||||
model_type = f"unknown:{type(model).__name__}"
|
model_type = f"unknown:{type(model).__name__}"
|
||||||
|
|
||||||
export_path = args.output_dir
|
export_path = args.output_dir
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user