diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index fee201850f20..49c80bd64042 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -28,19 +28,25 @@ class ParallelSetup(NamedTuple): chunked_prefill: bool +class PPTestOptions(NamedTuple): + multi_node_only: bool + trust_remote_code: bool + tokenizer_mode: Optional[str] + + @dataclass class PPTestSettings: parallel_setups: List[ParallelSetup] distributed_backends: List[str] task: TaskOption - trust_remote_code: bool - tokenizer_mode: Optional[str] + test_options: PPTestOptions @staticmethod def detailed( *, tp_base: int = 1, pp_base: int = 2, + multi_node_only: bool = False, task: TaskOption = "auto", trust_remote_code: bool = False, tokenizer_mode: Optional[str] = None, @@ -70,8 +76,9 @@ class PPTestSettings: ], distributed_backends=["mp", "ray"], task=task, - trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, + test_options=PPTestOptions(multi_node_only=multi_node_only, + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode), ) @staticmethod @@ -80,6 +87,7 @@ class PPTestSettings: tp_base: int = 1, pp_base: int = 2, task: TaskOption = "auto", + multi_node_only: bool = False, trust_remote_code: bool = False, tokenizer_mode: Optional[str] = None, ): @@ -92,15 +100,18 @@ class PPTestSettings: ], distributed_backends=["mp"], task=task, - trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, + test_options=PPTestOptions(multi_node_only=multi_node_only, + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode), ) def iter_params(self, model_name: str): + opts = self.test_options + for parallel_setup in self.parallel_setups: for distributed_backend in self.distributed_backends: yield (model_name, parallel_setup, distributed_backend, - self.task, self.trust_remote_code, self.tokenizer_mode) + self.task, opts) # NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU @@ -110,6 +121,7 @@ class PPTestSettings: GENERATION_MODEL_SETTINGS = { # [DETAILED TESTS] "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), + "microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501 # [FAST TESTS] # Uses Llama # "BAAI/AquilaChat-7B": PPTestSettings.fast(), @@ -151,10 +163,8 @@ GENERATION_MODEL_SETTINGS = { "facebook/opt-iml-max-1.3b": PPTestSettings.fast(), "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True), "microsoft/phi-2": PPTestSettings.fast(), - "microsoft/Phi-3-mini-4k-instruct": PPTestSettings.fast(), "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 - # FIXME: https://github.com/vllm-project/vllm/issues/8553 - # "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 "adept/persimmon-8b-chat": PPTestSettings.fast(), "Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True), "Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(), @@ -205,6 +215,7 @@ TEST_MODELS = [ # [LANGUAGE GENERATION] "meta-llama/Meta-Llama-3-8B", "ibm/PowerLM-3b", + "microsoft/Phi-3-mini-4k-instruct", # [LANGUAGE EMBEDDING] "intfloat/e5-mistral-7b-instruct", "BAAI/bge-multilingual-gemma2", @@ -220,19 +231,21 @@ def _compare_tp( parallel_setup: ParallelSetup, distributed_backend: str, task: TaskOption, - trust_remote_code: bool, - tokenizer_mode: Optional[str], + test_options: PPTestOptions, num_gpus_available: int, *, - method: Literal["generate", "encode"] = "encode", + method: Literal["generate", "encode"], ): tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup + multi_node_only, trust_remote_code, tokenizer_mode = test_options if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") if VLLM_MULTI_NODE and distributed_backend == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") + if multi_node_only and not VLLM_MULTI_NODE: + pytest.skip("Not in multi-node setting") common_args = [ # use half precision for speed and memory savings in CI environment @@ -307,7 +320,7 @@ def _compare_tp( @pytest.mark.parametrize( ("model_name", "parallel_setup", "distributed_backend", "task", - "trust_remote_code", "tokenizer_mode"), + "test_options"), [ params for model_name, settings in GENERATION_MODEL_SETTINGS.items() for params in settings.iter_params(model_name) @@ -320,23 +333,21 @@ def test_tp_language_generation( parallel_setup: ParallelSetup, distributed_backend: str, task: TaskOption, - trust_remote_code: bool, - tokenizer_mode: Optional[str], + test_options: PPTestOptions, num_gpus_available, ): _compare_tp(model_name, parallel_setup, distributed_backend, task, - trust_remote_code, - tokenizer_mode, + test_options, num_gpus_available, method="generate") @pytest.mark.parametrize( ("model_name", "parallel_setup", "distributed_backend", "task", - "trust_remote_code", "tokenizer_mode"), + "test_options"), [ params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items() for params in settings.iter_params(model_name) @@ -349,23 +360,21 @@ def test_tp_language_embedding( parallel_setup: ParallelSetup, distributed_backend: str, task: TaskOption, - trust_remote_code: bool, - tokenizer_mode: Optional[str], + test_options: PPTestOptions, num_gpus_available, ): _compare_tp(model_name, parallel_setup, distributed_backend, task, - trust_remote_code, - tokenizer_mode, + test_options, num_gpus_available, method="encode") @pytest.mark.parametrize( ("model_name", "parallel_setup", "distributed_backend", "task", - "trust_remote_code", "tokenizer_mode"), + "test_options"), [ params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items() for params in settings.iter_params(model_name) @@ -378,15 +387,13 @@ def test_tp_multimodal_generation( parallel_setup: ParallelSetup, distributed_backend: str, task: TaskOption, - trust_remote_code: bool, - tokenizer_mode: Optional[str], + test_options: PPTestOptions, num_gpus_available, ): _compare_tp(model_name, parallel_setup, distributed_backend, task, - trust_remote_code, - tokenizer_mode, + test_options, num_gpus_available, method="generate") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 56582ab61879..a5cfaf3977a4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -16,6 +16,8 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.transformers_utils.config import ( + maybe_register_config_serialize_by_value) from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import FlexibleArgumentParser @@ -924,6 +926,8 @@ class EngineArgs: "supported for multimodal models and has been disabled.") self.enable_prefix_caching = False + maybe_register_config_serialize_by_value(self.trust_remote_code) + cache_config = CacheConfig( # neuron needs block_size = max_model_len block_size=self.block_size if self.device != "neuron" else diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 46405f352921..9bd2531d7a15 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -232,6 +232,68 @@ def get_config( return config +def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None: + """Try to register HF model configuration class to serialize by value + + With trust_remote_code, the config class is typically an instance of a + custom class imported from the HF modules cache. The class will not be + importable in spawned workers by default (and won't exist at all on + other nodes), which breaks serialization of the config. + + In this function we tell the cloudpickle serialization library to pass + instances of these generated classes by value instead of by reference, + i.e. the class definition is serialized along with its data so that the + class module does not need to be importable on the receiving end. This + registration only works if the modules cache has already been + initialized. + + + See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs + """ + if not trust_remote_code: + return + + try: + import transformers_modules + except ImportError: + logger.debug("Could not import transformers_modules used for remote" + " code. If remote code is not needed remove" + " `--trust-remote-code`.") + return + + try: + import cloudpickle + cloudpickle.register_pickle_by_value(transformers_modules) + + # ray vendors its own version of cloudpickle + from vllm.executor.ray_utils import ray + if ray: + ray.cloudpickle.register_pickle_by_value(transformers_modules) + + # multiprocessing uses pickle to serialize arguments when using spawn + # Here we get pickle to use cloudpickle to serialize ModelConfig objects + # that contain instances of the custom config class to avoid + # serialization problems if the generated module (and model) has a `.` + # in its name + import multiprocessing + import pickle + + from vllm.config import ModelConfig + + def _reduce_modelconfig(mc: ModelConfig): + return (pickle.loads, (cloudpickle.dumps(mc), )) + + multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig) + + except Exception as e: + logger.warning( + "Unable to register remote classes used by" + " trust_remote_code with by-value serialization. This may" + " lead to a later error. If remote code is not needed" + " remove `--trust-remote-code`", + exc_info=e) + + def load_params_config(model, revision) -> PretrainedConfig: # This function loads a params.json config which # should be used when loading models in mistral format diff --git a/vllm/utils.py b/vllm/utils.py index 695764dadc12..d1a995a3ac8c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -968,6 +968,8 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]: return [item for sublist in lists for item in sublist] +# TODO: This function can be removed if transformer_modules classes are +# serialized by value when communicating between processes def init_cached_hf_modules() -> None: """ Lazy initialization of the Hugging Face modules.