[V1] Support LLM.apply_model (#18465)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-09-20 15:14:35 +08:00 committed by GitHub
parent be874c0201
commit 3d9a1d2de5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 194 additions and 169 deletions

View File

@ -987,17 +987,7 @@ class VllmRunner:
return [req_output.outputs.score for req_output in req_outputs] return [req_output.outputs.score for req_output in req_outputs]
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
if hasattr(self.llm.llm_engine, "model_executor"): return self.llm.apply_model(func)
# This works either in V0 or in V1 with
# VLLM_ENABLE_V1_MULTIPROCESSING=0
executor = self.llm.llm_engine.model_executor
return executor.apply_model(func)
# This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
def _apply_model(self):
return func(self.get_model())
return self.llm.llm_engine.collective_rpc(_apply_model)
def get_llm(self) -> LLM: def get_llm(self) -> LLM:
return self.llm return self.llm

View File

@ -1,20 +1,23 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import importlib.metadata import importlib.metadata
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec
from typing import Optional from typing import Optional
import pytest import pytest
import torch import torch
from packaging import version from packaging import version
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW4A4MXFP4)
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
QuarkW4A4MXFp4MoEMethod)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
"quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
@ -39,6 +42,12 @@ class ModelCase:
tp: int tp: int
@pytest.fixture(scope="function", autouse=True)
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.mark.parametrize('model_case', [ @pytest.mark.parametrize('model_case', [
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1), ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
@ -55,21 +64,19 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
tensor_parallel_size=model_case.tp, tensor_parallel_size=model_case.tp,
load_format="dummy") as llm: load_format="dummy") as llm:
# TODO: llm.apply_model(check_model) currently relies on V0 internals. def check_model(model):
# Re-enable this later. layer = model.model.layers[0]
# def check_model(model):
# layer = model.model.layers[0]
# qkv_proj = layer.self_attn.qkv_proj qkv_proj = layer.self_attn.qkv_proj
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
# assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4) assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
# assert isinstance(layer.mlp.experts.quant_method, assert isinstance(layer.mlp.experts.quant_method,
# QuarkW4A4MXFp4MoEMethod) QuarkW4A4MXFp4MoEMethod)
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4": if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
# llm.apply_model(check_model) llm.apply_model(check_model)
output = llm.generate_greedy("Today I am in the French Alps and", output = llm.generate_greedy("Today I am in the French Alps and",
max_tokens=20) max_tokens=20)

View File

@ -10,6 +10,7 @@ from PIL import Image
from vllm.multimodal.image import rescale_image_size from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
from vllm.utils import set_default_torch_num_threads
from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput, from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput,
PromptVideoInput, VllmRunner) PromptVideoInput, VllmRunner)
@ -17,11 +18,9 @@ from ...utils import check_logprobs_close
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch): def enable_pickle(monkeypatch):
""" """`LLM.apply_model` requires pickling a function."""
V1 Test: batch_make_xxxxx_embeddings calls a V0 internal monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
models = ["Qwen/Qwen2-VL-2B-Instruct"] models = ["Qwen/Qwen2-VL-2B-Instruct"]
@ -126,9 +125,8 @@ def batch_make_image_embeddings(
image_grid_thw_on_device = image_grid_thw.to(visual.device, image_grid_thw_on_device = image_grid_thw.to(visual.device,
dtype=torch.int64) dtype=torch.int64)
return visual(pixel_values_on_device, return visual(pixel_values_on_device,
grid_thw=image_grid_thw_on_device) grid_thw=image_grid_thw_on_device).cpu()
# V1 Test: this calls a V0 internal.
image_embeds = torch.concat(llm.apply_model(get_image_embeds)) image_embeds = torch.concat(llm.apply_model(get_image_embeds))
# split into original batches # split into original batches
@ -210,7 +208,7 @@ def batch_make_video_embeddings(
video_grid_thw_on_device = video_grid_thw.to(visual.device, video_grid_thw_on_device = video_grid_thw.to(visual.device,
dtype=torch.int64) dtype=torch.int64)
return visual(pixel_values_on_device, return visual(pixel_values_on_device,
grid_thw=video_grid_thw_on_device) grid_thw=video_grid_thw_on_device).cpu()
# V1 Test: this calls a V0 internal. # V1 Test: this calls a V0 internal.
video_embeds = torch.concat(llm.apply_model(get_image_embeds)) video_embeds = torch.concat(llm.apply_model(get_image_embeds))
@ -266,7 +264,9 @@ def run_embedding_input_test(
processor = AutoProcessor.from_pretrained(model) processor = AutoProcessor.from_pretrained(model)
# max_model_len should be greater than image_feature_size # max_model_len should be greater than image_feature_size
with vllm_runner(model, with set_default_torch_num_threads(1):
vllm_model = vllm_runner(
model,
runner="generate", runner="generate",
max_model_len=4000, max_model_len=4000,
max_num_seqs=3, max_num_seqs=3,
@ -276,9 +276,10 @@ def run_embedding_input_test(
"video": mm_limit "video": mm_limit
}, },
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend distributed_executor_backend=distributed_executor_backend,
) as vllm_model: )
with vllm_model:
outputs_per_case_for_original_input = [ outputs_per_case_for_original_input = [
vllm_model.generate_greedy_logprobs(prompts, vllm_model.generate_greedy_logprobs(prompts,
max_tokens, max_tokens,
@ -329,9 +330,8 @@ def run_embedding_input_test(
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10]) @pytest.mark.parametrize("num_logprobs", [10])
def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
size_factors, dtype: str, size_factors, dtype, max_tokens,
max_tokens: int, num_logprobs, monkeypatch) -> None:
num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
inputs_per_case: list[tuple[ inputs_per_case: list[tuple[

View File

@ -112,7 +112,7 @@ def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
monkeypatch) -> None: monkeypatch) -> None:
# Test V1: this test hangs during setup on single-scale input. # Test V1: this test hangs during setup on single-scale input.
# TODO: fixure out why and re-enable this on V1. # TODO: figure out why and re-enable this on V1.
monkeypatch.setenv("VLLM_USE_V1", "0") monkeypatch.setenv("VLLM_USE_V1", "0")
run_awq_test( run_awq_test(
vllm_runner, vllm_runner,

View File

@ -43,12 +43,9 @@ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch): def enable_pickle(monkeypatch):
""" """`LLM.apply_model` requires pickling a function."""
This module relies on V0 internals, so set VLLM_USE_V1=0. monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
"""
if not current_platform.is_cpu():
monkeypatch.setenv('VLLM_USE_V1', '0')
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -176,10 +173,11 @@ def test_compressed_tensors_w8a8_logprobs(
dtype = "bfloat16" dtype = "bfloat16"
# skip language translation prompt for the static per tensor asym model # skip language translation prompt for the static per tensor models
if (model_path == if model_path in (
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
): # noqa: E501 "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
):
example_prompts = example_prompts[0:-1] example_prompts = example_prompts[0:-1]
with hf_runner(model_path, dtype=dtype) as hf_model: with hf_runner(model_path, dtype=dtype) as hf_model:

View File

@ -60,8 +60,8 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
if use_rocm_aiter: if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# vllm_runner.apply_model() relies on V0 internals. # `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_USE_V1", "0") monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
def check_model(model): def check_model(model):
@ -104,8 +104,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
if use_rocm_aiter: if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# vllm_runner.apply_model() relies on V0 internals. # `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_USE_V1", "0") monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
if force_marlin: if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

View File

@ -31,22 +31,24 @@ MODEL_QUANT = [
@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT) @pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool, def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
monkeypatch): monkeypatch):
# vllm_runner.apply_model() relies on V0 internals. # `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_USE_V1", "0") monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048)
linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else ( linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
GPTQLinearMethod) GPTQLinearMethod)
for name, submodule in (vllm_model.llm.llm_engine.model_executor. with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as llm:
driver_worker.model_runner.model.named_modules()):
def check_model(model):
for name, submodule in model.named_modules():
if name == "lm_head": if name == "lm_head":
assert isinstance(submodule.quant_method, linear_method_cls) assert isinstance(submodule.quant_method,
linear_method_cls)
elif name == 'model.layers.0.self_attn.qkv_proj': elif name == 'model.layers.0.self_attn.qkv_proj':
# The first layer is quantized using bits=4, group_size=128 # The first layer is quantized using bits=4, group_size=128
# desc_act=True # desc_act=True
assert isinstance(submodule.quant_method, linear_method_cls) assert isinstance(submodule.quant_method,
linear_method_cls)
config = submodule.quant_method.quant_config config = submodule.quant_method.quant_config
assert config.weight_bits == 4 assert config.weight_bits == 4
assert config.group_size == 128 assert config.group_size == 128
@ -54,9 +56,11 @@ def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
elif name == 'model.layers.1.self_attn.qkv_proj': elif name == 'model.layers.1.self_attn.qkv_proj':
# The second layer is quantized using bits=8, group_size=32 # The second layer is quantized using bits=8, group_size=32
# desc_act=False # desc_act=False
assert isinstance(submodule.quant_method, linear_method_cls) assert isinstance(submodule.quant_method,
linear_method_cls)
config = submodule.quant_method.quant_config config = submodule.quant_method.quant_config
assert get_dynamic_override(config, layer_name=name, assert get_dynamic_override(config,
layer_name=name,
key="bits") == 8 key="bits") == 8
assert get_dynamic_override(config, assert get_dynamic_override(config,
layer_name=name, layer_name=name,
@ -66,6 +70,7 @@ def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
elif (name == 'model.layers.2.self_attn.qkv_proj' elif (name == 'model.layers.2.self_attn.qkv_proj'
or name == 'model.layers.2.mlp.gate_up_proj'): or name == 'model.layers.2.mlp.gate_up_proj'):
# All other layers (layer index >= 2) are not quantized # All other layers (layer index >= 2) are not quantized
assert isinstance(submodule.quant_method, UnquantizedLinearMethod) assert isinstance(submodule.quant_method,
UnquantizedLinearMethod)
del vllm_model llm.apply_model(check_model)

View File

@ -29,8 +29,8 @@ def test_lm_head(
lm_head_quantized: bool, lm_head_quantized: bool,
monkeypatch, monkeypatch,
) -> None: ) -> None:
# vllm_runner.apply_model() relies on V0 internals. # `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_USE_V1", "0") monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_id, dtype=torch.float16, with vllm_runner(model_id, dtype=torch.float16,
max_model_len=2048) as vllm_model: max_model_len=2048) as vllm_model:

View File

@ -11,16 +11,12 @@ import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm.platforms import current_platform
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch): def enable_pickle(monkeypatch):
""" """`LLM.apply_model` requires pickling a function."""
This module relies on V0 internals, so set VLLM_USE_V1=0. monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
"""
if not current_platform.is_cpu():
monkeypatch.setenv('VLLM_USE_V1', '0')
@pytest.mark.skipif(not is_quant_method_supported("modelopt"), @pytest.mark.skipif(not is_quant_method_supported("modelopt"),

View File

@ -13,6 +13,16 @@ from vllm.model_executor.layers.quantization.ptpc_fp8 import (
PTPCFp8LinearMethod) PTPCFp8LinearMethod)
from vllm.platforms import current_platform from vllm.platforms import current_platform
UNSUPPORTED_STR = (
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
"support output dtype of bfloat16. torch.float16 is specified.")
@pytest.fixture(scope="function", autouse=True)
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"), @pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
reason="PTPC FP8 is not supported on this GPU type.") reason="PTPC FP8 is not supported on this GPU type.")
@ -21,14 +31,22 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"]) @pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
try: try:
with vllm_runner("facebook/opt-125m", llm = vllm_runner("facebook/opt-125m",
dtype=dtype, dtype=dtype,
quantization="ptpc_fp8", quantization="ptpc_fp8",
kv_cache_dtype=kv_cache_dtype) as llm: kv_cache_dtype=kv_cache_dtype)
except AssertionError as e:
if str(e) == UNSUPPORTED_STR:
# If the error message matches, the test passes
return
else:
# If the error message does not match, re-raise the exception
raise
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 with llm:
def check_model(model):
fc1 = model.model.decoder.layers[0].fc1 fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod) assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
if kv_cache_dtype == "ptpc_fp8": if kv_cache_dtype == "ptpc_fp8":
@ -40,17 +58,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
if current_platform.has_device_capability(94): if current_platform.has_device_capability(94):
# For GPUs with hardware support, we keep weights in fp8 # For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fnuz assert fc1.weight.dtype == torch.float8_e4m3fnuz
else:
pytest.skip() llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20) output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output assert output
except AssertionError as e:
if str(
e
) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501
# If the error message matches, the test passes
pass
else:
# If the error message does not match, re-raise the exception
raise

View File

@ -7,10 +7,10 @@ Run `pytest tests/quantization/test_quark.py`.
See also `tests/kernels/moe/test_mxfp4_moe.py`. See also `tests/kernels/moe/test_mxfp4_moe.py`.
""" """
import importlib
import importlib.metadata import importlib.metadata
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec
import huggingface_hub import huggingface_hub
import lm_eval import lm_eval
@ -24,8 +24,7 @@ from vllm.platforms import current_platform
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
"quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
if QUARK_MXFP4_AVAILABLE: if QUARK_MXFP4_AVAILABLE:
@ -43,11 +42,9 @@ except huggingface_hub.errors.RepositoryNotFoundError:
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch): def enable_pickle(monkeypatch):
""" """`LLM.apply_model` requires pickling a function."""
This module relies on V0 internals, so set VLLM_USE_V1=0. monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8']) @pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
@ -132,13 +129,12 @@ def test_quark_fp8_parity(vllm_runner):
} }
with (vllm_runner(quark_model_id, **llm_kwargs) as with (vllm_runner(quark_model_id, **llm_kwargs) as
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle): quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
quark_model = (quark_handle.llm.llm_engine.model_executor.
driver_worker.model_runner.model)
quark_state_dict = quark_model.state_dict()
fp8_model = (fp8_handle.llm.llm_engine.model_executor.driver_worker. def get_state_dict(model):
model_runner.model) return {k: v.cpu() for k, v in model.state_dict().items()}
fp8_state_dict = fp8_model.state_dict()
quark_state_dict, = quark_handle.apply_model(get_state_dict)
fp8_state_dict, = fp8_handle.apply_model(get_state_dict)
assert fp8_state_dict.keys() == quark_state_dict.keys() assert fp8_state_dict.keys() == quark_state_dict.keys()

View File

@ -105,18 +105,21 @@ def test_register_quantization_config():
]) ])
def test_custom_quant(vllm_runner, model, monkeypatch): def test_custom_quant(vllm_runner, model, monkeypatch):
"""Test infer with the custom quantization method.""" """Test infer with the custom quantization method."""
# vllm_runner.apply_model() relies on V0 internals. # `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_USE_V1", "0") monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_name=model, with vllm_runner(model_name=model,
quantization="custom_quant", quantization="custom_quant",
enforce_eager=True) as llm: enforce_eager=True) as llm:
model = llm.llm.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 def check_model(model):
layer = model.model.layers[0] layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj qkv_proj = layer.self_attn.qkv_proj
# Check the quantization method is FakeQuantLinearMethod # Check the quantization method is FakeQuantLinearMethod
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod) assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20) output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output assert output

