[misc] improve cloudpickle registration and tests (#10202)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-10 16:10:53 -08:00 committed by GitHub
parent 20cf2f553c
commit 73b9083e99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 31 deletions

View File

@ -32,6 +32,8 @@ class PPTestOptions(NamedTuple):
multi_node_only: bool multi_node_only: bool
trust_remote_code: bool trust_remote_code: bool
tokenizer_mode: Optional[str] tokenizer_mode: Optional[str]
load_format: Optional[str] = None
hf_overrides: Optional[str] = None
@dataclass @dataclass
@ -50,6 +52,8 @@ class PPTestSettings:
task: TaskOption = "auto", task: TaskOption = "auto",
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None, tokenizer_mode: Optional[str] = None,
load_format: Optional[str] = None,
hf_overrides: Optional[str] = None,
): ):
return PPTestSettings( return PPTestSettings(
parallel_setups=[ parallel_setups=[
@ -78,7 +82,9 @@ class PPTestSettings:
task=task, task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only, test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode), tokenizer_mode=tokenizer_mode,
load_format=load_format,
hf_overrides=hf_overrides),
) )
@staticmethod @staticmethod
@ -90,6 +96,8 @@ class PPTestSettings:
multi_node_only: bool = False, multi_node_only: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None, tokenizer_mode: Optional[str] = None,
load_format: Optional[str] = None,
hf_overrides: Optional[str] = None,
): ):
return PPTestSettings( return PPTestSettings(
parallel_setups=[ parallel_setups=[
@ -102,7 +110,9 @@ class PPTestSettings:
task=task, task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only, test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode), tokenizer_mode=tokenizer_mode,
load_format=load_format,
hf_overrides=hf_overrides),
) )
def iter_params(self, model_name: str): def iter_params(self, model_name: str):
@ -161,9 +171,8 @@ TEXT_GENERATION_MODELS = {
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(), "facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True), "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
"microsoft/phi-2": PPTestSettings.fast(), "microsoft/phi-2": PPTestSettings.fast(),
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501 "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"adept/persimmon-8b-chat": PPTestSettings.fast(), "adept/persimmon-8b-chat": PPTestSettings.fast(),
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True), "Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
@ -214,9 +223,9 @@ MULTIMODAL_MODELS = {
# NOTE: You can update this on your local machine to run specific tests # NOTE: You can update this on your local machine to run specific tests
TEST_MODELS = [ TEST_MODELS = [
# [LANGUAGE GENERATION] # [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct",
"meta-llama/Meta-Llama-3-8B", "meta-llama/Meta-Llama-3-8B",
"ibm/PowerLM-3b", "ibm/PowerLM-3b",
"microsoft/Phi-3-mini-4k-instruct",
# [LANGUAGE EMBEDDING] # [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct", "intfloat/e5-mistral-7b-instruct",
"BAAI/bge-multilingual-gemma2", "BAAI/bge-multilingual-gemma2",
@ -238,7 +247,8 @@ def _compare_tp(
method: Literal["generate", "encode"], method: Literal["generate", "encode"],
): ):
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
multi_node_only, trust_remote_code, tokenizer_mode = test_options multi_node_only, trust_remote_code, tokenizer_mode, \
load_format, hf_overrides = test_options
if num_gpus_available < tp_size * pp_size: if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
@ -267,6 +277,10 @@ def _compare_tp(
common_args.append("--trust-remote-code") common_args.append("--trust-remote-code")
if tokenizer_mode: if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode]) common_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", hf_overrides])
if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2 if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
and chunked_prefill): and chunked_prefill):

View File

