[BUG] Remove token param #10921 (#11022)

Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
Flávia Béo 2024-12-10 14:38:15 -03:00 committed by GitHub
parent 9b9cef3145
commit 250ee65d72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
import enum import enum
import json import json
import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Type, Union from typing import Any, Dict, Optional, Type, Union
@ -41,6 +42,7 @@ else:
from transformers import AutoConfig from transformers import AutoConfig
MISTRAL_CONFIG_NAME = "params.json" MISTRAL_CONFIG_NAME = "params.json"
HF_TOKEN = os.getenv('HF_TOKEN', None)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -77,8 +79,8 @@ class ConfigFormat(str, enum.Enum):
MISTRAL = "mistral" MISTRAL = "mistral"
def file_or_path_exists(model: Union[str, Path], config_name, revision, def file_or_path_exists(model: Union[str, Path], config_name: str,
token) -> bool: revision: Optional[str]) -> bool:
if Path(model).exists(): if Path(model).exists():
return (Path(model) / config_name).is_file() return (Path(model) / config_name).is_file()
@ -93,7 +95,10 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
# NB: file_exists will only check for the existence of the config file on # NB: file_exists will only check for the existence of the config file on
# hf_hub. This will fail in offline mode. # hf_hub. This will fail in offline mode.
try: try:
return file_exists(model, config_name, revision=revision, token=token) return file_exists(model,
config_name,
revision=revision,
token=HF_TOKEN)
except huggingface_hub.errors.OfflineModeIsEnabled: except huggingface_hub.errors.OfflineModeIsEnabled:
# Don't raise in offline mode, all we know is that we don't have this # Don't raise in offline mode, all we know is that we don't have this
# file cached. # file cached.
@ -161,7 +166,6 @@ def get_config(
revision: Optional[str] = None, revision: Optional[str] = None,
code_revision: Optional[str] = None, code_revision: Optional[str] = None,
config_format: ConfigFormat = ConfigFormat.AUTO, config_format: ConfigFormat = ConfigFormat.AUTO,
token: Optional[str] = None,
**kwargs, **kwargs,
) -> PretrainedConfig: ) -> PretrainedConfig:
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
@ -173,19 +177,20 @@ def get_config(
if config_format == ConfigFormat.AUTO: if config_format == ConfigFormat.AUTO:
if is_gguf or file_or_path_exists( if is_gguf or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision, token=token): model, HF_CONFIG_NAME, revision=revision):
config_format = ConfigFormat.HF config_format = ConfigFormat.HF
elif file_or_path_exists(model, elif file_or_path_exists(model, MISTRAL_CONFIG_NAME,
MISTRAL_CONFIG_NAME, revision=revision):
revision=revision,
token=token):
config_format = ConfigFormat.MISTRAL config_format = ConfigFormat.MISTRAL
else: else:
# If we're in offline mode and found no valid config format, then # If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they # raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online. # don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists(). # This is conveniently triggered by calling file_exists().
file_exists(model, HF_CONFIG_NAME, revision=revision, token=token) file_exists(model,
HF_CONFIG_NAME,
revision=revision,
token=HF_TOKEN)
raise ValueError(f"No supported config format found in {model}") raise ValueError(f"No supported config format found in {model}")
@ -194,7 +199,7 @@ def get_config(
model, model,
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
token=token, token=HF_TOKEN,
**kwargs, **kwargs,
) )
@ -206,7 +211,7 @@ def get_config(
model, model,
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
token=token, token=HF_TOKEN,
**kwargs, **kwargs,
) )
else: else:
@ -216,7 +221,7 @@ def get_config(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
token=token, token=HF_TOKEN,
**kwargs, **kwargs,
) )
except ValueError as e: except ValueError as e:
@ -234,7 +239,7 @@ def get_config(
raise e raise e
elif config_format == ConfigFormat.MISTRAL: elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision, token=token, **kwargs) config = load_params_config(model, revision, token=HF_TOKEN, **kwargs)
else: else:
raise ValueError(f"Unsupported config format: {config_format}") raise ValueError(f"Unsupported config format: {config_format}")
@ -256,8 +261,7 @@ def get_config(
def get_hf_file_to_dict(file_name: str, def get_hf_file_to_dict(file_name: str,
model: Union[str, Path], model: Union[str, Path],
revision: Optional[str] = 'main', revision: Optional[str] = 'main'):
token: Optional[str] = None):
""" """
Downloads a file from the Hugging Face Hub and returns Downloads a file from the Hugging Face Hub and returns
its contents as a dictionary. its contents as a dictionary.
@ -266,7 +270,6 @@ def get_hf_file_to_dict(file_name: str,
- file_name (str): The name of the file to download. - file_name (str): The name of the file to download.
- model (str): The name of the model on the Hugging Face Hub. - model (str): The name of the model on the Hugging Face Hub.
- revision (str): The specific version of the model. - revision (str): The specific version of the model.
- token (str): The Hugging Face authentication token.
Returns: Returns:
- config_dict (dict): A dictionary containing - config_dict (dict): A dictionary containing
@ -276,8 +279,7 @@ def get_hf_file_to_dict(file_name: str,
if file_or_path_exists(model=model, if file_or_path_exists(model=model,
config_name=file_name, config_name=file_name,
revision=revision, revision=revision):
token=token):
if not file_path.is_file(): if not file_path.is_file():
try: try:
@ -296,9 +298,7 @@ def get_hf_file_to_dict(file_name: str,
return None return None
def get_pooling_config(model: str, def get_pooling_config(model: str, revision: Optional[str] = 'main'):
revision: Optional[str] = 'main',
token: Optional[str] = None):
""" """
This function gets the pooling and normalize This function gets the pooling and normalize
config from the model - only applies to config from the model - only applies to
@ -315,8 +315,7 @@ def get_pooling_config(model: str,
""" """
modules_file_name = "modules.json" modules_file_name = "modules.json"
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision, modules_dict = get_hf_file_to_dict(modules_file_name, model, revision)
token)
if modules_dict is None: if modules_dict is None:
return None return None
@ -332,8 +331,7 @@ def get_pooling_config(model: str,
if pooling: if pooling:
pooling_file_name = "{}/config.json".format(pooling["path"]) pooling_file_name = "{}/config.json".format(pooling["path"])
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision, pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
token)
pooling_type_name = next( pooling_type_name = next(
(item for item, val in pooling_dict.items() if val is True), None) (item for item, val in pooling_dict.items() if val is True), None)
@ -368,8 +366,8 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
def get_sentence_transformer_tokenizer_config(model: str, def get_sentence_transformer_tokenizer_config(model: str,
revision: Optional[str] = 'main', revision: Optional[str] = 'main'
token: Optional[str] = None): ):
""" """
Returns the tokenization configuration dictionary for a Returns the tokenization configuration dictionary for a
given Sentence Transformer BERT model. given Sentence Transformer BERT model.
@ -379,7 +377,6 @@ def get_sentence_transformer_tokenizer_config(model: str,
BERT model. BERT model.
- revision (str, optional): The revision of the m - revision (str, optional): The revision of the m
odel to use. Defaults to 'main'. odel to use. Defaults to 'main'.
- token (str): A Hugging Face access token.
Returns: Returns:
- dict: A dictionary containing the configuration parameters - dict: A dictionary containing the configuration parameters
@ -394,7 +391,7 @@ def get_sentence_transformer_tokenizer_config(model: str,
"sentence_xlm-roberta_config.json", "sentence_xlm-roberta_config.json",
"sentence_xlnet_config.json", "sentence_xlnet_config.json",
]: ]:
encoder_dict = get_hf_file_to_dict(config_name, model, revision, token) encoder_dict = get_hf_file_to_dict(config_name, model, revision)
if encoder_dict: if encoder_dict:
break break
@ -474,16 +471,14 @@ def maybe_register_config_serialize_by_value() -> None:
exc_info=e) exc_info=e)
def load_params_config(model: Union[str, Path], def load_params_config(model: Union[str, Path], revision: Optional[str],
revision: Optional[str],
token: Optional[str] = None,
**kwargs) -> PretrainedConfig: **kwargs) -> PretrainedConfig:
# This function loads a params.json config which # This function loads a params.json config which
# should be used when loading models in mistral format # should be used when loading models in mistral format
config_file_name = "params.json" config_file_name = "params.json"
config_dict = get_hf_file_to_dict(config_file_name, model, revision, token) config_dict = get_hf_file_to_dict(config_file_name, model, revision)
assert isinstance(config_dict, dict) assert isinstance(config_dict, dict)
config_mapping = { config_mapping = {