diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 69e6a56106293..47aa35ed25b00 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -153,11 +153,11 @@ def _initialize_model( return model_class(**kwargs) -def _process_attention_weights_after_loading( - model: nn.Module, model_config: ModelConfig) -> None: +def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig, + target_device: torch.device) -> None: # Currently only used by MLA. - # NOTE: This intentionally happens before other modules so - # we can easily decompress the weights for MLA. + # NOTE: This intentionally happens before other modules so we can easily + # decompress the weights for MLA. for _, module in model.named_modules(): if isinstance(module, Attention) and \ hasattr(module, "process_weights_after_loading"): @@ -165,6 +165,17 @@ def _process_attention_weights_after_loading( # of process_weights_after_loading module.process_weights_after_loading(model_config.dtype) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + class BaseModelLoader(ABC): """Base class for model loaders.""" @@ -389,7 +400,6 @@ class DefaultModelLoader(BaseModelLoader): def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config - target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: @@ -407,18 +417,8 @@ class DefaultModelLoader(BaseModelLoader): "Following weights were not initialized from " f"checkpoint: {weights_not_loaded}") - _process_attention_weights_after_loading(model, model_config) + _process_weights_after_loading(model, model_config, target_device) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if isinstance(quant_method, QuantizeMethodBase): - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) return model.eval() @@ -437,26 +437,15 @@ class DummyModelLoader(BaseModelLoader): def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with target_device: model = _initialize_model(vllm_config=vllm_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) - _process_attention_weights_after_loading(model, model_config) - - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context( - module, torch.device(device_config.device)): - quant_method.process_weights_after_loading(module) + _process_weights_after_loading(model, model_config, target_device) return model.eval() @@ -637,6 +626,7 @@ class ShardedStateLoader(BaseModelLoader): def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config + target_device = torch.device(device_config.device) from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank @@ -645,13 +635,10 @@ class ShardedStateLoader(BaseModelLoader): model_config.revision) with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with target_device: model = _initialize_model(vllm_config=vllm_config) - _process_attention_weights_after_loading(model, model_config) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) + _process_weights_after_loading(model, model_config, + target_device) rank = get_tensor_model_parallel_rank() pattern = os.path.join( local_model_path, @@ -1401,12 +1388,7 @@ class RunaiModelStreamerLoader(BaseModelLoader): self._get_weights_iterator(model_weights, model_config.revision)) - _process_attention_weights_after_loading(model, model_config) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) + _process_weights_after_loading(model, model_config, target_device) return model.eval()