diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 2e123465c48d..194da4ff542e 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -34,8 +34,9 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.weight_utils import ( + hf_model_weights_iterator, load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -280,8 +281,7 @@ class AquilaForCausalLM(nn.Module): return next_tokens _column_parallel_weights = [ - "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight", - "gate_proj.weight", "up_proj.weight" + "qkv_proj.weight", "gate_proj.weight", "up_proj.weight" ] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] @@ -309,16 +309,6 @@ class AquilaForCausalLM(nn.Module): if "rotary_emb.inv_freq" in name: continue - if "embed_tokens" in name or "lm_head" in name: - param = state_dict[name] - # Consider padding in the vocab size. - padded_vocab_size = (param.shape[0] * tp_size) - num_extra_rows = padded_vocab_size - self.config.vocab_size - extra_rows = torch.empty(num_extra_rows, - loaded_weight.shape[1]) - extra_rows = extra_rows.to(loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - is_attention_weight = False for weight_name, shard_size, offset in attention_weight_specs: if weight_name not in name: @@ -356,6 +346,11 @@ class AquilaForCausalLM(nn.Module): continue param = state_dict[name] + if "embed_tokens" in name or "lm_head" in name: + load_padded_tensor_parallel_vocab(param, loaded_weight, + tensor_model_parallel_rank) + continue + load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights, self._row_parallel_weights, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 3ec9ddbacadc..47c79059b7a8 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -32,10 +32,12 @@ from vllm.sequence import SequenceOutputs from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi +from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE, + PagedAttentionWithALiBi) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.weight_utils import ( + hf_model_weights_iterator, load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -295,10 +297,7 @@ class BaiChuanBaseForCausalLM(nn.Module): input_metadata) return next_tokens - _column_parallel_weights = [ - "embed_tokens.weight", - "lm_head.weight", - ] + _column_parallel_weights = [] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] def load_weights(self, @@ -314,16 +313,6 @@ class BaiChuanBaseForCausalLM(nn.Module): if "rotary_emb.inv_freq" in name: continue - if "embed_tokens" in name or "lm_head" in name: - # Consider padding in the vocab size. - param = state_dict[name] - padded_vocab_size = param.shape[0] * tp_world_size - num_extra_rows = padded_vocab_size - self.config.vocab_size - extra_rows = torch.empty(num_extra_rows, - loaded_weight.shape[1]) - extra_rows = extra_rows.to(loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - if "W_pack" in name: total_num_heads = self.config.num_attention_heads hidden_size = self.config.hidden_size @@ -355,6 +344,12 @@ class BaiChuanBaseForCausalLM(nn.Module): continue param = state_dict[name] + + if "embed_tokens" in name or "lm_head" in name: + load_padded_tensor_parallel_vocab(param, loaded_weight, + tp_rank) + continue + load_tensor_parallel_weights( param, loaded_weight, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index fd59e691e213..a3b6efe2af8d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -31,8 +31,9 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.weight_utils import ( + hf_model_weights_iterator, load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -224,7 +225,7 @@ class GPT2LMHeadModel(nn.Module): input_metadata) return next_tokens - _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] + _column_parallel_weights = ["c_fc.weight", "c_fc.bias"] _row_parallel_weights = ["c_proj.weight"] def load_weights(self, @@ -261,14 +262,9 @@ class GPT2LMHeadModel(nn.Module): param = state_dict[name] if name == "transformer.wte.weight": - # Consider padding in the vocab size. - padded_vocab_size = (param.shape[0] * - tensor_model_parallel_world_size) - num_extra_rows = padded_vocab_size - self.config.vocab_size - extra_rows = torch.empty(num_extra_rows, - loaded_weight.shape[1]) - extra_rows = extra_rows.to(loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + load_padded_tensor_parallel_vocab(param, loaded_weight, + tensor_model_parallel_rank) + continue # For the fused QKV linear layer, manually shard the weights. if "c_attn" in name: diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 595c7b640e23..8694c975832c 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -32,8 +32,9 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.weight_utils import ( + hf_model_weights_iterator, load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -252,7 +253,7 @@ class GPTBigCodeForCausalLM(nn.Module): input_metadata) return next_tokens - _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] + _column_parallel_weights = ["c_fc.weight", "c_fc.bias"] _row_parallel_weights = ["c_proj.weight"] def load_weights(self, @@ -328,14 +329,9 @@ class GPTBigCodeForCausalLM(nn.Module): param = state_dict[name] if name == "transformer.wte.weight": - # Consider padding in the vocab size. - padded_vocab_size = param.shape[ - 0] * tensor_model_parallel_world_size - num_extra_rows = padded_vocab_size - self.config.vocab_size - extra_rows = torch.empty(num_extra_rows, - loaded_weight.shape[1]) - extra_rows = extra_rows.to(loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + load_padded_tensor_parallel_vocab(param, loaded_weight, + tensor_model_parallel_rank) + continue load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights, diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 1998323352be..1aeeb91d94c6 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -14,8 +14,9 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding) -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.weight_utils import ( + hf_model_weights_iterator, load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights) from vllm.sequence import SequenceOutputs KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -225,8 +226,7 @@ class InternLMForCausalLM(nn.Module): return next_tokens _column_parallel_weights = [ - "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight", - "gate_proj.weight", "up_proj.weight" + "qkv_proj.weight", "gate_proj.weight", "up_proj.weight" ] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] @@ -234,8 +234,6 @@ class InternLMForCausalLM(nn.Module): model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() @@ -246,14 +244,9 @@ class InternLMForCausalLM(nn.Module): if "embed_tokens" in name or "lm_head" in name: param = state_dict[name] - # Consider padding in the vocab size. - padded_vocab_size = (param.shape[0] * - tensor_model_parallel_world_size) - num_extra_rows = padded_vocab_size - self.config.vocab_size - extra_rows = torch.empty(num_extra_rows, - loaded_weight.shape[1]) - extra_rows = extra_rows.to(loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + load_padded_tensor_parallel_vocab(param, loaded_weight, + tensor_model_parallel_rank) + continue is_attention_weight = False for stride_id, att_weight_name in enumerate( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 7d8c950cb2a9..d72c4ff6a0ea 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -36,8 +36,9 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.weight_utils import ( + load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, + hf_model_weights_iterator) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -263,15 +264,15 @@ class LlamaForCausalLM(nn.Module): return next_tokens _column_parallel_weights = [ - "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight", - "gate_proj.weight", "up_proj.weight" + "qkv_proj.weight", "gate_proj.weight", "up_proj.weight" ] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + use_np_cache: bool = False, + use_safetensor: bool = True): tp_size = get_tensor_model_parallel_world_size() tensor_model_parallel_rank = get_tensor_model_parallel_rank() q_proj_shard_size = (self.config.hidden_size // tp_size) @@ -288,20 +289,10 @@ class LlamaForCausalLM(nn.Module): state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, use_np_cache, use_safetensor): if "rotary_emb.inv_freq" in name: continue - if "embed_tokens" in name or "lm_head" in name: - param = state_dict[name] - # Consider padding in the vocab size. - padded_vocab_size = (param.shape[0] * tp_size) - num_extra_rows = padded_vocab_size - self.config.vocab_size - extra_rows = torch.empty(num_extra_rows, - loaded_weight.shape[1]) - extra_rows = extra_rows.to(loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - is_attention_weight = False for weight_name, shard_size, offset in attention_weight_specs: if weight_name not in name: @@ -339,6 +330,12 @@ class LlamaForCausalLM(nn.Module): continue param = state_dict[name] + + if "embed_tokens" in name or "lm_head" in name: + load_padded_tensor_parallel_vocab(param, loaded_weight, + tensor_model_parallel_rank) + continue + load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights, self._row_parallel_weights, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d81940ed28b0..1318bbe70288 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.weight_utils import ( hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, load_tensor_parallel_weights, ) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -241,7 +242,7 @@ class QWenLMHeadModel(nn.Module): input_metadata) return next_tokens - _column_parallel_weights = ["wte.weight", "lm_head.weight"] + _column_parallel_weights = [] _row_parallel_weights = ["c_proj.weight"] def load_weights( @@ -259,16 +260,6 @@ class QWenLMHeadModel(nn.Module): if "rotary_emb.inv_freq" in name: continue - if "wte" in name or "lm_head" in name: - # Consider padding in the vocab size. - param = state_dict[name] - padded_vocab_size = param.shape[0] * tp_world_size - num_extra_rows = padded_vocab_size - self.config.vocab_size - extra_rows = torch.empty(num_extra_rows, - loaded_weight.shape[1]) - extra_rows = extra_rows.to(loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - if "c_attn" in name: total_num_heads = self.config.num_attention_heads hidden_size = self.config.hidden_size @@ -306,6 +297,12 @@ class QWenLMHeadModel(nn.Module): continue param = state_dict[name] + + if "wte" in name or "lm_head" in name: + load_padded_tensor_parallel_vocab(param, loaded_weight, + tp_rank) + continue + load_tensor_parallel_weights( param, loaded_weight, diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 3127a36098e2..23844470faa5 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -3,13 +3,19 @@ import filelock import glob import json import os -from typing import Iterator, List, Optional, Tuple +from collections import defaultdict +from typing import Iterator, List, Optional, Tuple, Any from huggingface_hub import snapshot_download +from safetensors.torch import load_file, save_file, safe_open import numpy as np import torch from tqdm.auto import tqdm +from vllm.logger import init_logger + +logger = init_logger(__name__) + class Disabledtqdm(tqdm): @@ -17,43 +23,118 @@ class Disabledtqdm(tqdm): super().__init__(*args, **kwargs, disable=True) -def hf_model_weights_iterator( - model_name_or_path: str, - cache_dir: Optional[str] = None, - use_np_cache: bool = False, -) -> Iterator[Tuple[str, torch.Tensor]]: - # Prepare file lock directory to prevent multiple processes from - # downloading the same model weights at the same time. +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): lock_dir = cache_dir if cache_dir is not None else "/tmp" lock_file_name = model_name_or_path.replace("/", "-") + ".lock" lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) + return lock + +def _shared_pointers(tensors): + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + failing = [] + for _, names in ptrs.items(): + if len(names) > 1: + failing.append(names) + return failing + + +def convert_bin_to_safetensor_file( + pt_filename: str, + sf_filename: str, +): + loaded = torch.load(pt_filename, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + shared = _shared_pointers(loaded) + for shared_weights in shared: + for name in shared_weights[1:]: + loaded.pop(name) + + # For tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata={"format": "pt"}) + + # check file size + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError(f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """) + + # check if the tensors are the same + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def prepare_hf_model_weights( + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_safetensor: bool = False, +): # Download model weights from huggingface. is_local = os.path.isdir(model_name_or_path) + allow_patterns = "*.safetensors" if use_safetensor else "*.bin" if not is_local: - with lock: + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): hf_folder = snapshot_download(model_name_or_path, - allow_patterns="*.bin", + allow_patterns=allow_patterns, cache_dir=cache_dir, tqdm_class=Disabledtqdm) else: hf_folder = model_name_or_path + hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns)) + if not use_safetensor: + hf_weights_files = [ + x for x in hf_weights_files if not x.endswith("training_args.bin") + ] - hf_bin_files = [ - x for x in glob.glob(os.path.join(hf_folder, "*.bin")) - if not x.endswith("training_args.bin") - ] + if len(hf_weights_files) == 0 and use_safetensor: + logger.warning("No *.safetensors files found, " + "fall back to *.bin files") + return prepare_hf_model_weights(model_name_or_path, + cache_dir=cache_dir, + use_safetensor=False) + return hf_folder, hf_weights_files, use_safetensor + + +def hf_model_weights_iterator( + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False, + use_safetensor: bool = False, +) -> Iterator[Tuple[str, torch.Tensor]]: + hf_folder, hf_weights_files, use_safetensor = prepare_hf_model_weights( + model_name_or_path, cache_dir=cache_dir, use_safetensor=use_safetensor) if use_np_cache: + # Currently np_cache only support *.bin checkpoints + assert use_safetensor is False + # Convert the model weights from torch tensors to numpy arrays for # faster loading. np_folder = os.path.join(hf_folder, "np") os.makedirs(np_folder, exist_ok=True) weight_names_file = os.path.join(np_folder, "weight_names.json") - with lock: + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): if not os.path.exists(weight_names_file): weight_names = [] - for bin_file in hf_bin_files: + for bin_file in hf_weights_files: state = torch.load(bin_file, map_location="cpu") for name, param in state.items(): param_path = os.path.join(np_folder, name) @@ -71,8 +152,14 @@ def hf_model_weights_iterator( with open(param_path, "rb") as f: param = np.load(f) yield name, torch.from_numpy(param) + elif use_safetensor: + for st_file in hf_weights_files: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): + param = f.get_slice(name) + yield name, param else: - for bin_file in hf_bin_files: + for bin_file in hf_weights_files: state = torch.load(bin_file, map_location="cpu") for name, param in state.items(): yield name, param @@ -80,9 +167,26 @@ def hf_model_weights_iterator( torch.cuda.empty_cache() +def load_padded_tensor_parallel_vocab( + param: torch.Tensor, + loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` + tensor_model_parallel_rank: int, +) -> None: + shard_size = param.shape[0] + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + loaded_weight = loaded_weight[start_idx:end_idx] + + # convert PySafeSlice object to torch.Tensor + if not isinstance(loaded_weight, torch.Tensor): + loaded_weight = loaded_weight[:] + + param[:loaded_weight.shape[0]].copy_(loaded_weight) + + def load_tensor_parallel_weights( param: torch.Tensor, - loaded_weight: torch.Tensor, + loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` param_name: str, column_parallel_weight_names: List[str], row_parallel_weight_names: List[str], @@ -102,6 +206,11 @@ def load_tensor_parallel_weights( end_idx = (tensor_model_parallel_rank + 1) * shard_size loaded_weight = loaded_weight[:, start_idx:end_idx] break + + # convert PySafeSlice object to torch.Tensor + if not isinstance(loaded_weight, torch.Tensor): + loaded_weight = loaded_weight[:] + assert param.shape == loaded_weight.shape, ( f"{param_name} shape mismatch between model and checkpoint: " f"{param.shape} != {loaded_weight.shape}")