diff --git a/tests/config/test_mp_reducer.py b/tests/config/test_mp_reducer.py new file mode 100644 index 000000000000..ee351cbfa7c1 --- /dev/null +++ b/tests/config/test_mp_reducer.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import sys +from unittest.mock import patch + +from vllm.config import VllmConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.v1.engine.async_llm import AsyncLLM + + +def test_mp_reducer(monkeypatch): + """ + Test that _reduce_config reducer is registered when AsyncLLM is instantiated + without transformers_modules. This is a regression test for + https://github.com/vllm-project/vllm/pull/18640. + """ + + # Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value + monkeypatch.setenv('VLLM_USE_V1', '1') + + # Ensure transformers_modules is not in sys.modules + if 'transformers_modules' in sys.modules: + del sys.modules['transformers_modules'] + + with patch('multiprocessing.reducer.register') as mock_register: + engine_args = AsyncEngineArgs( + model="facebook/opt-125m", + max_model_len=32, + gpu_memory_utilization=0.1, + disable_log_stats=True, + disable_log_requests=True, + ) + + async_llm = AsyncLLM.from_engine_args( + engine_args, + start_engine_loop=False, + ) + + assert mock_register.called, ( + "multiprocessing.reducer.register should have been called") + + vllm_config_registered = False + for call_args in mock_register.call_args_list: + # Verify that a reducer for VllmConfig was registered + if len(call_args[0]) >= 2 and call_args[0][0] == VllmConfig: + vllm_config_registered = True + + reducer_func = call_args[0][1] + assert callable( + reducer_func), "Reducer function should be callable" + break + + assert vllm_config_registered, ( + "VllmConfig should have been registered to multiprocessing.reducer" + ) + + async_llm.shutdown() diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3cc112790013..5c422a9e3fce 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -655,34 +655,35 @@ def maybe_register_config_serialize_by_value() -> None: """ # noqa try: import transformers_modules + transformers_modules_available = True except ImportError: - # the config does not need trust_remote_code - return + transformers_modules_available = False try: - import cloudpickle - cloudpickle.register_pickle_by_value(transformers_modules) - - # ray vendors its own version of cloudpickle - from vllm.executor.ray_utils import ray - if ray: - ray.cloudpickle.register_pickle_by_value(transformers_modules) - - # multiprocessing uses pickle to serialize arguments when using spawn - # Here we get pickle to use cloudpickle to serialize config objects - # that contain instances of the custom config class to avoid - # serialization problems if the generated module (and model) has a `.` - # in its name import multiprocessing import pickle + import cloudpickle + from vllm.config import VllmConfig + # Register multiprocessing reducers to handle cross-process + # serialization of VllmConfig objects that may contain custom configs + # from transformers_modules def _reduce_config(config: VllmConfig): return (pickle.loads, (cloudpickle.dumps(config), )) multiprocessing.reducer.register(VllmConfig, _reduce_config) + # Register transformers_modules with cloudpickle if available + if transformers_modules_available: + cloudpickle.register_pickle_by_value(transformers_modules) + + # ray vendors its own version of cloudpickle + from vllm.executor.ray_utils import ray + if ray: + ray.cloudpickle.register_pickle_by_value(transformers_modules) + except Exception as e: logger.warning( "Unable to register remote classes used by"