View File

@ -13,6 +13,7 @@ from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast from typing import Set, Type, Union, cast
import torch import torch
import torch.nn as nn
from typing_extensions import TypeVar from typing_extensions import TypeVar
import vllm.envs as envs import vllm.envs as envs
@ -55,6 +56,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
from vllm.worker.model_runner_base import InputProcessingError from vllm.worker.model_runner_base import InputProcessingError
from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
@ -1817,13 +1819,16 @@ class LLMEngine:
return sampling_params return sampling_params
def collective_rpc(self, def collective_rpc(self,
method: Union[str, Callable[..., _R]], method: Union[str, Callable[[WorkerBase], _R]],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: tuple = (), args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]: kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args, return self.model_executor.collective_rpc(method, timeout, args,
kwargs) kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
return self.collective_rpc("apply_model", args=(func, ))
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine

View File

@ -522,9 +522,14 @@ class LLM:
""" """
Run a function directly on the model inside each worker, Run a function directly on the model inside each worker,
returning the result for each of them. returning the result for each of them.
!!! warning
To reduce the overhead of data transfer, avoid returning large
arrays or tensors from this method. If you must return them,
make sure you move them to CPU first to avoid taking up additional
VRAM!
""" """
executor = self.llm_engine.model_executor return self.llm_engine.apply_model(func)
return executor.apply_model(func)
def _get_beam_search_lora_requests( def _get_beam_search_lora_requests(
self, self,

View File

@ -5,11 +5,10 @@ import asyncio
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import cached_property from functools import cached_property
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, from typing import Any, Awaitable, Callable, List, Optional, Set, Union
Union)
import torch.nn as nn import torch.nn as nn
from typing_extensions import TypeVar from typing_extensions import TypeVar, deprecated
import vllm.platforms import vllm.platforms
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -63,10 +62,10 @@ class ExecutorBase(ABC):
@abstractmethod @abstractmethod
def collective_rpc(self, def collective_rpc(self,
method: Union[str, Callable[..., _R]], method: Union[str, Callable[[WorkerBase], _R]],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: tuple = (),
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]: kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
""" """
Execute an RPC call on all workers. Execute an RPC call on all workers.
@ -91,7 +90,7 @@ class ExecutorBase(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and """Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache. swappable CPU KV cache.
@ -99,9 +98,10 @@ class ExecutorBase(ABC):
ExecutorBase may require modification of the result, e.g. to ensure the ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers. selected cache sizes are compatible with all workers.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where
are blocks that are "active" on the device and can be appended to. `num_gpu_blocks` are blocks that are "active" on the device and can be
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be appended to.
`num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be
appended to. appended to.
""" """
results = self.collective_rpc("determine_num_available_blocks") results = self.collective_rpc("determine_num_available_blocks")
@ -127,16 +127,15 @@ class ExecutorBase(ABC):
self.collective_rpc("initialize_cache", self.collective_rpc("initialize_cache",
args=(num_gpu_blocks, num_cpu_blocks)) args=(num_gpu_blocks, num_cpu_blocks))
@deprecated("`llm_engine.model_executor.apply_model` will no longer work "
"in V1 Engine. Please replace with `llm_engine.apply_model` "
"and set `VLLM_ALLOW_INSECURE_SERIALIZATION=1`.")
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
""" """
Run a function directly on the model inside each worker, Run a function directly on the model inside each worker,
returning the result for each of them. returning the result for each of them.
""" """
return self.collective_rpc("apply_model", args=(func, ))
def rpc_func(worker: WorkerBase) -> _R:
return func(worker.get_model())
return self.collective_rpc(rpc_func)
@cached_property # Avoid unnecessary RPC calls @cached_property # Avoid unnecessary RPC calls
def supported_tasks(self) -> tuple[SupportedTask, ...]: def supported_tasks(self) -> tuple[SupportedTask, ...]:
@ -308,8 +307,8 @@ class DistributedExecutorBase(ExecutorBase):
def collective_rpc(self, def collective_rpc(self,
method: Union[str, Callable], method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]: kwargs: Optional[dict[str, Any]] = None) -> list[Any]:
return self._run_workers(method, *args, **(kwargs or {})) return self._run_workers(method, *args, **(kwargs or {}))
@abstractmethod @abstractmethod

View File

@ -5,6 +5,7 @@ from collections.abc import Mapping
from copy import copy from copy import copy
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch.nn as nn
from typing_extensions import TypeVar from typing_extensions import TypeVar
import vllm.envs as envs import vllm.envs as envs
@ -33,6 +34,7 @@ from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory) StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats from vllm.v1.metrics.stats import IterationStats
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
@ -319,12 +321,15 @@ class LLMEngine:
return self.engine_core.pin_lora(lora_id) return self.engine_core.pin_lora(lora_id)
def collective_rpc(self, def collective_rpc(self,
method: Union[str, Callable[..., _R]], method: Union[str, Callable[[WorkerBase], _R]],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: tuple = (), args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]: kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs) return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
return self.collective_rpc("apply_model", args=(func, ))
def __del__(self): def __del__(self):
if dp_group := getattr(self, "dp_group", None): if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group) stateless_destroy_torch_distributed_process_group(dp_group)

View File

@ -5,7 +5,8 @@ import dataclasses
import os import os
import time import time
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)
import cloudpickle import cloudpickle
import torch import torch
@ -28,6 +29,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
logger = init_logger(__name__) logger = init_logger(__name__)
_R = TypeVar("_R")
@warn_for_unimplemented_methods @warn_for_unimplemented_methods
class WorkerBase: class WorkerBase:
@ -70,6 +73,10 @@ class WorkerBase:
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
raise NotImplementedError raise NotImplementedError
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
"""Apply a function on the model inside this worker."""
return fn(self.get_model())
def load_model(self) -> None: def load_model(self) -> None:
"""Load model onto target device.""" """Load model onto target device."""
raise NotImplementedError raise NotImplementedError