[V1] Support LLM.apply_model (#18465)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-09-20 15:14:35 +08:00 committed by GitHub
parent be874c0201
commit 3d9a1d2de5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 194 additions and 169 deletions

View File

@ -987,17 +987,7 @@ class VllmRunner:
return [req_output.outputs.score for req_output in req_outputs]
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
if hasattr(self.llm.llm_engine, "model_executor"):
# This works either in V0 or in V1 with
# VLLM_ENABLE_V1_MULTIPROCESSING=0
executor = self.llm.llm_engine.model_executor
return executor.apply_model(func)
# This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
def _apply_model(self):
return func(self.get_model())
return self.llm.llm_engine.collective_rpc(_apply_model)
return self.llm.apply_model(func)
def get_llm(self) -> LLM:
return self.llm

View File

@ -1,21 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import importlib.metadata
from dataclasses import dataclass
from importlib.util import find_spec
from typing import Optional
import pytest
import torch
from packaging import version
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW4A4MXFP4)
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
QuarkW4A4MXFp4MoEMethod)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
"quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
) and current_platform.is_device_capability(100)
@ -39,6 +42,12 @@ class ModelCase:
tp: int
@pytest.fixture(scope="function", autouse=True)
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.mark.parametrize('model_case', [
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
@ -55,21 +64,19 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
tensor_parallel_size=model_case.tp,
load_format="dummy") as llm:
# TODO: llm.apply_model(check_model) currently relies on V0 internals.
# Re-enable this later.
# def check_model(model):
# layer = model.model.layers[0]
def check_model(model):
layer = model.model.layers[0]
# qkv_proj = layer.self_attn.qkv_proj
qkv_proj = layer.self_attn.qkv_proj
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
# assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
# assert isinstance(layer.mlp.experts.quant_method,
# QuarkW4A4MXFp4MoEMethod)
assert isinstance(layer.mlp.experts.quant_method,
QuarkW4A4MXFp4MoEMethod)
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
# llm.apply_model(check_model)
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
llm.apply_model(check_model)
output = llm.generate_greedy("Today I am in the French Alps and",
max_tokens=20)

View File

@ -10,6 +10,7 @@ from PIL import Image
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
from vllm.utils import set_default_torch_num_threads
from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput,
PromptVideoInput, VllmRunner)
@ -17,11 +18,9 @@ from ...utils import check_logprobs_close
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
V1 Test: batch_make_xxxxx_embeddings calls a V0 internal
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
models = ["Qwen/Qwen2-VL-2B-Instruct"]
@ -126,9 +125,8 @@ def batch_make_image_embeddings(
image_grid_thw_on_device = image_grid_thw.to(visual.device,
dtype=torch.int64)
return visual(pixel_values_on_device,
grid_thw=image_grid_thw_on_device)
grid_thw=image_grid_thw_on_device).cpu()
# V1 Test: this calls a V0 internal.
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
# split into original batches
@ -210,7 +208,7 @@ def batch_make_video_embeddings(
video_grid_thw_on_device = video_grid_thw.to(visual.device,
dtype=torch.int64)
return visual(pixel_values_on_device,
grid_thw=video_grid_thw_on_device)
grid_thw=video_grid_thw_on_device).cpu()
# V1 Test: this calls a V0 internal.
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
@ -266,19 +264,22 @@ def run_embedding_input_test(
processor = AutoProcessor.from_pretrained(model)
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
runner="generate",
max_model_len=4000,
max_num_seqs=3,
dtype=dtype,
limit_mm_per_prompt={
"image": mm_limit,
"video": mm_limit
},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:
with set_default_torch_num_threads(1):
vllm_model = vllm_runner(
model,
runner="generate",
max_model_len=4000,
max_num_seqs=3,
dtype=dtype,
limit_mm_per_prompt={
"image": mm_limit,
"video": mm_limit
},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
)
with vllm_model:
outputs_per_case_for_original_input = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
@ -329,9 +330,8 @@ def run_embedding_input_test(
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
size_factors, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
size_factors, dtype, max_tokens,
num_logprobs, monkeypatch) -> None:
images = [asset.pil_image for asset in image_assets]
inputs_per_case: list[tuple[

View File

@ -112,7 +112,7 @@ def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
monkeypatch) -> None:
# Test V1: this test hangs during setup on single-scale input.
# TODO: fixure out why and re-enable this on V1.
# TODO: figure out why and re-enable this on V1.
monkeypatch.setenv("VLLM_USE_V1", "0")
run_awq_test(
vllm_runner,

View File

@ -43,12 +43,9 @@ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
if not current_platform.is_cpu():
monkeypatch.setenv('VLLM_USE_V1', '0')
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.mark.parametrize(
@ -176,10 +173,11 @@ def test_compressed_tensors_w8a8_logprobs(
dtype = "bfloat16"
# skip language translation prompt for the static per tensor asym model
if (model_path ==
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
): # noqa: E501
# skip language translation prompt for the static per tensor models
if model_path in (
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
):
example_prompts = example_prompts[0:-1]
with hf_runner(model_path, dtype=dtype) as hf_model:

View File

@ -60,8 +60,8 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
def check_model(model):
@ -104,8 +104,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

View File

@ -31,41 +31,46 @@ MODEL_QUANT = [
@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
monkeypatch):
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048)
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
GPTQLinearMethod)
for name, submodule in (vllm_model.llm.llm_engine.model_executor.
driver_worker.model_runner.model.named_modules()):
if name == "lm_head":
assert isinstance(submodule.quant_method, linear_method_cls)
elif name == 'model.layers.0.self_attn.qkv_proj':
# The first layer is quantized using bits=4, group_size=128
# desc_act=True
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert config.weight_bits == 4
assert config.group_size == 128
assert config.desc_act
elif name == 'model.layers.1.self_attn.qkv_proj':
# The second layer is quantized using bits=8, group_size=32
# desc_act=False
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert get_dynamic_override(config, layer_name=name,
key="bits") == 8
assert get_dynamic_override(config,
layer_name=name,
key="group_size") == 32
assert not get_dynamic_override(
config, layer_name=name, key="desc_act")
elif (name == 'model.layers.2.self_attn.qkv_proj'
or name == 'model.layers.2.mlp.gate_up_proj'):
# All other layers (layer index >= 2) are not quantized
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as llm:
del vllm_model
def check_model(model):
for name, submodule in model.named_modules():
if name == "lm_head":
assert isinstance(submodule.quant_method,
linear_method_cls)
elif name == 'model.layers.0.self_attn.qkv_proj':
# The first layer is quantized using bits=4, group_size=128
# desc_act=True
assert isinstance(submodule.quant_method,
linear_method_cls)
config = submodule.quant_method.quant_config
assert config.weight_bits == 4
assert config.group_size == 128
assert config.desc_act
elif name == 'model.layers.1.self_attn.qkv_proj':
# The second layer is quantized using bits=8, group_size=32
# desc_act=False
assert isinstance(submodule.quant_method,
linear_method_cls)
config = submodule.quant_method.quant_config
assert get_dynamic_override(config,
layer_name=name,
key="bits") == 8
assert get_dynamic_override(config,
layer_name=name,
key="group_size") == 32
assert not get_dynamic_override(
config, layer_name=name, key="desc_act")
elif (name == 'model.layers.2.self_attn.qkv_proj'
or name == 'model.layers.2.mlp.gate_up_proj'):
# All other layers (layer index >= 2) are not quantized
assert isinstance(submodule.quant_method,
UnquantizedLinearMethod)
llm.apply_model(check_model)

View File

@ -29,8 +29,8 @@ def test_lm_head(
lm_head_quantized: bool,
monkeypatch,
) -> None:
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_id, dtype=torch.float16,
max_model_len=2048) as vllm_model:

View File

@ -11,16 +11,12 @@ import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm.platforms import current_platform
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
if not current_platform.is_cpu():
monkeypatch.setenv('VLLM_USE_V1', '0')
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.mark.skipif(not is_quant_method_supported("modelopt"),

View File

@ -13,6 +13,16 @@ from vllm.model_executor.layers.quantization.ptpc_fp8 import (
PTPCFp8LinearMethod)
from vllm.platforms import current_platform
UNSUPPORTED_STR = (
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
"support output dtype of bfloat16. torch.float16 is specified.")
@pytest.fixture(scope="function", autouse=True)
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
reason="PTPC FP8 is not supported on this GPU type.")
@ -21,14 +31,22 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
try:
with vllm_runner("facebook/opt-125m",
dtype=dtype,
quantization="ptpc_fp8",
kv_cache_dtype=kv_cache_dtype) as llm:
llm = vllm_runner("facebook/opt-125m",
dtype=dtype,
quantization="ptpc_fp8",
kv_cache_dtype=kv_cache_dtype)
except AssertionError as e:
if str(e) == UNSUPPORTED_STR:
# If the error message matches, the test passes
return
else:
# If the error message does not match, re-raise the exception
raise
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
with llm:
def check_model(model):
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
if kv_cache_dtype == "ptpc_fp8":
@ -40,17 +58,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
if current_platform.has_device_capability(94):
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fnuz
else:
pytest.skip()
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
except AssertionError as e:
if str(
e
) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501
# If the error message matches, the test passes
pass
else:
# If the error message does not match, re-raise the exception
raise
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output

View File

@ -7,10 +7,10 @@ Run `pytest tests/quantization/test_quark.py`.
See also `tests/kernels/moe/test_mxfp4_moe.py`.
"""
import importlib
import importlib.metadata
import os
from dataclasses import dataclass
from importlib.util import find_spec
import huggingface_hub
import lm_eval
@ -24,9 +24,8 @@ from vllm.platforms import current_platform
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
"quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
if QUARK_MXFP4_AVAILABLE:
from quark.torch.export.nn.modules.realquantizer import (
@ -43,11 +42,9 @@ except huggingface_hub.errors.RepositoryNotFoundError:
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
@ -132,13 +129,12 @@ def test_quark_fp8_parity(vllm_runner):
}
with (vllm_runner(quark_model_id, **llm_kwargs) as
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
quark_model = (quark_handle.llm.llm_engine.model_executor.
driver_worker.model_runner.model)
quark_state_dict = quark_model.state_dict()
fp8_model = (fp8_handle.llm.llm_engine.model_executor.driver_worker.
model_runner.model)
fp8_state_dict = fp8_model.state_dict()
def get_state_dict(model):
return {k: v.cpu() for k, v in model.state_dict().items()}
quark_state_dict, = quark_handle.apply_model(get_state_dict)
fp8_state_dict, = fp8_handle.apply_model(get_state_dict)
assert fp8_state_dict.keys() == quark_state_dict.keys()

View File

@ -105,18 +105,21 @@ def test_register_quantization_config():
])
def test_custom_quant(vllm_runner, model, monkeypatch):
"""Test infer with the custom quantization method."""
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_name=model,
quantization="custom_quant",
enforce_eager=True) as llm:
model = llm.llm.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
# Check the quantization method is FakeQuantLinearMethod
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
# Check the quantization method is FakeQuantLinearMethod
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output

View File

@ -13,6 +13,7 @@ from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast
import torch
import torch.nn as nn
from typing_extensions import TypeVar
import vllm.envs as envs
@ -55,6 +56,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
from vllm.version import __version__ as VLLM_VERSION
from vllm.worker.model_runner_base import InputProcessingError
from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
@ -1817,13 +1819,16 @@ class LLMEngine:
return sampling_params
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
method: Union[str, Callable[[WorkerBase], _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
return self.collective_rpc("apply_model", args=(func, ))
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine

View File

@ -522,9 +522,14 @@ class LLM:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
!!! warning
To reduce the overhead of data transfer, avoid returning large
arrays or tensors from this method. If you must return them,
make sure you move them to CPU first to avoid taking up additional
VRAM!
"""
executor = self.llm_engine.model_executor
return executor.apply_model(func)
return self.llm_engine.apply_model(func)
def _get_beam_search_lora_requests(
self,

View File

@ -5,11 +5,10 @@ import asyncio
import time
from abc import ABC, abstractmethod
from functools import cached_property
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
Union)
from typing import Any, Awaitable, Callable, List, Optional, Set, Union
import torch.nn as nn
from typing_extensions import TypeVar
from typing_extensions import TypeVar, deprecated
import vllm.platforms
from vllm.config import VllmConfig
@ -63,10 +62,10 @@ class ExecutorBase(ABC):
@abstractmethod
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
method: Union[str, Callable[[WorkerBase], _R]],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
"""
Execute an RPC call on all workers.
@ -91,7 +90,7 @@ class ExecutorBase(ABC):
"""
raise NotImplementedError
def determine_num_available_blocks(self) -> Tuple[int, int]:
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
@ -99,9 +98,10 @@ class ExecutorBase(ABC):
ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where
`num_gpu_blocks` are blocks that are "active" on the device and can be
appended to.
`num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
results = self.collective_rpc("determine_num_available_blocks")
@ -127,16 +127,15 @@ class ExecutorBase(ABC):
self.collective_rpc("initialize_cache",
args=(num_gpu_blocks, num_cpu_blocks))
@deprecated("`llm_engine.model_executor.apply_model` will no longer work "
"in V1 Engine. Please replace with `llm_engine.apply_model` "
"and set `VLLM_ALLOW_INSECURE_SERIALIZATION=1`.")
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
def rpc_func(worker: WorkerBase) -> _R:
return func(worker.get_model())
return self.collective_rpc(rpc_func)
return self.collective_rpc("apply_model", args=(func, ))
@cached_property # Avoid unnecessary RPC calls
def supported_tasks(self) -> tuple[SupportedTask, ...]:
@ -308,8 +307,8 @@ class DistributedExecutorBase(ExecutorBase):
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[Any]:
return self._run_workers(method, *args, **(kwargs or {}))
@abstractmethod

View File

@ -5,6 +5,7 @@ from collections.abc import Mapping
from copy import copy
from typing import Any, Callable, Optional, Union
import torch.nn as nn
from typing_extensions import TypeVar
import vllm.envs as envs
@ -33,6 +34,7 @@ from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
@ -319,12 +321,15 @@ class LLMEngine:
return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
method: Union[str, Callable[[WorkerBase], _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
return self.collective_rpc("apply_model", args=(func, ))
def __del__(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)

View File

@ -5,7 +5,8 @@ import dataclasses
import os
import time
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)
import cloudpickle
import torch
@ -28,6 +29,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
logger = init_logger(__name__)
_R = TypeVar("_R")
@warn_for_unimplemented_methods
class WorkerBase:
@ -70,6 +73,10 @@ class WorkerBase:
def get_model(self) -> nn.Module:
raise NotImplementedError
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
"""Apply a function on the model inside this worker."""
return fn(self.get_model())
def load_model(self) -> None:
"""Load model onto target device."""
raise NotImplementedError