diff --git a/tests/conftest.py b/tests/conftest.py index 3cd93f4ad3289..e8e95357ff5b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -987,17 +987,7 @@ class VllmRunner: return [req_output.outputs.score for req_output in req_outputs] def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: - if hasattr(self.llm.llm_engine, "model_executor"): - # This works either in V0 or in V1 with - # VLLM_ENABLE_V1_MULTIPROCESSING=0 - executor = self.llm.llm_engine.model_executor - return executor.apply_model(func) - - # This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1 - def _apply_model(self): - return func(self.get_model()) - - return self.llm.llm_engine.collective_rpc(_apply_model) + return self.llm.apply_model(func) def get_llm(self) -> LLM: return self.llm diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index a3b8f07638d9a..61d3311cc1624 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -1,21 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import importlib import importlib.metadata from dataclasses import dataclass +from importlib.util import find_spec from typing import Optional import pytest import torch from packaging import version +from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 + QuarkLinearMethod, QuarkW4A4MXFP4) +from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 + QuarkW4A4MXFp4MoEMethod) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer -QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( - "quark") is not None and version.parse( - importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') +QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( + importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( ) and current_platform.is_device_capability(100) @@ -39,6 +42,12 @@ class ModelCase: tp: int +@pytest.fixture(scope="function", autouse=True) +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + @pytest.mark.parametrize('model_case', [ ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1), ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), @@ -55,21 +64,19 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): tensor_parallel_size=model_case.tp, load_format="dummy") as llm: - # TODO: llm.apply_model(check_model) currently relies on V0 internals. - # Re-enable this later. - # def check_model(model): - # layer = model.model.layers[0] + def check_model(model): + layer = model.model.layers[0] - # qkv_proj = layer.self_attn.qkv_proj + qkv_proj = layer.self_attn.qkv_proj - # assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) - # assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4) + assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) + assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4) - # assert isinstance(layer.mlp.experts.quant_method, - # QuarkW4A4MXFp4MoEMethod) + assert isinstance(layer.mlp.experts.quant_method, + QuarkW4A4MXFp4MoEMethod) - # if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4": - # llm.apply_model(check_model) + if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4": + llm.apply_model(check_model) output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20) diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index a81f5e7ec8872..e56f4e4075be4 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -10,6 +10,7 @@ from PIL import Image from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import rescale_video_size, sample_frames_from_video +from vllm.utils import set_default_torch_num_threads from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput, PromptVideoInput, VllmRunner) @@ -17,11 +18,9 @@ from ...utils import check_logprobs_close @pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - V1 Test: batch_make_xxxxx_embeddings calls a V0 internal - """ - monkeypatch.setenv('VLLM_USE_V1', '0') +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") models = ["Qwen/Qwen2-VL-2B-Instruct"] @@ -126,9 +125,8 @@ def batch_make_image_embeddings( image_grid_thw_on_device = image_grid_thw.to(visual.device, dtype=torch.int64) return visual(pixel_values_on_device, - grid_thw=image_grid_thw_on_device) + grid_thw=image_grid_thw_on_device).cpu() - # V1 Test: this calls a V0 internal. image_embeds = torch.concat(llm.apply_model(get_image_embeds)) # split into original batches @@ -210,7 +208,7 @@ def batch_make_video_embeddings( video_grid_thw_on_device = video_grid_thw.to(visual.device, dtype=torch.int64) return visual(pixel_values_on_device, - grid_thw=video_grid_thw_on_device) + grid_thw=video_grid_thw_on_device).cpu() # V1 Test: this calls a V0 internal. video_embeds = torch.concat(llm.apply_model(get_image_embeds)) @@ -266,19 +264,22 @@ def run_embedding_input_test( processor = AutoProcessor.from_pretrained(model) # max_model_len should be greater than image_feature_size - with vllm_runner(model, - runner="generate", - max_model_len=4000, - max_num_seqs=3, - dtype=dtype, - limit_mm_per_prompt={ - "image": mm_limit, - "video": mm_limit - }, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend - ) as vllm_model: + with set_default_torch_num_threads(1): + vllm_model = vllm_runner( + model, + runner="generate", + max_model_len=4000, + max_num_seqs=3, + dtype=dtype, + limit_mm_per_prompt={ + "image": mm_limit, + "video": mm_limit + }, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + ) + with vllm_model: outputs_per_case_for_original_input = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, @@ -329,9 +330,8 @@ def run_embedding_input_test( @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, - size_factors, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: + size_factors, dtype, max_tokens, + num_logprobs, monkeypatch) -> None: images = [asset.pil_image for asset in image_assets] inputs_per_case: list[tuple[ diff --git a/tests/models/quantization/test_awq.py b/tests/models/quantization/test_awq.py index bd696198931ff..7005e435ecf46 100644 --- a/tests/models/quantization/test_awq.py +++ b/tests/models/quantization/test_awq.py @@ -112,7 +112,7 @@ def test_awq_models(vllm_runner, image_assets, source_model, quant_model, monkeypatch) -> None: # Test V1: this test hangs during setup on single-scale input. - # TODO: fixure out why and re-enable this on V1. + # TODO: figure out why and re-enable this on V1. monkeypatch.setenv("VLLM_USE_V1", "0") run_awq_test( vllm_runner, diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 484f53246f349..b7949a488ad05 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -43,12 +43,9 @@ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [ @pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module relies on V0 internals, so set VLLM_USE_V1=0. - """ - if not current_platform.is_cpu(): - monkeypatch.setenv('VLLM_USE_V1', '0') +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") @pytest.mark.parametrize( @@ -176,10 +173,11 @@ def test_compressed_tensors_w8a8_logprobs( dtype = "bfloat16" - # skip language translation prompt for the static per tensor asym model - if (model_path == - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" - ): # noqa: E501 + # skip language translation prompt for the static per tensor models + if model_path in ( + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym", + ): example_prompts = example_prompts[0:-1] with hf_runner(model_path, dtype=dtype) as hf_model: diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index d781f462b4ad7..db53061cf2d1a 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -60,8 +60,8 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - # vllm_runner.apply_model() relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: def check_model(model): @@ -104,8 +104,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - # vllm_runner.apply_model() relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") if force_marlin: monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") diff --git a/tests/quantization/test_gptq_dynamic.py b/tests/quantization/test_gptq_dynamic.py index aea50e99c1dd5..00a5946ed0154 100644 --- a/tests/quantization/test_gptq_dynamic.py +++ b/tests/quantization/test_gptq_dynamic.py @@ -31,41 +31,46 @@ MODEL_QUANT = [ @pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT) def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool, monkeypatch): - # vllm_runner.apply_model() relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") - - vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else ( GPTQLinearMethod) - for name, submodule in (vllm_model.llm.llm_engine.model_executor. - driver_worker.model_runner.model.named_modules()): - if name == "lm_head": - assert isinstance(submodule.quant_method, linear_method_cls) - elif name == 'model.layers.0.self_attn.qkv_proj': - # The first layer is quantized using bits=4, group_size=128 - # desc_act=True - assert isinstance(submodule.quant_method, linear_method_cls) - config = submodule.quant_method.quant_config - assert config.weight_bits == 4 - assert config.group_size == 128 - assert config.desc_act - elif name == 'model.layers.1.self_attn.qkv_proj': - # The second layer is quantized using bits=8, group_size=32 - # desc_act=False - assert isinstance(submodule.quant_method, linear_method_cls) - config = submodule.quant_method.quant_config - assert get_dynamic_override(config, layer_name=name, - key="bits") == 8 - assert get_dynamic_override(config, - layer_name=name, - key="group_size") == 32 - assert not get_dynamic_override( - config, layer_name=name, key="desc_act") - elif (name == 'model.layers.2.self_attn.qkv_proj' - or name == 'model.layers.2.mlp.gate_up_proj'): - # All other layers (layer index >= 2) are not quantized - assert isinstance(submodule.quant_method, UnquantizedLinearMethod) + with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as llm: - del vllm_model + def check_model(model): + for name, submodule in model.named_modules(): + if name == "lm_head": + assert isinstance(submodule.quant_method, + linear_method_cls) + elif name == 'model.layers.0.self_attn.qkv_proj': + # The first layer is quantized using bits=4, group_size=128 + # desc_act=True + assert isinstance(submodule.quant_method, + linear_method_cls) + config = submodule.quant_method.quant_config + assert config.weight_bits == 4 + assert config.group_size == 128 + assert config.desc_act + elif name == 'model.layers.1.self_attn.qkv_proj': + # The second layer is quantized using bits=8, group_size=32 + # desc_act=False + assert isinstance(submodule.quant_method, + linear_method_cls) + config = submodule.quant_method.quant_config + assert get_dynamic_override(config, + layer_name=name, + key="bits") == 8 + assert get_dynamic_override(config, + layer_name=name, + key="group_size") == 32 + assert not get_dynamic_override( + config, layer_name=name, key="desc_act") + elif (name == 'model.layers.2.self_attn.qkv_proj' + or name == 'model.layers.2.mlp.gate_up_proj'): + # All other layers (layer index >= 2) are not quantized + assert isinstance(submodule.quant_method, + UnquantizedLinearMethod) + + llm.apply_model(check_model) diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index b24964a9d0a9f..e69d4ad349c38 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -29,8 +29,8 @@ def test_lm_head( lm_head_quantized: bool, monkeypatch, ) -> None: - # vllm_runner.apply_model() relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as vllm_model: diff --git a/tests/quantization/test_modelopt.py b/tests/quantization/test_modelopt.py index c60a03f44baec..e7174be73626a 100644 --- a/tests/quantization/test_modelopt.py +++ b/tests/quantization/test_modelopt.py @@ -11,16 +11,12 @@ import pytest import torch from tests.quantization.utils import is_quant_method_supported -from vllm.platforms import current_platform @pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module relies on V0 internals, so set VLLM_USE_V1=0. - """ - if not current_platform.is_cpu(): - monkeypatch.setenv('VLLM_USE_V1', '0') +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") @pytest.mark.skipif(not is_quant_method_supported("modelopt"), diff --git a/tests/quantization/test_ptpc_fp8.py b/tests/quantization/test_ptpc_fp8.py index 5f78bc30504c0..088b68510cffa 100644 --- a/tests/quantization/test_ptpc_fp8.py +++ b/tests/quantization/test_ptpc_fp8.py @@ -13,6 +13,16 @@ from vllm.model_executor.layers.quantization.ptpc_fp8 import ( PTPCFp8LinearMethod) from vllm.platforms import current_platform +UNSUPPORTED_STR = ( + "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only " + "support output dtype of bfloat16. torch.float16 is specified.") + + +@pytest.fixture(scope="function", autouse=True) +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + @pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"), reason="PTPC FP8 is not supported on this GPU type.") @@ -21,14 +31,22 @@ from vllm.platforms import current_platform @pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"]) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: - try: - with vllm_runner("facebook/opt-125m", - dtype=dtype, - quantization="ptpc_fp8", - kv_cache_dtype=kv_cache_dtype) as llm: + llm = vllm_runner("facebook/opt-125m", + dtype=dtype, + quantization="ptpc_fp8", + kv_cache_dtype=kv_cache_dtype) + except AssertionError as e: + if str(e) == UNSUPPORTED_STR: + # If the error message matches, the test passes + return + else: + # If the error message does not match, re-raise the exception + raise - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + with llm: + + def check_model(model): fc1 = model.model.decoder.layers[0].fc1 assert isinstance(fc1.quant_method, PTPCFp8LinearMethod) if kv_cache_dtype == "ptpc_fp8": @@ -40,17 +58,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: if current_platform.has_device_capability(94): # For GPUs with hardware support, we keep weights in fp8 assert fc1.weight.dtype == torch.float8_e4m3fnuz - else: - pytest.skip() - output = llm.generate_greedy("Hello my name is", max_tokens=20) - assert output - except AssertionError as e: - if str( - e - ) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501 - # If the error message matches, the test passes - pass - else: - # If the error message does not match, re-raise the exception - raise + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index c09931971e6fb..930f4acb328fd 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -7,10 +7,10 @@ Run `pytest tests/quantization/test_quark.py`. See also `tests/kernels/moe/test_mxfp4_moe.py`. """ -import importlib import importlib.metadata import os from dataclasses import dataclass +from importlib.util import find_spec import huggingface_hub import lm_eval @@ -24,9 +24,8 @@ from vllm.platforms import current_platform from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch -QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( - "quark") is not None and version.parse( - importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') +QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( + importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') if QUARK_MXFP4_AVAILABLE: from quark.torch.export.nn.modules.realquantizer import ( @@ -43,11 +42,9 @@ except huggingface_hub.errors.RepositoryNotFoundError: @pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module relies on V0 internals, so set VLLM_USE_V1=0. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") @pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8']) @@ -132,13 +129,12 @@ def test_quark_fp8_parity(vllm_runner): } with (vllm_runner(quark_model_id, **llm_kwargs) as quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle): - quark_model = (quark_handle.llm.llm_engine.model_executor. - driver_worker.model_runner.model) - quark_state_dict = quark_model.state_dict() - fp8_model = (fp8_handle.llm.llm_engine.model_executor.driver_worker. - model_runner.model) - fp8_state_dict = fp8_model.state_dict() + def get_state_dict(model): + return {k: v.cpu() for k, v in model.state_dict().items()} + + quark_state_dict, = quark_handle.apply_model(get_state_dict) + fp8_state_dict, = fp8_handle.apply_model(get_state_dict) assert fp8_state_dict.keys() == quark_state_dict.keys() diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py index 84705e92c85bb..03fe59d7e3bff 100644 --- a/tests/quantization/test_register_quantization_config.py +++ b/tests/quantization/test_register_quantization_config.py @@ -105,18 +105,21 @@ def test_register_quantization_config(): ]) def test_custom_quant(vllm_runner, model, monkeypatch): """Test infer with the custom quantization method.""" - # vllm_runner.apply_model() relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + with vllm_runner(model_name=model, quantization="custom_quant", enforce_eager=True) as llm: - model = llm.llm.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 - layer = model.model.layers[0] - qkv_proj = layer.self_attn.qkv_proj + def check_model(model): + layer = model.model.layers[0] + qkv_proj = layer.self_attn.qkv_proj - # Check the quantization method is FakeQuantLinearMethod - assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod) + # Check the quantization method is FakeQuantLinearMethod + assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod) + + llm.apply_model(check_model) output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 708f3bbeeff15..014bc56bc8ece 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -13,6 +13,7 @@ from typing import Sequence as GenericSequence from typing import Set, Type, Union, cast import torch +import torch.nn as nn from typing_extensions import TypeVar import vllm.envs as envs @@ -55,6 +56,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind from vllm.version import __version__ as VLLM_VERSION from vllm.worker.model_runner_base import InputProcessingError +from vllm.worker.worker_base import WorkerBase logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -1817,13 +1819,16 @@ class LLMEngine: return sampling_params def collective_rpc(self, - method: Union[str, Callable[..., _R]], + method: Union[str, Callable[[WorkerBase], _R]], timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict[str, Any]] = None) -> list[_R]: return self.model_executor.collective_rpc(method, timeout, args, kwargs) + def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: + return self.collective_rpc("apply_model", args=(func, )) + if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e21bfce0ab085..f2282c40f7073 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -522,9 +522,14 @@ class LLM: """ Run a function directly on the model inside each worker, returning the result for each of them. + + !!! warning + To reduce the overhead of data transfer, avoid returning large + arrays or tensors from this method. If you must return them, + make sure you move them to CPU first to avoid taking up additional + VRAM! """ - executor = self.llm_engine.model_executor - return executor.apply_model(func) + return self.llm_engine.apply_model(func) def _get_beam_search_lora_requests( self, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 42aa8d14a21eb..b75b94ad0acc2 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -5,11 +5,10 @@ import asyncio import time from abc import ABC, abstractmethod from functools import cached_property -from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, - Union) +from typing import Any, Awaitable, Callable, List, Optional, Set, Union import torch.nn as nn -from typing_extensions import TypeVar +from typing_extensions import TypeVar, deprecated import vllm.platforms from vllm.config import VllmConfig @@ -63,10 +62,10 @@ class ExecutorBase(ABC): @abstractmethod def collective_rpc(self, - method: Union[str, Callable[..., _R]], + method: Union[str, Callable[[WorkerBase], _R]], timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict[str, Any]] = None) -> List[_R]: + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: """ Execute an RPC call on all workers. @@ -91,7 +90,7 @@ class ExecutorBase(ABC): """ raise NotImplementedError - def determine_num_available_blocks(self) -> Tuple[int, int]: + def determine_num_available_blocks(self) -> tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and swappable CPU KV cache. @@ -99,9 +98,10 @@ class ExecutorBase(ABC): ExecutorBase may require modification of the result, e.g. to ensure the selected cache sizes are compatible with all workers. - Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks - are blocks that are "active" on the device and can be appended to. - num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where + `num_gpu_blocks` are blocks that are "active" on the device and can be + appended to. + `num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be appended to. """ results = self.collective_rpc("determine_num_available_blocks") @@ -127,16 +127,15 @@ class ExecutorBase(ABC): self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) + @deprecated("`llm_engine.model_executor.apply_model` will no longer work " + "in V1 Engine. Please replace with `llm_engine.apply_model` " + "and set `VLLM_ALLOW_INSECURE_SERIALIZATION=1`.") def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: """ Run a function directly on the model inside each worker, returning the result for each of them. """ - - def rpc_func(worker: WorkerBase) -> _R: - return func(worker.get_model()) - - return self.collective_rpc(rpc_func) + return self.collective_rpc("apply_model", args=(func, )) @cached_property # Avoid unnecessary RPC calls def supported_tasks(self) -> tuple[SupportedTask, ...]: @@ -308,8 +307,8 @@ class DistributedExecutorBase(ExecutorBase): def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict] = None) -> List[Any]: + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[Any]: return self._run_workers(method, *args, **(kwargs or {})) @abstractmethod diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c93bfc35f0aeb..907656d1b24cb 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -5,6 +5,7 @@ from collections.abc import Mapping from copy import copy from typing import Any, Callable, Optional, Union +import torch.nn as nn from typing_extensions import TypeVar import vllm.envs as envs @@ -33,6 +34,7 @@ from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase, StatLoggerFactory) from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.stats import IterationStats +from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -319,12 +321,15 @@ class LLMEngine: return self.engine_core.pin_lora(lora_id) def collective_rpc(self, - method: Union[str, Callable[..., _R]], + method: Union[str, Callable[[WorkerBase], _R]], timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict[str, Any]] = None) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) + def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: + return self.collective_rpc("apply_model", args=(func, )) + def __del__(self): if dp_group := getattr(self, "dp_group", None): stateless_destroy_torch_distributed_process_group(dp_group) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index aa76d21f0fcaa..d0a56f6ff4637 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -5,7 +5,8 @@ import dataclasses import os import time from abc import abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union +from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type, + TypeVar, Union) import cloudpickle import torch @@ -28,6 +29,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput, logger = init_logger(__name__) +_R = TypeVar("_R") + @warn_for_unimplemented_methods class WorkerBase: @@ -70,6 +73,10 @@ class WorkerBase: def get_model(self) -> nn.Module: raise NotImplementedError + def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R: + """Apply a function on the model inside this worker.""" + return fn(self.get_model()) + def load_model(self) -> None: """Load model onto target device.""" raise NotImplementedError