mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:34:59 +08:00
Add Model Revision Support (#1014)
Co-authored-by: Jasmond Loh <Jasmond.Loh@hotmail.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
9841d48a10
commit
ab019eea75
@ -38,6 +38,9 @@ class ModelConfig:
|
|||||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||||
for BF16 models.
|
for BF16 models.
|
||||||
seed: Random seed for reproducibility.
|
seed: Random seed for reproducibility.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id. If unspecified, will use the default
|
||||||
|
version.
|
||||||
max_model_len: Maximum length of a sequence (including prompt and
|
max_model_len: Maximum length of a sequence (including prompt and
|
||||||
output). If None, will be derived from the model.
|
output). If None, will be derived from the model.
|
||||||
"""
|
"""
|
||||||
@ -52,6 +55,7 @@ class ModelConfig:
|
|||||||
load_format: str,
|
load_format: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
revision: Optional[str],
|
||||||
max_model_len: Optional[int] = None,
|
max_model_len: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -61,8 +65,9 @@ class ModelConfig:
|
|||||||
self.download_dir = download_dir
|
self.download_dir = download_dir
|
||||||
self.load_format = load_format
|
self.load_format = load_format
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
self.revision = revision
|
||||||
|
|
||||||
self.hf_config = get_config(model, trust_remote_code)
|
self.hf_config = get_config(model, trust_remote_code, revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
self._verify_load_format()
|
self._verify_load_format()
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
|||||||
@ -28,6 +28,7 @@ class EngineArgs:
|
|||||||
max_num_batched_tokens: int = 2560
|
max_num_batched_tokens: int = 2560
|
||||||
max_num_seqs: int = 256
|
max_num_seqs: int = 256
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
|
revision: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
@ -49,6 +50,13 @@ class EngineArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.tokenizer,
|
default=EngineArgs.tokenizer,
|
||||||
help='name or path of the huggingface tokenizer to use')
|
help='name or path of the huggingface tokenizer to use')
|
||||||
|
parser.add_argument(
|
||||||
|
'--revision',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='the specific model version to use. It can be a branch '
|
||||||
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
|
'the default version.')
|
||||||
parser.add_argument('--tokenizer-mode',
|
parser.add_argument('--tokenizer-mode',
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.tokenizer_mode,
|
default=EngineArgs.tokenizer_mode,
|
||||||
@ -159,7 +167,8 @@ class EngineArgs:
|
|||||||
model_config = ModelConfig(self.model, self.tokenizer,
|
model_config = ModelConfig(self.model, self.tokenizer,
|
||||||
self.tokenizer_mode, self.trust_remote_code,
|
self.tokenizer_mode, self.trust_remote_code,
|
||||||
self.download_dir, self.load_format,
|
self.download_dir, self.load_format,
|
||||||
self.dtype, self.seed, self.max_model_len)
|
self.dtype, self.seed, self.revision,
|
||||||
|
self.max_model_len)
|
||||||
cache_config = CacheConfig(self.block_size,
|
cache_config = CacheConfig(self.block_size,
|
||||||
self.gpu_memory_utilization,
|
self.gpu_memory_utilization,
|
||||||
self.swap_space)
|
self.swap_space)
|
||||||
|
|||||||
@ -74,6 +74,7 @@ class LLMEngine:
|
|||||||
f"model={model_config.model!r}, "
|
f"model={model_config.model!r}, "
|
||||||
f"tokenizer={model_config.tokenizer!r}, "
|
f"tokenizer={model_config.tokenizer!r}, "
|
||||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||||
|
f"revision={model_config.revision}, "
|
||||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||||
f"dtype={model_config.dtype}, "
|
f"dtype={model_config.dtype}, "
|
||||||
f"download_dir={model_config.download_dir!r}, "
|
f"download_dir={model_config.download_dir!r}, "
|
||||||
@ -92,7 +93,8 @@ class LLMEngine:
|
|||||||
self.tokenizer = get_tokenizer(
|
self.tokenizer = get_tokenizer(
|
||||||
model_config.tokenizer,
|
model_config.tokenizer,
|
||||||
tokenizer_mode=model_config.tokenizer_mode,
|
tokenizer_mode=model_config.tokenizer_mode,
|
||||||
trust_remote_code=model_config.trust_remote_code)
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
revision=model_config.revision)
|
||||||
self.seq_counter = Counter()
|
self.seq_counter = Counter()
|
||||||
|
|
||||||
# Create the parallel GPU workers.
|
# Create the parallel GPU workers.
|
||||||
|
|||||||
@ -38,6 +38,8 @@ class LLM:
|
|||||||
However, if the `torch_dtype` in the config is `float32`, we will
|
However, if the `torch_dtype` in the config is `float32`, we will
|
||||||
use `float16` instead.
|
use `float16` instead.
|
||||||
seed: The seed to initialize the random number generator for sampling.
|
seed: The seed to initialize the random number generator for sampling.
|
||||||
|
revision: The specific model version to use. It can be a branch name,
|
||||||
|
a tag name, or a commit id.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -64,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
|||||||
else:
|
else:
|
||||||
# Load the weights from the cached or downloaded files.
|
# Load the weights from the cached or downloaded files.
|
||||||
model.load_weights(model_config.model, model_config.download_dir,
|
model.load_weights(model_config.model, model_config.download_dir,
|
||||||
model_config.load_format)
|
model_config.load_format, model_config.revision)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|||||||
@ -288,7 +288,8 @@ class AquilaForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
@ -305,7 +306,7 @@ class AquilaForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -303,13 +303,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -279,11 +279,12 @@ class BloomForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if name == "lm_head.weight":
|
if name == "lm_head.weight":
|
||||||
# Since hidden_states are parallelized, we need to
|
# Since hidden_states are parallelized, we need to
|
||||||
# load lm_head.weight in parallel.
|
# load lm_head.weight in parallel.
|
||||||
|
|||||||
@ -420,7 +420,8 @@ class FalconForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_size = (get_tensor_model_parallel_world_size())
|
tp_size = (get_tensor_model_parallel_world_size())
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
@ -452,7 +453,7 @@ class FalconForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "query_key_value" in name:
|
if "query_key_value" in name:
|
||||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
loaded_weight_size = loaded_weight.size()
|
loaded_weight_size = loaded_weight.size()
|
||||||
|
|||||||
@ -231,14 +231,15 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
# GPT-2 ties the weights of the embedding layer and the final
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
# linear layer.
|
# linear layer.
|
||||||
|
|||||||
@ -259,14 +259,15 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
# GPT-2 ties the weights of the embedding layer and the final
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
# linear layer.
|
# linear layer.
|
||||||
|
|||||||
@ -222,11 +222,12 @@ class GPTJForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -231,11 +231,12 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||||
or "rotary_emb.inv_freq" in name):
|
or "rotary_emb.inv_freq" in name):
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -233,12 +233,13 @@ class InternLMForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -271,7 +271,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
@ -288,7 +289,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -244,12 +244,13 @@ class MPTForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "Wqkv" in name:
|
if "Wqkv" in name:
|
||||||
# NOTE(woosuk): MPT's fused QKV has the shape of
|
# NOTE(woosuk): MPT's fused QKV has the shape of
|
||||||
# [3 * num_heads * head_size, hidden_size].
|
# [3 * num_heads * head_size, hidden_size].
|
||||||
|
|||||||
@ -297,12 +297,13 @@ class OPTForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto"):
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -251,13 +251,14 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto",
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
):
|
):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, load_format):
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -83,6 +83,7 @@ def prepare_hf_model_weights(
|
|||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
fall_back_to_pt: bool = True,
|
fall_back_to_pt: bool = True,
|
||||||
|
revision: Optional[str] = None,
|
||||||
):
|
):
|
||||||
# Download model weights from huggingface.
|
# Download model weights from huggingface.
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
@ -94,7 +95,8 @@ def prepare_hf_model_weights(
|
|||||||
hf_folder = snapshot_download(model_name_or_path,
|
hf_folder = snapshot_download(model_name_or_path,
|
||||||
allow_patterns=allow_patterns,
|
allow_patterns=allow_patterns,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
tqdm_class=Disabledtqdm)
|
tqdm_class=Disabledtqdm,
|
||||||
|
revision=revision)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
||||||
@ -107,7 +109,8 @@ def prepare_hf_model_weights(
|
|||||||
return prepare_hf_model_weights(model_name_or_path,
|
return prepare_hf_model_weights(model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
use_safetensors=False,
|
use_safetensors=False,
|
||||||
fall_back_to_pt=False)
|
fall_back_to_pt=False,
|
||||||
|
revision=revision)
|
||||||
|
|
||||||
if len(hf_weights_files) == 0:
|
if len(hf_weights_files) == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -120,6 +123,7 @@ def hf_model_weights_iterator(
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto",
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
use_safetensors = False
|
use_safetensors = False
|
||||||
use_np_cache = False
|
use_np_cache = False
|
||||||
@ -140,7 +144,8 @@ def hf_model_weights_iterator(
|
|||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
use_safetensors=use_safetensors,
|
use_safetensors=use_safetensors,
|
||||||
fall_back_to_pt=fall_back_to_pt)
|
fall_back_to_pt=fall_back_to_pt,
|
||||||
|
revision=revision)
|
||||||
|
|
||||||
if use_np_cache:
|
if use_np_cache:
|
||||||
# Currently np_cache only support *.bin checkpoints
|
# Currently np_cache only support *.bin checkpoints
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
|
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
|
||||||
@ -12,10 +14,12 @@ _CONFIG_REGISTRY = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
|
def get_config(model: str,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
revision: Optional[str] = None) -> PretrainedConfig:
|
||||||
try:
|
try:
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model, trust_remote_code=trust_remote_code)
|
model, trust_remote_code=trust_remote_code, revision=revision)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if (not trust_remote_code and
|
if (not trust_remote_code and
|
||||||
"requires you to execute the configuration file" in str(e)):
|
"requires you to execute the configuration file" in str(e)):
|
||||||
@ -29,5 +33,5 @@ def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
|
|||||||
raise e
|
raise e
|
||||||
if config.model_type in _CONFIG_REGISTRY:
|
if config.model_type in _CONFIG_REGISTRY:
|
||||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||||
config = config_class.from_pretrained(model)
|
config = config_class.from_pretrained(model, revision=revision)
|
||||||
return config
|
return config
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user