Fix ShardedStateLoader for vllm fp8 quantization (#7708)

This commit is contained in:
Flex Wang 2024-08-22 05:25:04 -07:00 committed by GitHub
parent a3fce56b88
commit 4f419c00a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -579,6 +579,10 @@ class ShardedStateLoader(BaseModelLoader):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,