From 109e15a335a20251cbefa0a81bf51cd7624eae27 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 1 May 2025 23:23:42 -0700 Subject: [PATCH] Add `pt_load_map_location` to allow loading to cuda (#16869) Signed-off-by: Jerry Zhang --- tests/quantization/test_torchao.py | 26 +++++++++++++++++++ tests/test_config.py | 16 +++++++++++- vllm/config.py | 10 +++++++ vllm/engine/arg_utils.py | 18 ++++++++++++- vllm/model_executor/model_loader/loader.py | 2 ++ .../model_loader/weight_utils.py | 5 +++- 6 files changed, 74 insertions(+), 3 deletions(-) diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index 314ec90e34f9..1a20228765e8 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -3,6 +3,7 @@ import importlib.metadata import importlib.util import pytest +import torch DTYPE = ["bfloat16"] @@ -21,5 +22,30 @@ def test_pre_quantized_model(vllm_runner): print(output) +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.parametrize( + "pt_load_map_location", + [ + "cuda:0", + # {"": "cuda"}, + ]) +def test_opt_125m_int4wo_model_loading_with_params(vllm_runner, + pt_load_map_location): + """ + Test loading roberta-base model with no lm_head. + """ + torch._dynamo.reset() + model_name = "jerryzh168/opt-125m-int4wo" + with vllm_runner(model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location=pt_load_map_location) as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + + assert output + print(output) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_config.py b/tests/test_config.py index f2155d954db0..7db95e3f6450 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,8 @@ from typing import Literal, Union import pytest -from vllm.config import ModelConfig, PoolerConfig, config, get_field +from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, + config, get_field) from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform @@ -410,3 +411,16 @@ def test_generation_config_loading(): override_generation_config=override_generation_config) assert model_config.get_diff_sampling_param() == override_generation_config + + +@pytest.mark.parametrize("pt_load_map_location", [ + "cuda", + { + "": "cuda" + }, +]) +def test_load_config_pt_load_map_location(pt_load_map_location): + load_config = LoadConfig(pt_load_map_location=pt_load_map_location) + config = VllmConfig(load_config=load_config) + + assert config.load_config.pt_load_map_location == pt_load_map_location diff --git a/vllm/config.py b/vllm/config.py index c2995cacaeb6..81e2460c2bbf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1564,6 +1564,16 @@ class LoadConfig: use_tqdm_on_load: bool = True """Whether to enable tqdm for showing progress bar when loading model weights.""" + pt_load_map_location: Union[str, dict[str, str]] = "cpu" + """ + pt_load_map_location: the map location for loading pytorch checkpoint, to + support loading checkpoints can only be loaded on certain devices like + "cuda", this is equivalent to {"": "cuda"}. Another supported format is + mapping from different devices like from GPU 1 to GPU 0: + {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings + in dictionary needs to be double quoted for json parsing. For more details, + see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html + """ def compute_hash(self) -> str: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3cafcb7c31f2..4ffc0b767e8c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -64,6 +64,13 @@ def optional_type( return _optional_type +def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: + if not re.match("^{.*}$", val): + return str(val) + else: + return optional_type(json.loads)(val) + + @deprecated( "Passing a JSON argument as a string containing comma separated key=value " "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON " @@ -187,6 +194,10 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name]["type"] = human_readable_int elif contains_type(type_hints, float): kwargs[name]["type"] = float + elif contains_type(type_hints, + dict) and (contains_type(type_hints, str) or any( + is_not_builtin(th) for th in type_hints)): + kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): # Dict arguments will always be optional kwargs[name]["type"] = optional_type(json.loads) @@ -371,6 +382,7 @@ class EngineArgs: reasoning_parser: str = DecodingConfig.reasoning_backend use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load + pt_load_map_location: str = LoadConfig.pt_load_map_location def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -491,6 +503,8 @@ class EngineArgs: type=str, default=None, help='Name or path of the QLoRA adapter.') + load_group.add_argument('--pt-load-map-location', + **load_kwargs["pt_load_map_location"]) # Guided decoding arguments guided_decoding_kwargs = get_kwargs(DecodingConfig) @@ -883,12 +897,14 @@ class EngineArgs: if self.quantization == "bitsandbytes": self.load_format = "bitsandbytes" + return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, use_tqdm_on_load=self.use_tqdm_on_load, + pt_load_map_location=self.pt_load_map_location, ) def create_speculative_config( @@ -1513,7 +1529,7 @@ def _warn_or_fallback(feature_name: str) -> bool: def human_readable_int(value): """Parse human-readable integers like '1k', '2M', etc. Including decimal values with decimal multipliers. - + Examples: - '1k' -> 1,000 - '1K' -> 1,024 diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index cb9100e35594..01f75db9ee86 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -384,6 +384,7 @@ class DefaultModelLoader(BaseModelLoader): weights_iterator = pt_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, ) if current_platform.is_tpu(): @@ -890,6 +891,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): iterator = pt_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, ) for org_name, param in iterator: # mapping weight names from transformers to vllm while preserving diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 37a8491cf63d..10bc55ca5f7d 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -502,6 +502,7 @@ def fastsafetensors_weights_iterator( def pt_weights_iterator( hf_weights_files: List[str], use_tqdm_on_load: bool, + pt_load_map_location: Union[str, dict[str, str]] = "cpu", ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" for bin_file in tqdm( @@ -510,7 +511,9 @@ def pt_weights_iterator( disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT, ): - state = torch.load(bin_file, map_location="cpu", weights_only=True) + state = torch.load(bin_file, + map_location=pt_load_map_location, + weights_only=True) yield from state.items() del state