mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 07:15:01 +08:00
Fix ShardedStateLoader for vllm fp8 quantization (#7708)
This commit is contained in:
parent
a3fce56b88
commit
4f419c00a6
@ -579,6 +579,10 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
model = _initialize_model(model_config, self.load_config,
|
model = _initialize_model(model_config, self.load_config,
|
||||||
lora_config, cache_config)
|
lora_config, cache_config)
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
rank = get_tensor_model_parallel_rank()
|
rank = get_tensor_model_parallel_rank()
|
||||||
pattern = os.path.join(
|
pattern = os.path.join(
|
||||||
local_model_path,
|
local_model_path,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user