@ -19,8 +19,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import FlexibleArgumentParser, StoreBoolean from vllm.utils import FlexibleArgumentParser, StoreBoolean
@ -1013,8 +1011,6 @@ class EngineArgs:
"supported for multimodal models and has been disabled.") "supported for multimodal models and has been disabled.")
self.enable_prefix_caching = False self.enable_prefix_caching = False
maybe_register_config_serialize_by_value(self.trust_remote_code)
cache_config = CacheConfig( cache_config = CacheConfig(
# neuron needs block_size = max_model_len # neuron needs block_size = max_model_len
block_size=self.block_size if self.device != "neuron" else block_size=self.block_size if self.device != "neuron" else

View File

@ -234,6 +234,9 @@ def get_config(
patch_rope_scaling(config) patch_rope_scaling(config)
if trust_remote_code:
maybe_register_config_serialize_by_value()
return config return config
@ -389,33 +392,39 @@ def get_sentence_transformer_tokenizer_config(model: str,
return None return None
def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None: def maybe_register_config_serialize_by_value() -> None:
"""Try to register HF model configuration class to serialize by value """Try to register HF model configuration class to serialize by value
With trust_remote_code, the config class is typically an instance of a If trust_remote_code is set, and the model's config file specifies an
custom class imported from the HF modules cache. The class will not be `AutoConfig` class, then the config class is typically an instance of
importable in spawned workers by default (and won't exist at all on a custom class imported from the HF modules cache.
other nodes), which breaks serialization of the config.
Examples:
>>> from transformers import AutoConfig
>>> klass = AutoConfig.from_pretrained('meta-llama/Meta-Llama-3-8B', trust_remote_code=True)
>>> klass.__class__ # transformers.models.llama.configuration_llama.LlamaConfig
>>> import transformers_modules # error, not initialized
>>> klass = AutoConfig.from_pretrained('deepseek-ai/DeepSeek-V2.5', trust_remote_code=True)
>>> import transformers_modules # success, initialized
>>> klass.__class__ # transformers_modules.deepseek-ai.DeepSeek-V2.5.98b11844770b2c3ffc18b175c758a803640f4e77.configuration_deepseek.DeepseekV2Config
In the DeepSeek example, the config class is an instance of a custom
class that is not serializable by default. This class will not be
importable in spawned workers, and won't exist at all on
other nodes, which breaks serialization of the config.
In this function we tell the cloudpickle serialization library to pass In this function we tell the cloudpickle serialization library to pass
instances of these generated classes by value instead of by reference, instances of these generated classes by value instead of by reference,
i.e. the class definition is serialized along with its data so that the i.e. the class definition is serialized along with its data so that the
class module does not need to be importable on the receiving end. This class module does not need to be importable on the receiving end.
registration only works if the modules cache has already been
initialized.
See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
""" """ # noqa
if not trust_remote_code:
return
try: try:
import transformers_modules import transformers_modules
except ImportError: except ImportError:
logger.debug("Could not import transformers_modules used for remote" # the config does not need trust_remote_code
" code. If remote code is not needed remove"
" `--trust-remote-code`.")
return return
try: try:
@ -428,19 +437,19 @@ def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
ray.cloudpickle.register_pickle_by_value(transformers_modules) ray.cloudpickle.register_pickle_by_value(transformers_modules)
# multiprocessing uses pickle to serialize arguments when using spawn # multiprocessing uses pickle to serialize arguments when using spawn
# Here we get pickle to use cloudpickle to serialize ModelConfig objects # Here we get pickle to use cloudpickle to serialize config objects
# that contain instances of the custom config class to avoid # that contain instances of the custom config class to avoid
# serialization problems if the generated module (and model) has a `.` # serialization problems if the generated module (and model) has a `.`
# in its name # in its name
import multiprocessing import multiprocessing
import pickle import pickle
from vllm.config import ModelConfig from vllm.config import VllmConfig
def _reduce_modelconfig(mc: ModelConfig): def _reduce_config(config: VllmConfig):
return (pickle.loads, (cloudpickle.dumps(mc), )) return (pickle.loads, (cloudpickle.dumps(config), ))
multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig) multiprocessing.reducer.register(VllmConfig, _reduce_config)
except Exception as e: except Exception as e:
logger.warning( logger.warning(