Add Model Revision Support (#1014)

Co-authored-by: Jasmond Loh <Jasmond.Loh@hotmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Jasmond L 2023-09-14 06:20:02 +08:00 committed by GitHub
parent 9841d48a10
commit ab019eea75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 75 additions and 35 deletions

View File

@ -38,6 +38,9 @@ class ModelConfig:
will use FP16 precision for FP32 and FP16 models, and BF16 precision will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models. for BF16 models.
seed: Random seed for reproducibility. seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
max_model_len: Maximum length of a sequence (including prompt and max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model. output). If None, will be derived from the model.
""" """
@ -52,6 +55,7 @@ class ModelConfig:
load_format: str, load_format: str,
dtype: str, dtype: str,
seed: int, seed: int,
revision: Optional[str],
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
) -> None: ) -> None:
self.model = model self.model = model
@ -61,8 +65,9 @@ class ModelConfig:
self.download_dir = download_dir self.download_dir = download_dir
self.load_format = load_format self.load_format = load_format
self.seed = seed self.seed = seed
self.revision = revision
self.hf_config = get_config(model, trust_remote_code) self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format() self._verify_load_format()
self._verify_tokenizer_mode() self._verify_tokenizer_mode()

View File

@ -28,6 +28,7 @@ class EngineArgs:
max_num_batched_tokens: int = 2560 max_num_batched_tokens: int = 2560
max_num_seqs: int = 256 max_num_seqs: int = 256
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
@ -49,6 +50,13 @@ class EngineArgs:
type=str, type=str,
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use') help='name or path of the huggingface tokenizer to use')
parser.add_argument(
'--revision',
type=str,
default=None,
help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-mode', parser.add_argument('--tokenizer-mode',
type=str, type=str,
default=EngineArgs.tokenizer_mode, default=EngineArgs.tokenizer_mode,
@ -159,7 +167,8 @@ class EngineArgs:
model_config = ModelConfig(self.model, self.tokenizer, model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code, self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format, self.download_dir, self.load_format,
self.dtype, self.seed, self.max_model_len) self.dtype, self.seed, self.revision,
self.max_model_len)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space) self.swap_space)

View File

@ -74,6 +74,7 @@ class LLMEngine:
f"model={model_config.model!r}, " f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, " f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, " f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
f"trust_remote_code={model_config.trust_remote_code}, " f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, " f"dtype={model_config.dtype}, "
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
@ -92,7 +93,8 @@ class LLMEngine:
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
model_config.tokenizer, model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode, tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code) trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision)
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.

View File

@ -38,6 +38,8 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead. use `float16` instead.
seed: The seed to initialize the random number generator for sampling. seed: The seed to initialize the random number generator for sampling.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
""" """
def __init__( def __init__(

View File

@ -64,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
else: else:
# Load the weights from the cached or downloaded files. # Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir, model.load_weights(model_config.model, model_config.download_dir,
model_config.load_format) model_config.load_format, model_config.revision)
model = model.cuda() model = model.cuda()
return model.eval() return model.eval()

View File

@ -288,7 +288,8 @@ class AquilaForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size) q_proj_shard_size = (self.config.hidden_size // tp_size)
@ -305,7 +306,7 @@ class AquilaForCausalLM(nn.Module):
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue

View File

@ -303,13 +303,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue

View File

@ -279,11 +279,12 @@ class BloomForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if name == "lm_head.weight": if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to # Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel. # load lm_head.weight in parallel.

View File

@ -420,7 +420,8 @@ class FalconForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_size = (get_tensor_model_parallel_world_size()) tp_size = (get_tensor_model_parallel_world_size())
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
@ -452,7 +453,7 @@ class FalconForCausalLM(nn.Module):
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "query_key_value" in name: if "query_key_value" in name:
loaded_weight = convert_pyslice_to_tensor(loaded_weight) loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight_size = loaded_weight.size() loaded_weight_size = loaded_weight.size()

View File

@ -231,14 +231,15 @@ class GPT2LMHeadModel(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_world_size = ( tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.

View File

@ -259,14 +259,15 @@ class GPTBigCodeForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_world_size = ( tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.

View File

@ -222,11 +222,12 @@ class GPTJForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "attn.bias" in name or "attn.masked_bias" in name: if "attn.bias" in name or "attn.masked_bias" in name:
continue continue

View File

@ -231,11 +231,12 @@ class GPTNeoXForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if ("attention.bias" in name or "attention.masked_bias" in name if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name): or "rotary_emb.inv_freq" in name):
continue continue

View File

@ -233,12 +233,13 @@ class InternLMForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue

View File

@ -271,7 +271,8 @@ class LlamaForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size) q_proj_shard_size = (self.config.hidden_size // tp_size)
@ -288,7 +289,7 @@ class LlamaForCausalLM(nn.Module):
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue

View File

@ -244,12 +244,13 @@ class MPTForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "Wqkv" in name: if "Wqkv" in name:
# NOTE(woosuk): MPT's fused QKV has the shape of # NOTE(woosuk): MPT's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size]. # [3 * num_heads * head_size, hidden_size].

View File

@ -297,12 +297,13 @@ class OPTForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
continue continue

View File

@ -251,13 +251,14 @@ class QWenLMHeadModel(nn.Module):
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None,
): ):
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue

View File

@ -83,6 +83,7 @@ def prepare_hf_model_weights(
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_safetensors: bool = False, use_safetensors: bool = False,
fall_back_to_pt: bool = True, fall_back_to_pt: bool = True,
revision: Optional[str] = None,
): ):
# Download model weights from huggingface. # Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
@ -94,7 +95,8 @@ def prepare_hf_model_weights(
hf_folder = snapshot_download(model_name_or_path, hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
cache_dir=cache_dir, cache_dir=cache_dir,
tqdm_class=Disabledtqdm) tqdm_class=Disabledtqdm,
revision=revision)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns)) hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
@ -107,7 +109,8 @@ def prepare_hf_model_weights(
return prepare_hf_model_weights(model_name_or_path, return prepare_hf_model_weights(model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
use_safetensors=False, use_safetensors=False,
fall_back_to_pt=False) fall_back_to_pt=False,
revision=revision)
if len(hf_weights_files) == 0: if len(hf_weights_files) == 0:
raise RuntimeError( raise RuntimeError(
@ -120,6 +123,7 @@ def hf_model_weights_iterator(
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None,
) -> Iterator[Tuple[str, torch.Tensor]]: ) -> Iterator[Tuple[str, torch.Tensor]]:
use_safetensors = False use_safetensors = False
use_np_cache = False use_np_cache = False
@ -140,7 +144,8 @@ def hf_model_weights_iterator(
model_name_or_path, model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
fall_back_to_pt=fall_back_to_pt) fall_back_to_pt=fall_back_to_pt,
revision=revision)
if use_np_cache: if use_np_cache:
# Currently np_cache only support *.bin checkpoints # Currently np_cache only support *.bin checkpoints

View File

@ -1,3 +1,5 @@
from typing import Optional
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
@ -12,10 +14,12 @@ _CONFIG_REGISTRY = {
} }
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig: def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None) -> PretrainedConfig:
try: try:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code) model, trust_remote_code=trust_remote_code, revision=revision)
except ValueError as e: except ValueError as e:
if (not trust_remote_code and if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)): "requires you to execute the configuration file" in str(e)):
@ -29,5 +33,5 @@ def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
raise e raise e
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model) config = config_class.from_pretrained(model, revision=revision)
return config return config