mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 05:15:37 +08:00
Enable safetensors loading for all models (#974)
This commit is contained in:
parent
c07ece5ca4
commit
c957c741d9
@ -24,9 +24,16 @@ class ModelConfig:
|
|||||||
downloading the model and tokenizer.
|
downloading the model and tokenizer.
|
||||||
download_dir: Directory to download and load the weights, default to the
|
download_dir: Directory to download and load the weights, default to the
|
||||||
default cache directory of huggingface.
|
default cache directory of huggingface.
|
||||||
use_np_weights: Save a numpy copy of model weights for faster loading.
|
load_format: The format of the model weights to load:
|
||||||
This can increase the disk usage by up to 2x.
|
"auto" will try to load the weights in the safetensors format and
|
||||||
use_dummy_weights: Use dummy values for model weights (for profiling).
|
fall back to the pytorch bin format if safetensors format is
|
||||||
|
not available.
|
||||||
|
"pt" will load the weights in the pytorch bin format.
|
||||||
|
"safetensors" will load the weights in the safetensors format.
|
||||||
|
"npcache" will load the weights in pytorch format and store
|
||||||
|
a numpy cache to speed up the loading.
|
||||||
|
"dummy" will initialize the weights with random values, which is
|
||||||
|
mainly for profiling.
|
||||||
dtype: Data type for model weights and activations. The "auto" option
|
dtype: Data type for model weights and activations. The "auto" option
|
||||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||||
for BF16 models.
|
for BF16 models.
|
||||||
@ -40,8 +47,7 @@ class ModelConfig:
|
|||||||
tokenizer_mode: str,
|
tokenizer_mode: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
download_dir: Optional[str],
|
download_dir: Optional[str],
|
||||||
use_np_weights: bool,
|
load_format: str,
|
||||||
use_dummy_weights: bool,
|
|
||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -50,14 +56,24 @@ class ModelConfig:
|
|||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
self.trust_remote_code = trust_remote_code
|
self.trust_remote_code = trust_remote_code
|
||||||
self.download_dir = download_dir
|
self.download_dir = download_dir
|
||||||
self.use_np_weights = use_np_weights
|
self.load_format = load_format
|
||||||
self.use_dummy_weights = use_dummy_weights
|
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
self.hf_config = get_config(model, trust_remote_code)
|
self.hf_config = get_config(model, trust_remote_code)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
|
self._verify_load_format()
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
|
||||||
|
def _verify_load_format(self) -> None:
|
||||||
|
load_format = self.load_format.lower()
|
||||||
|
if load_format not in [
|
||||||
|
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||||
|
]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown load format: {self.load_format}. Must be one of "
|
||||||
|
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||||
|
self.load_format = load_format
|
||||||
|
|
||||||
def _verify_tokenizer_mode(self) -> None:
|
def _verify_tokenizer_mode(self) -> None:
|
||||||
tokenizer_mode = self.tokenizer_mode.lower()
|
tokenizer_mode = self.tokenizer_mode.lower()
|
||||||
if tokenizer_mode not in ["auto", "slow"]:
|
if tokenizer_mode not in ["auto", "slow"]:
|
||||||
|
|||||||
@ -15,8 +15,7 @@ class EngineArgs:
|
|||||||
tokenizer_mode: str = 'auto'
|
tokenizer_mode: str = 'auto'
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
use_np_weights: bool = False
|
load_format: str = 'auto'
|
||||||
use_dummy_weights: bool = False
|
|
||||||
dtype: str = 'auto'
|
dtype: str = 'auto'
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
worker_use_ray: bool = False
|
worker_use_ray: bool = False
|
||||||
@ -65,14 +64,21 @@ class EngineArgs:
|
|||||||
help='directory to download and load the weights, '
|
help='directory to download and load the weights, '
|
||||||
'default to the default cache dir of '
|
'default to the default cache dir of '
|
||||||
'huggingface')
|
'huggingface')
|
||||||
parser.add_argument('--use-np-weights',
|
parser.add_argument(
|
||||||
action='store_true',
|
'--load-format',
|
||||||
help='save a numpy copy of model weights for '
|
type=str,
|
||||||
'faster loading. This can increase the disk '
|
default=EngineArgs.load_format,
|
||||||
'usage by up to 2x.')
|
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
|
||||||
parser.add_argument('--use-dummy-weights',
|
help='The format of the model weights to load. '
|
||||||
action='store_true',
|
'"auto" will try to load the weights in the safetensors format '
|
||||||
help='use dummy values for model weights')
|
'and fall back to the pytorch bin format if safetensors format '
|
||||||
|
'is not available. '
|
||||||
|
'"pt" will load the weights in the pytorch bin format. '
|
||||||
|
'"safetensors" will load the weights in the safetensors format. '
|
||||||
|
'"npcache" will load the weights in pytorch format and store '
|
||||||
|
'a numpy cache to speed up the loading. '
|
||||||
|
'"dummy" will initialize the weights with random values, '
|
||||||
|
'which is mainly for profiling.')
|
||||||
# TODO(woosuk): Support FP32.
|
# TODO(woosuk): Support FP32.
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dtype',
|
'--dtype',
|
||||||
@ -146,9 +152,8 @@ class EngineArgs:
|
|||||||
# Initialize the configs.
|
# Initialize the configs.
|
||||||
model_config = ModelConfig(self.model, self.tokenizer,
|
model_config = ModelConfig(self.model, self.tokenizer,
|
||||||
self.tokenizer_mode, self.trust_remote_code,
|
self.tokenizer_mode, self.trust_remote_code,
|
||||||
self.download_dir, self.use_np_weights,
|
self.download_dir, self.load_format,
|
||||||
self.use_dummy_weights, self.dtype,
|
self.dtype, self.seed)
|
||||||
self.seed)
|
|
||||||
cache_config = CacheConfig(self.block_size,
|
cache_config = CacheConfig(self.block_size,
|
||||||
self.gpu_memory_utilization,
|
self.gpu_memory_utilization,
|
||||||
self.swap_space)
|
self.swap_space)
|
||||||
|
|||||||
@ -76,9 +76,8 @@ class LLMEngine:
|
|||||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||||
f"dtype={model_config.dtype}, "
|
f"dtype={model_config.dtype}, "
|
||||||
f"use_dummy_weights={model_config.use_dummy_weights}, "
|
|
||||||
f"download_dir={model_config.download_dir!r}, "
|
f"download_dir={model_config.download_dir!r}, "
|
||||||
f"use_np_weights={model_config.use_np_weights}, "
|
f"load_format={model_config.load_format}, "
|
||||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||||
f"seed={model_config.seed})")
|
f"seed={model_config.seed})")
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
|
|||||||
@ -56,7 +56,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
|||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
# The weights will be initialized as empty tensors.
|
# The weights will be initialized as empty tensors.
|
||||||
model = model_class(model_config.hf_config)
|
model = model_class(model_config.hf_config)
|
||||||
if model_config.use_dummy_weights:
|
if model_config.load_format == "dummy":
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
@ -64,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
|||||||
else:
|
else:
|
||||||
# Load the weights from the cached or downloaded files.
|
# Load the weights from the cached or downloaded files.
|
||||||
model.load_weights(model_config.model, model_config.download_dir,
|
model.load_weights(model_config.model, model_config.download_dir,
|
||||||
model_config.use_np_weights)
|
model_config.load_format)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|||||||
@ -288,7 +288,7 @@ class AquilaForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
@ -305,7 +305,7 @@ class AquilaForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -35,8 +35,8 @@ from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
|
|||||||
PagedAttentionWithALiBi)
|
PagedAttentionWithALiBi)
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (
|
from vllm.model_executor.weight_utils import (
|
||||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
load_tensor_parallel_weights)
|
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
@ -303,16 +303,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
|
||||||
if "W_pack" in name:
|
if "W_pack" in name:
|
||||||
total_num_heads = self.config.num_attention_heads
|
total_num_heads = self.config.num_attention_heads
|
||||||
hidden_size = self.config.hidden_size
|
hidden_size = self.config.hidden_size
|
||||||
|
|||||||
@ -279,11 +279,11 @@ class BloomForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if name == "lm_head.weight":
|
if name == "lm_head.weight":
|
||||||
# Since hidden_states are parallelized, we need to
|
# Since hidden_states are parallelized, we need to
|
||||||
# load lm_head.weight in parallel.
|
# load lm_head.weight in parallel.
|
||||||
|
|||||||
@ -31,7 +31,8 @@ from vllm.model_executor.layers.attention import (PagedAttention,
|
|||||||
PagedAttentionWithALiBi,
|
PagedAttentionWithALiBi,
|
||||||
PagedAttentionWithRoPE)
|
PagedAttentionWithRoPE)
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||||
|
hf_model_weights_iterator,
|
||||||
load_tensor_parallel_weights)
|
load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
@ -419,7 +420,7 @@ class FalconForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_size = (get_tensor_model_parallel_world_size())
|
tp_size = (get_tensor_model_parallel_world_size())
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
@ -451,8 +452,9 @@ class FalconForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "query_key_value" in name:
|
if "query_key_value" in name:
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
loaded_weight_size = loaded_weight.size()
|
loaded_weight_size = loaded_weight.size()
|
||||||
loaded_weight = loaded_weight.view(
|
loaded_weight = loaded_weight.view(
|
||||||
total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
||||||
|
|||||||
@ -32,8 +32,8 @@ from vllm.model_executor.layers.activation import get_act_fn
|
|||||||
from vllm.model_executor.layers.attention import PagedAttention
|
from vllm.model_executor.layers.attention import PagedAttention
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (
|
from vllm.model_executor.weight_utils import (
|
||||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
load_tensor_parallel_weights)
|
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
@ -231,14 +231,14 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
# GPT-2 ties the weights of the embedding layer and the final
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
# linear layer.
|
# linear layer.
|
||||||
@ -251,6 +251,8 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
if not name.startswith("transformer."):
|
if not name.startswith("transformer."):
|
||||||
name = "transformer." + name
|
name = "transformer." + name
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
|
||||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||||
# Because of this, we need to transpose the weights.
|
# Because of this, we need to transpose the weights.
|
||||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||||
|
|||||||
@ -33,8 +33,8 @@ from vllm.model_executor.layers.activation import get_act_fn
|
|||||||
from vllm.model_executor.layers.attention import PagedAttention
|
from vllm.model_executor.layers.attention import PagedAttention
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (
|
from vllm.model_executor.weight_utils import (
|
||||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
load_tensor_parallel_weights)
|
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
@ -259,14 +259,14 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
# GPT-2 ties the weights of the embedding layer and the final
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
# linear layer.
|
# linear layer.
|
||||||
@ -295,6 +295,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
head_start = tensor_model_parallel_rank * num_heads
|
head_start = tensor_model_parallel_rank * num_heads
|
||||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
wq, wk, wv = torch.split(
|
wq, wk, wv = torch.split(
|
||||||
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
|
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
|
||||||
dim=0)
|
dim=0)
|
||||||
|
|||||||
@ -222,11 +222,11 @@ class GPTJForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -231,11 +231,11 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||||
or "rotary_emb.inv_freq" in name):
|
or "rotary_emb.inv_freq" in name):
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -233,12 +233,12 @@ class InternLMForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -271,8 +271,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False,
|
load_format: str = "auto"):
|
||||||
use_safetensor: bool = True):
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
@ -289,7 +288,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache, use_safetensor):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,8 @@ from vllm.model_executor.input_metadata import InputMetadata
|
|||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||||
|
hf_model_weights_iterator,
|
||||||
load_tensor_parallel_weights)
|
load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
@ -243,12 +244,12 @@ class MPTForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "Wqkv" in name:
|
if "Wqkv" in name:
|
||||||
# NOTE(woosuk): MPT's fused QKV has the shape of
|
# NOTE(woosuk): MPT's fused QKV has the shape of
|
||||||
# [3 * num_heads * head_size, hidden_size].
|
# [3 * num_heads * head_size, hidden_size].
|
||||||
@ -260,7 +261,7 @@ class MPTForCausalLM(nn.Module):
|
|||||||
num_heads = total_num_heads // tp_world_size
|
num_heads = total_num_heads // tp_world_size
|
||||||
head_start = tp_rank * num_heads
|
head_start = tp_rank * num_heads
|
||||||
head_end = (tp_rank + 1) * num_heads
|
head_end = (tp_rank + 1) * num_heads
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
if name.endswith(".weight"):
|
if name.endswith(".weight"):
|
||||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||||
head_size, hidden_size)
|
head_size, hidden_size)
|
||||||
|
|||||||
@ -297,12 +297,12 @@ class OPTForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (
|
from vllm.model_executor.weight_utils import (
|
||||||
|
convert_pyslice_to_tensor,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
load_padded_tensor_parallel_vocab,
|
load_padded_tensor_parallel_vocab,
|
||||||
load_tensor_parallel_weights,
|
load_tensor_parallel_weights,
|
||||||
@ -249,17 +250,19 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False,
|
load_format: str = "auto",
|
||||||
):
|
):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
|
||||||
if "c_attn" in name:
|
if "c_attn" in name:
|
||||||
total_num_heads = self.config.num_attention_heads
|
total_num_heads = self.config.num_attention_heads
|
||||||
hidden_size = self.config.hidden_size
|
hidden_size = self.config.hidden_size
|
||||||
|
|||||||
@ -81,11 +81,12 @@ def convert_bin_to_safetensor_file(
|
|||||||
def prepare_hf_model_weights(
|
def prepare_hf_model_weights(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_safetensor: bool = False,
|
use_safetensors: bool = False,
|
||||||
|
fall_back_to_pt: bool = True,
|
||||||
):
|
):
|
||||||
# Download model weights from huggingface.
|
# Download model weights from huggingface.
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
allow_patterns = "*.safetensors" if use_safetensor else "*.bin"
|
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
|
||||||
if not is_local:
|
if not is_local:
|
||||||
# Use file lock to prevent multiple processes from
|
# Use file lock to prevent multiple processes from
|
||||||
# downloading the same model weights at the same time.
|
# downloading the same model weights at the same time.
|
||||||
@ -97,32 +98,53 @@ def prepare_hf_model_weights(
|
|||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
||||||
if not use_safetensor:
|
if not use_safetensors:
|
||||||
hf_weights_files = [
|
hf_weights_files = [
|
||||||
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(hf_weights_files) == 0 and use_safetensor:
|
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
||||||
logger.warning("No *.safetensors files found, "
|
|
||||||
"fall back to *.bin files")
|
|
||||||
return prepare_hf_model_weights(model_name_or_path,
|
return prepare_hf_model_weights(model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
use_safetensor=False)
|
use_safetensors=False,
|
||||||
return hf_folder, hf_weights_files, use_safetensor
|
fall_back_to_pt=False)
|
||||||
|
|
||||||
|
if len(hf_weights_files) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||||
|
|
||||||
|
return hf_folder, hf_weights_files, use_safetensors
|
||||||
|
|
||||||
|
|
||||||
def hf_model_weights_iterator(
|
def hf_model_weights_iterator(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False,
|
load_format: str = "auto",
|
||||||
use_safetensor: bool = False,
|
|
||||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
hf_folder, hf_weights_files, use_safetensor = prepare_hf_model_weights(
|
use_safetensors = False
|
||||||
model_name_or_path, cache_dir=cache_dir, use_safetensor=use_safetensor)
|
use_np_cache = False
|
||||||
|
fall_back_to_pt = False
|
||||||
|
if load_format == "auto":
|
||||||
|
use_safetensors = True
|
||||||
|
fall_back_to_pt = True
|
||||||
|
elif load_format == "safetensors":
|
||||||
|
use_safetensors = True
|
||||||
|
elif load_format == "pt":
|
||||||
|
pass
|
||||||
|
elif load_format == "npcache":
|
||||||
|
use_np_cache = True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown load_format: {load_format}")
|
||||||
|
|
||||||
|
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
|
||||||
|
model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
fall_back_to_pt=fall_back_to_pt)
|
||||||
|
|
||||||
if use_np_cache:
|
if use_np_cache:
|
||||||
# Currently np_cache only support *.bin checkpoints
|
# Currently np_cache only support *.bin checkpoints
|
||||||
assert use_safetensor is False
|
assert use_safetensors is False
|
||||||
|
|
||||||
# Convert the model weights from torch tensors to numpy arrays for
|
# Convert the model weights from torch tensors to numpy arrays for
|
||||||
# faster loading.
|
# faster loading.
|
||||||
@ -152,7 +174,7 @@ def hf_model_weights_iterator(
|
|||||||
with open(param_path, "rb") as f:
|
with open(param_path, "rb") as f:
|
||||||
param = np.load(f)
|
param = np.load(f)
|
||||||
yield name, torch.from_numpy(param)
|
yield name, torch.from_numpy(param)
|
||||||
elif use_safetensor:
|
elif use_safetensors:
|
||||||
for st_file in hf_weights_files:
|
for st_file in hf_weights_files:
|
||||||
with safe_open(st_file, framework="pt") as f:
|
with safe_open(st_file, framework="pt") as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
@ -167,6 +189,21 @@ def hf_model_weights_iterator(
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||||
|
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||||
|
|
||||||
|
PySafeSlice object supports indexing, which is done before loading the
|
||||||
|
actual tensor and can reduce the amount of memory being read into the
|
||||||
|
memory. However, it does not support more advanced functionalities
|
||||||
|
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
||||||
|
tensor with these more complicated operators, we need to convert to
|
||||||
|
tensor first.
|
||||||
|
"""
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
x = x[:]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def load_padded_tensor_parallel_vocab(
|
def load_padded_tensor_parallel_vocab(
|
||||||
param: torch.Tensor,
|
param: torch.Tensor,
|
||||||
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
||||||
@ -176,11 +213,7 @@ def load_padded_tensor_parallel_vocab(
|
|||||||
start_idx = tensor_model_parallel_rank * shard_size
|
start_idx = tensor_model_parallel_rank * shard_size
|
||||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||||
loaded_weight = loaded_weight[start_idx:end_idx]
|
loaded_weight = loaded_weight[start_idx:end_idx]
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
# convert PySafeSlice object to torch.Tensor
|
|
||||||
if not isinstance(loaded_weight, torch.Tensor):
|
|
||||||
loaded_weight = loaded_weight[:]
|
|
||||||
|
|
||||||
param[:loaded_weight.shape[0]].copy_(loaded_weight)
|
param[:loaded_weight.shape[0]].copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
@ -207,10 +240,7 @@ def load_tensor_parallel_weights(
|
|||||||
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
||||||
break
|
break
|
||||||
|
|
||||||
# convert PySafeSlice object to torch.Tensor
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
if not isinstance(loaded_weight, torch.Tensor):
|
|
||||||
loaded_weight = loaded_weight[:]
|
|
||||||
|
|
||||||
assert param.shape == loaded_weight.shape, (
|
assert param.shape == loaded_weight.shape, (
|
||||||
f"{param_name} shape mismatch between model and checkpoint: "
|
f"{param_name} shape mismatch between model and checkpoint: "
|
||||||
f"{param.shape} != {loaded_weight.shape}")
|
f"{param.shape} != {loaded_weight.shape}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user