mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 06:17:52 +08:00
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
parent
9b9cef3145
commit
250ee65d72
@ -1,5 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Type, Union
|
from typing import Any, Dict, Optional, Type, Union
|
||||||
|
|
||||||
@ -41,6 +42,7 @@ else:
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
MISTRAL_CONFIG_NAME = "params.json"
|
MISTRAL_CONFIG_NAME = "params.json"
|
||||||
|
HF_TOKEN = os.getenv('HF_TOKEN', None)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -77,8 +79,8 @@ class ConfigFormat(str, enum.Enum):
|
|||||||
MISTRAL = "mistral"
|
MISTRAL = "mistral"
|
||||||
|
|
||||||
|
|
||||||
def file_or_path_exists(model: Union[str, Path], config_name, revision,
|
def file_or_path_exists(model: Union[str, Path], config_name: str,
|
||||||
token) -> bool:
|
revision: Optional[str]) -> bool:
|
||||||
if Path(model).exists():
|
if Path(model).exists():
|
||||||
return (Path(model) / config_name).is_file()
|
return (Path(model) / config_name).is_file()
|
||||||
|
|
||||||
@ -93,7 +95,10 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
|
|||||||
# NB: file_exists will only check for the existence of the config file on
|
# NB: file_exists will only check for the existence of the config file on
|
||||||
# hf_hub. This will fail in offline mode.
|
# hf_hub. This will fail in offline mode.
|
||||||
try:
|
try:
|
||||||
return file_exists(model, config_name, revision=revision, token=token)
|
return file_exists(model,
|
||||||
|
config_name,
|
||||||
|
revision=revision,
|
||||||
|
token=HF_TOKEN)
|
||||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||||
# Don't raise in offline mode, all we know is that we don't have this
|
# Don't raise in offline mode, all we know is that we don't have this
|
||||||
# file cached.
|
# file cached.
|
||||||
@ -161,7 +166,6 @@ def get_config(
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
code_revision: Optional[str] = None,
|
code_revision: Optional[str] = None,
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||||
token: Optional[str] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> PretrainedConfig:
|
) -> PretrainedConfig:
|
||||||
# Separate model folder from file path for GGUF models
|
# Separate model folder from file path for GGUF models
|
||||||
@ -173,19 +177,20 @@ def get_config(
|
|||||||
|
|
||||||
if config_format == ConfigFormat.AUTO:
|
if config_format == ConfigFormat.AUTO:
|
||||||
if is_gguf or file_or_path_exists(
|
if is_gguf or file_or_path_exists(
|
||||||
model, HF_CONFIG_NAME, revision=revision, token=token):
|
model, HF_CONFIG_NAME, revision=revision):
|
||||||
config_format = ConfigFormat.HF
|
config_format = ConfigFormat.HF
|
||||||
elif file_or_path_exists(model,
|
elif file_or_path_exists(model, MISTRAL_CONFIG_NAME,
|
||||||
MISTRAL_CONFIG_NAME,
|
revision=revision):
|
||||||
revision=revision,
|
|
||||||
token=token):
|
|
||||||
config_format = ConfigFormat.MISTRAL
|
config_format = ConfigFormat.MISTRAL
|
||||||
else:
|
else:
|
||||||
# If we're in offline mode and found no valid config format, then
|
# If we're in offline mode and found no valid config format, then
|
||||||
# raise an offline mode error to indicate to the user that they
|
# raise an offline mode error to indicate to the user that they
|
||||||
# don't have files cached and may need to go online.
|
# don't have files cached and may need to go online.
|
||||||
# This is conveniently triggered by calling file_exists().
|
# This is conveniently triggered by calling file_exists().
|
||||||
file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
|
file_exists(model,
|
||||||
|
HF_CONFIG_NAME,
|
||||||
|
revision=revision,
|
||||||
|
token=HF_TOKEN)
|
||||||
|
|
||||||
raise ValueError(f"No supported config format found in {model}")
|
raise ValueError(f"No supported config format found in {model}")
|
||||||
|
|
||||||
@ -194,7 +199,7 @@ def get_config(
|
|||||||
model,
|
model,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
code_revision=code_revision,
|
code_revision=code_revision,
|
||||||
token=token,
|
token=HF_TOKEN,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -206,7 +211,7 @@ def get_config(
|
|||||||
model,
|
model,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
code_revision=code_revision,
|
code_revision=code_revision,
|
||||||
token=token,
|
token=HF_TOKEN,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -216,7 +221,7 @@ def get_config(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
code_revision=code_revision,
|
code_revision=code_revision,
|
||||||
token=token,
|
token=HF_TOKEN,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -234,7 +239,7 @@ def get_config(
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
elif config_format == ConfigFormat.MISTRAL:
|
elif config_format == ConfigFormat.MISTRAL:
|
||||||
config = load_params_config(model, revision, token=token, **kwargs)
|
config = load_params_config(model, revision, token=HF_TOKEN, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported config format: {config_format}")
|
raise ValueError(f"Unsupported config format: {config_format}")
|
||||||
|
|
||||||
@ -256,8 +261,7 @@ def get_config(
|
|||||||
|
|
||||||
def get_hf_file_to_dict(file_name: str,
|
def get_hf_file_to_dict(file_name: str,
|
||||||
model: Union[str, Path],
|
model: Union[str, Path],
|
||||||
revision: Optional[str] = 'main',
|
revision: Optional[str] = 'main'):
|
||||||
token: Optional[str] = None):
|
|
||||||
"""
|
"""
|
||||||
Downloads a file from the Hugging Face Hub and returns
|
Downloads a file from the Hugging Face Hub and returns
|
||||||
its contents as a dictionary.
|
its contents as a dictionary.
|
||||||
@ -266,7 +270,6 @@ def get_hf_file_to_dict(file_name: str,
|
|||||||
- file_name (str): The name of the file to download.
|
- file_name (str): The name of the file to download.
|
||||||
- model (str): The name of the model on the Hugging Face Hub.
|
- model (str): The name of the model on the Hugging Face Hub.
|
||||||
- revision (str): The specific version of the model.
|
- revision (str): The specific version of the model.
|
||||||
- token (str): The Hugging Face authentication token.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- config_dict (dict): A dictionary containing
|
- config_dict (dict): A dictionary containing
|
||||||
@ -276,8 +279,7 @@ def get_hf_file_to_dict(file_name: str,
|
|||||||
|
|
||||||
if file_or_path_exists(model=model,
|
if file_or_path_exists(model=model,
|
||||||
config_name=file_name,
|
config_name=file_name,
|
||||||
revision=revision,
|
revision=revision):
|
||||||
token=token):
|
|
||||||
|
|
||||||
if not file_path.is_file():
|
if not file_path.is_file():
|
||||||
try:
|
try:
|
||||||
@ -296,9 +298,7 @@ def get_hf_file_to_dict(file_name: str,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_pooling_config(model: str,
|
def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
||||||
revision: Optional[str] = 'main',
|
|
||||||
token: Optional[str] = None):
|
|
||||||
"""
|
"""
|
||||||
This function gets the pooling and normalize
|
This function gets the pooling and normalize
|
||||||
config from the model - only applies to
|
config from the model - only applies to
|
||||||
@ -315,8 +315,7 @@ def get_pooling_config(model: str,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
modules_file_name = "modules.json"
|
modules_file_name = "modules.json"
|
||||||
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision,
|
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision)
|
||||||
token)
|
|
||||||
|
|
||||||
if modules_dict is None:
|
if modules_dict is None:
|
||||||
return None
|
return None
|
||||||
@ -332,8 +331,7 @@ def get_pooling_config(model: str,
|
|||||||
if pooling:
|
if pooling:
|
||||||
|
|
||||||
pooling_file_name = "{}/config.json".format(pooling["path"])
|
pooling_file_name = "{}/config.json".format(pooling["path"])
|
||||||
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision,
|
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
|
||||||
token)
|
|
||||||
pooling_type_name = next(
|
pooling_type_name = next(
|
||||||
(item for item, val in pooling_dict.items() if val is True), None)
|
(item for item, val in pooling_dict.items() if val is True), None)
|
||||||
|
|
||||||
@ -368,8 +366,8 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
|
|||||||
|
|
||||||
|
|
||||||
def get_sentence_transformer_tokenizer_config(model: str,
|
def get_sentence_transformer_tokenizer_config(model: str,
|
||||||
revision: Optional[str] = 'main',
|
revision: Optional[str] = 'main'
|
||||||
token: Optional[str] = None):
|
):
|
||||||
"""
|
"""
|
||||||
Returns the tokenization configuration dictionary for a
|
Returns the tokenization configuration dictionary for a
|
||||||
given Sentence Transformer BERT model.
|
given Sentence Transformer BERT model.
|
||||||
@ -379,7 +377,6 @@ def get_sentence_transformer_tokenizer_config(model: str,
|
|||||||
BERT model.
|
BERT model.
|
||||||
- revision (str, optional): The revision of the m
|
- revision (str, optional): The revision of the m
|
||||||
odel to use. Defaults to 'main'.
|
odel to use. Defaults to 'main'.
|
||||||
- token (str): A Hugging Face access token.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict: A dictionary containing the configuration parameters
|
- dict: A dictionary containing the configuration parameters
|
||||||
@ -394,7 +391,7 @@ def get_sentence_transformer_tokenizer_config(model: str,
|
|||||||
"sentence_xlm-roberta_config.json",
|
"sentence_xlm-roberta_config.json",
|
||||||
"sentence_xlnet_config.json",
|
"sentence_xlnet_config.json",
|
||||||
]:
|
]:
|
||||||
encoder_dict = get_hf_file_to_dict(config_name, model, revision, token)
|
encoder_dict = get_hf_file_to_dict(config_name, model, revision)
|
||||||
if encoder_dict:
|
if encoder_dict:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -474,16 +471,14 @@ def maybe_register_config_serialize_by_value() -> None:
|
|||||||
exc_info=e)
|
exc_info=e)
|
||||||
|
|
||||||
|
|
||||||
def load_params_config(model: Union[str, Path],
|
def load_params_config(model: Union[str, Path], revision: Optional[str],
|
||||||
revision: Optional[str],
|
|
||||||
token: Optional[str] = None,
|
|
||||||
**kwargs) -> PretrainedConfig:
|
**kwargs) -> PretrainedConfig:
|
||||||
# This function loads a params.json config which
|
# This function loads a params.json config which
|
||||||
# should be used when loading models in mistral format
|
# should be used when loading models in mistral format
|
||||||
|
|
||||||
config_file_name = "params.json"
|
config_file_name = "params.json"
|
||||||
|
|
||||||
config_dict = get_hf_file_to_dict(config_file_name, model, revision, token)
|
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
|
||||||
assert isinstance(config_dict, dict)
|
assert isinstance(config_dict, dict)
|
||||||
|
|
||||||
config_mapping = {
|
config_mapping = {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user