From 4f419c00a621eac24f954f2b7670cbd22eb232a8 Mon Sep 17 00:00:00 2001 From: Flex Wang Date: Thu, 22 Aug 2024 05:25:04 -0700 Subject: [PATCH] Fix ShardedStateLoader for vllm fp8 quantization (#7708) --- vllm/model_executor/model_loader/loader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index d0427fb9b16af..2f6cdbc6ce3e9 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -579,6 +579,10 @@ class ShardedStateLoader(BaseModelLoader): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, lora_config, cache_config) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) rank = get_tensor_model_parallel_rank() pattern = os.path.join( local_model_path,