mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 07:05:01 +08:00
Add pt_load_map_location to allow loading to cuda (#16869)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
parent
f192ca90e6
commit
109e15a335
@ -3,6 +3,7 @@ import importlib.metadata
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
DTYPE = ["bfloat16"]
|
DTYPE = ["bfloat16"]
|
||||||
|
|
||||||
@ -21,5 +22,30 @@ def test_pre_quantized_model(vllm_runner):
|
|||||||
print(output)
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"pt_load_map_location",
|
||||||
|
[
|
||||||
|
"cuda:0",
|
||||||
|
# {"": "cuda"},
|
||||||
|
])
|
||||||
|
def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
|
||||||
|
pt_load_map_location):
|
||||||
|
"""
|
||||||
|
Test loading roberta-base model with no lm_head.
|
||||||
|
"""
|
||||||
|
torch._dynamo.reset()
|
||||||
|
model_name = "jerryzh168/opt-125m-int4wo"
|
||||||
|
with vllm_runner(model_name=model_name,
|
||||||
|
quantization="torchao",
|
||||||
|
dtype="bfloat16",
|
||||||
|
pt_load_map_location=pt_load_map_location) as llm:
|
||||||
|
output = llm.generate_greedy(["The capital of France is"],
|
||||||
|
max_tokens=32)
|
||||||
|
|
||||||
|
assert output
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
@ -5,7 +5,8 @@ from typing import Literal, Union
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import ModelConfig, PoolerConfig, config, get_field
|
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
|
||||||
|
config, get_field)
|
||||||
from vllm.model_executor.layers.pooler import PoolingType
|
from vllm.model_executor.layers.pooler import PoolingType
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -410,3 +411,16 @@ def test_generation_config_loading():
|
|||||||
override_generation_config=override_generation_config)
|
override_generation_config=override_generation_config)
|
||||||
|
|
||||||
assert model_config.get_diff_sampling_param() == override_generation_config
|
assert model_config.get_diff_sampling_param() == override_generation_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("pt_load_map_location", [
|
||||||
|
"cuda",
|
||||||
|
{
|
||||||
|
"": "cuda"
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_load_config_pt_load_map_location(pt_load_map_location):
|
||||||
|
load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
|
||||||
|
config = VllmConfig(load_config=load_config)
|
||||||
|
|
||||||
|
assert config.load_config.pt_load_map_location == pt_load_map_location
|
||||||
|
|||||||
@ -1564,6 +1564,16 @@ class LoadConfig:
|
|||||||
use_tqdm_on_load: bool = True
|
use_tqdm_on_load: bool = True
|
||||||
"""Whether to enable tqdm for showing progress bar when loading model
|
"""Whether to enable tqdm for showing progress bar when loading model
|
||||||
weights."""
|
weights."""
|
||||||
|
pt_load_map_location: Union[str, dict[str, str]] = "cpu"
|
||||||
|
"""
|
||||||
|
pt_load_map_location: the map location for loading pytorch checkpoint, to
|
||||||
|
support loading checkpoints can only be loaded on certain devices like
|
||||||
|
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
|
||||||
|
mapping from different devices like from GPU 1 to GPU 0:
|
||||||
|
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
|
||||||
|
in dictionary needs to be double quoted for json parsing. For more details,
|
||||||
|
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
|
||||||
|
"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -64,6 +64,13 @@ def optional_type(
|
|||||||
return _optional_type
|
return _optional_type
|
||||||
|
|
||||||
|
|
||||||
|
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
|
||||||
|
if not re.match("^{.*}$", val):
|
||||||
|
return str(val)
|
||||||
|
else:
|
||||||
|
return optional_type(json.loads)(val)
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
"Passing a JSON argument as a string containing comma separated key=value "
|
"Passing a JSON argument as a string containing comma separated key=value "
|
||||||
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
|
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
|
||||||
@ -187,6 +194,10 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
|||||||
kwargs[name]["type"] = human_readable_int
|
kwargs[name]["type"] = human_readable_int
|
||||||
elif contains_type(type_hints, float):
|
elif contains_type(type_hints, float):
|
||||||
kwargs[name]["type"] = float
|
kwargs[name]["type"] = float
|
||||||
|
elif contains_type(type_hints,
|
||||||
|
dict) and (contains_type(type_hints, str) or any(
|
||||||
|
is_not_builtin(th) for th in type_hints)):
|
||||||
|
kwargs[name]["type"] = union_dict_and_str
|
||||||
elif contains_type(type_hints, dict):
|
elif contains_type(type_hints, dict):
|
||||||
# Dict arguments will always be optional
|
# Dict arguments will always be optional
|
||||||
kwargs[name]["type"] = optional_type(json.loads)
|
kwargs[name]["type"] = optional_type(json.loads)
|
||||||
@ -371,6 +382,7 @@ class EngineArgs:
|
|||||||
reasoning_parser: str = DecodingConfig.reasoning_backend
|
reasoning_parser: str = DecodingConfig.reasoning_backend
|
||||||
|
|
||||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||||
|
pt_load_map_location: str = LoadConfig.pt_load_map_location
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# support `EngineArgs(compilation_config={...})`
|
# support `EngineArgs(compilation_config={...})`
|
||||||
@ -491,6 +503,8 @@ class EngineArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help='Name or path of the QLoRA adapter.')
|
help='Name or path of the QLoRA adapter.')
|
||||||
|
load_group.add_argument('--pt-load-map-location',
|
||||||
|
**load_kwargs["pt_load_map_location"])
|
||||||
|
|
||||||
# Guided decoding arguments
|
# Guided decoding arguments
|
||||||
guided_decoding_kwargs = get_kwargs(DecodingConfig)
|
guided_decoding_kwargs = get_kwargs(DecodingConfig)
|
||||||
@ -883,12 +897,14 @@ class EngineArgs:
|
|||||||
|
|
||||||
if self.quantization == "bitsandbytes":
|
if self.quantization == "bitsandbytes":
|
||||||
self.load_format = "bitsandbytes"
|
self.load_format = "bitsandbytes"
|
||||||
|
|
||||||
return LoadConfig(
|
return LoadConfig(
|
||||||
load_format=self.load_format,
|
load_format=self.load_format,
|
||||||
download_dir=self.download_dir,
|
download_dir=self.download_dir,
|
||||||
model_loader_extra_config=self.model_loader_extra_config,
|
model_loader_extra_config=self.model_loader_extra_config,
|
||||||
ignore_patterns=self.ignore_patterns,
|
ignore_patterns=self.ignore_patterns,
|
||||||
use_tqdm_on_load=self.use_tqdm_on_load,
|
use_tqdm_on_load=self.use_tqdm_on_load,
|
||||||
|
pt_load_map_location=self.pt_load_map_location,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_speculative_config(
|
def create_speculative_config(
|
||||||
|
|||||||
@ -384,6 +384,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
weights_iterator = pt_weights_iterator(
|
weights_iterator = pt_weights_iterator(
|
||||||
hf_weights_files,
|
hf_weights_files,
|
||||||
self.load_config.use_tqdm_on_load,
|
self.load_config.use_tqdm_on_load,
|
||||||
|
self.load_config.pt_load_map_location,
|
||||||
)
|
)
|
||||||
|
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
@ -890,6 +891,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
iterator = pt_weights_iterator(
|
iterator = pt_weights_iterator(
|
||||||
hf_weights_files,
|
hf_weights_files,
|
||||||
self.load_config.use_tqdm_on_load,
|
self.load_config.use_tqdm_on_load,
|
||||||
|
self.load_config.pt_load_map_location,
|
||||||
)
|
)
|
||||||
for org_name, param in iterator:
|
for org_name, param in iterator:
|
||||||
# mapping weight names from transformers to vllm while preserving
|
# mapping weight names from transformers to vllm while preserving
|
||||||
|
|||||||
@ -502,6 +502,7 @@ def fastsafetensors_weights_iterator(
|
|||||||
def pt_weights_iterator(
|
def pt_weights_iterator(
|
||||||
hf_weights_files: List[str],
|
hf_weights_files: List[str],
|
||||||
use_tqdm_on_load: bool,
|
use_tqdm_on_load: bool,
|
||||||
|
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
"""Iterate over the weights in the model bin/pt files."""
|
"""Iterate over the weights in the model bin/pt files."""
|
||||||
for bin_file in tqdm(
|
for bin_file in tqdm(
|
||||||
@ -510,7 +511,9 @@ def pt_weights_iterator(
|
|||||||
disable=not enable_tqdm(use_tqdm_on_load),
|
disable=not enable_tqdm(use_tqdm_on_load),
|
||||||
bar_format=_BAR_FORMAT,
|
bar_format=_BAR_FORMAT,
|
||||||
):
|
):
|
||||||
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
state = torch.load(bin_file,
|
||||||
|
map_location=pt_load_map_location,
|
||||||
|
weights_only=True)
|
||||||
yield from state.items()
|
yield from state.items()
|
||||||
del state
|
del state
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user