diff --git a/vllm/config.py b/vllm/config.py index ca260f279c5a..aa8c4dc3a0ce 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -38,6 +38,9 @@ class ModelConfig: will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. seed: Random seed for reproducibility. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. If unspecified, will use the default + version. max_model_len: Maximum length of a sequence (including prompt and output). If None, will be derived from the model. """ @@ -52,6 +55,7 @@ class ModelConfig: load_format: str, dtype: str, seed: int, + revision: Optional[str], max_model_len: Optional[int] = None, ) -> None: self.model = model @@ -61,8 +65,9 @@ class ModelConfig: self.download_dir = download_dir self.load_format = load_format self.seed = seed + self.revision = revision - self.hf_config = get_config(model, trust_remote_code) + self.hf_config = get_config(model, trust_remote_code, revision) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self._verify_load_format() self._verify_tokenizer_mode() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1e3f9e644c6c..9478e8002d51 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -28,6 +28,7 @@ class EngineArgs: max_num_batched_tokens: int = 2560 max_num_seqs: int = 256 disable_log_stats: bool = False + revision: Optional[str] = None def __post_init__(self): if self.tokenizer is None: @@ -49,6 +50,13 @@ class EngineArgs: type=str, default=EngineArgs.tokenizer, help='name or path of the huggingface tokenizer to use') + parser.add_argument( + '--revision', + type=str, + default=None, + help='the specific model version to use. It can be a branch ' + 'name, a tag name, or a commit id. If unspecified, will use ' + 'the default version.') parser.add_argument('--tokenizer-mode', type=str, default=EngineArgs.tokenizer_mode, @@ -159,7 +167,8 @@ class EngineArgs: model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, - self.dtype, self.seed, self.max_model_len) + self.dtype, self.seed, self.revision, + self.max_model_len) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 74093cf40808..1b0f50e3a55f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -74,6 +74,7 @@ class LLMEngine: f"model={model_config.model!r}, " f"tokenizer={model_config.tokenizer!r}, " f"tokenizer_mode={model_config.tokenizer_mode}, " + f"revision={model_config.revision}, " f"trust_remote_code={model_config.trust_remote_code}, " f"dtype={model_config.dtype}, " f"download_dir={model_config.download_dir!r}, " @@ -92,7 +93,8 @@ class LLMEngine: self.tokenizer = get_tokenizer( model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code) + trust_remote_code=model_config.trust_remote_code, + revision=model_config.revision) self.seq_counter = Counter() # Create the parallel GPU workers. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c9ab68525503..6c2afe9e7272 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -38,6 +38,8 @@ class LLM: However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead. seed: The seed to initialize the random number generator for sampling. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. """ def __init__( diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 71a326883b62..cd6c6b672a76 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -64,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module: else: # Load the weights from the cached or downloaded files. model.load_weights(model_config.model, model_config.download_dir, - model_config.load_format) + model_config.load_format, model_config.revision) model = model.cuda() return model.eval() diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index dcd849d722f2..1551974112b2 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -288,7 +288,8 @@ class AquilaForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tp_size = get_tensor_model_parallel_world_size() tensor_model_parallel_rank = get_tensor_model_parallel_rank() q_proj_shard_size = (self.config.hidden_size // tp_size) @@ -305,7 +306,7 @@ class AquilaForCausalLM(nn.Module): state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 77ada1e76522..17e971d7bb29 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -303,13 +303,14 @@ class BaiChuanBaseForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tp_world_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index e17d8d075e14..d7f7d1910bc5 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -279,11 +279,12 @@ class BloomForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if name == "lm_head.weight": # Since hidden_states are parallelized, we need to # load lm_head.weight in parallel. diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 883faf636269..dbd8a8203e4b 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -420,7 +420,8 @@ class FalconForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tp_size = (get_tensor_model_parallel_world_size()) tp_rank = get_tensor_model_parallel_rank() @@ -452,7 +453,7 @@ class FalconForCausalLM(nn.Module): state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if "query_key_value" in name: loaded_weight = convert_pyslice_to_tensor(loaded_weight) loaded_weight_size = loaded_weight.size() diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index ba1ee05b6f0e..fe7e009aeaf7 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -231,14 +231,15 @@ class GPT2LMHeadModel(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 9d6d091cb1f1..049b4622839a 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -259,14 +259,15 @@ class GPTBigCodeForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 456b192322ed..c3e8da239ace 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -222,11 +222,12 @@ class GPTJForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if "attn.bias" in name or "attn.masked_bias" in name: continue diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 454600854a86..acbd1b47d626 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -231,11 +231,12 @@ class GPTNeoXForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): continue diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 5cd68541141b..fdcac02a1b27 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -233,12 +233,13 @@ class InternLMForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d217f447a498..a2804d889810 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -271,7 +271,8 @@ class LlamaForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tp_size = get_tensor_model_parallel_world_size() tensor_model_parallel_rank = get_tensor_model_parallel_rank() q_proj_shard_size = (self.config.hidden_size // tp_size) @@ -288,7 +289,7 @@ class LlamaForCausalLM(nn.Module): state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 4cd88900a4d9..293d77b6aa1d 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -244,12 +244,13 @@ class MPTForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tp_world_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if "Wqkv" in name: # NOTE(woosuk): MPT's fused QKV has the shape of # [3 * num_heads * head_size, hidden_size]. diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 508083df9297..2064e1aec2af 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -297,12 +297,13 @@ class OPTForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto"): + load_format: str = "auto", + revision: Optional[str] = None): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: continue diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a3557b5818f5..4ce9aea5e2c7 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -251,13 +251,14 @@ class QWenLMHeadModel(nn.Module): model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", + revision: Optional[str] = None, ): tp_world_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format): + model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 578a1392845b..c99f02bdbb77 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -83,6 +83,7 @@ def prepare_hf_model_weights( cache_dir: Optional[str] = None, use_safetensors: bool = False, fall_back_to_pt: bool = True, + revision: Optional[str] = None, ): # Download model weights from huggingface. is_local = os.path.isdir(model_name_or_path) @@ -94,7 +95,8 @@ def prepare_hf_model_weights( hf_folder = snapshot_download(model_name_or_path, allow_patterns=allow_patterns, cache_dir=cache_dir, - tqdm_class=Disabledtqdm) + tqdm_class=Disabledtqdm, + revision=revision) else: hf_folder = model_name_or_path hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns)) @@ -107,7 +109,8 @@ def prepare_hf_model_weights( return prepare_hf_model_weights(model_name_or_path, cache_dir=cache_dir, use_safetensors=False, - fall_back_to_pt=False) + fall_back_to_pt=False, + revision=revision) if len(hf_weights_files) == 0: raise RuntimeError( @@ -120,6 +123,7 @@ def hf_model_weights_iterator( model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", + revision: Optional[str] = None, ) -> Iterator[Tuple[str, torch.Tensor]]: use_safetensors = False use_np_cache = False @@ -140,7 +144,8 @@ def hf_model_weights_iterator( model_name_or_path, cache_dir=cache_dir, use_safetensors=use_safetensors, - fall_back_to_pt=fall_back_to_pt) + fall_back_to_pt=fall_back_to_pt, + revision=revision) if use_np_cache: # Currently np_cache only support *.bin checkpoints diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index a203fad668a0..fd5618bd81ba 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,3 +1,5 @@ +from typing import Optional + from transformers import AutoConfig, PretrainedConfig from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import @@ -12,10 +14,12 @@ _CONFIG_REGISTRY = { } -def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig: +def get_config(model: str, + trust_remote_code: bool, + revision: Optional[str] = None) -> PretrainedConfig: try: config = AutoConfig.from_pretrained( - model, trust_remote_code=trust_remote_code) + model, trust_remote_code=trust_remote_code, revision=revision) except ValueError as e: if (not trust_remote_code and "requires you to execute the configuration file" in str(e)): @@ -29,5 +33,5 @@ def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig: raise e if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] - config = config_class.from_pretrained(model) + config = config_class.from_pretrained(model, revision=revision) return config