mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:05:35 +08:00
[V1] Support LLM.apply_model (#18465)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
be874c0201
commit
3d9a1d2de5
@ -987,17 +987,7 @@ class VllmRunner:
|
||||
return [req_output.outputs.score for req_output in req_outputs]
|
||||
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
if hasattr(self.llm.llm_engine, "model_executor"):
|
||||
# This works either in V0 or in V1 with
|
||||
# VLLM_ENABLE_V1_MULTIPROCESSING=0
|
||||
executor = self.llm.llm_engine.model_executor
|
||||
return executor.apply_model(func)
|
||||
|
||||
# This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
|
||||
def _apply_model(self):
|
||||
return func(self.get_model())
|
||||
|
||||
return self.llm.llm_engine.collective_rpc(_apply_model)
|
||||
return self.llm.apply_model(func)
|
||||
|
||||
def get_llm(self) -> LLM:
|
||||
return self.llm
|
||||
|
||||
@ -1,21 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
||||
QuarkLinearMethod, QuarkW4A4MXFP4)
|
||||
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||
QuarkW4A4MXFp4MoEMethod)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
|
||||
"quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
|
||||
) and current_platform.is_device_capability(100)
|
||||
@ -39,6 +42,12 @@ class ModelCase:
|
||||
tp: int
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_case', [
|
||||
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
|
||||
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
||||
@ -55,21 +64,19 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
tensor_parallel_size=model_case.tp,
|
||||
load_format="dummy") as llm:
|
||||
|
||||
# TODO: llm.apply_model(check_model) currently relies on V0 internals.
|
||||
# Re-enable this later.
|
||||
# def check_model(model):
|
||||
# layer = model.model.layers[0]
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
# qkv_proj = layer.self_attn.qkv_proj
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
|
||||
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||
# assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
|
||||
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
|
||||
|
||||
# assert isinstance(layer.mlp.experts.quant_method,
|
||||
# QuarkW4A4MXFp4MoEMethod)
|
||||
assert isinstance(layer.mlp.experts.quant_method,
|
||||
QuarkW4A4MXFp4MoEMethod)
|
||||
|
||||
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
|
||||
# llm.apply_model(check_model)
|
||||
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Today I am in the French Alps and",
|
||||
max_tokens=20)
|
||||
|
||||
@ -10,6 +10,7 @@ from PIL import Image
|
||||
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput,
|
||||
PromptVideoInput, VllmRunner)
|
||||
@ -17,11 +18,9 @@ from ...utils import check_logprobs_close
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
V1 Test: batch_make_xxxxx_embeddings calls a V0 internal
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
models = ["Qwen/Qwen2-VL-2B-Instruct"]
|
||||
@ -126,9 +125,8 @@ def batch_make_image_embeddings(
|
||||
image_grid_thw_on_device = image_grid_thw.to(visual.device,
|
||||
dtype=torch.int64)
|
||||
return visual(pixel_values_on_device,
|
||||
grid_thw=image_grid_thw_on_device)
|
||||
grid_thw=image_grid_thw_on_device).cpu()
|
||||
|
||||
# V1 Test: this calls a V0 internal.
|
||||
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
||||
|
||||
# split into original batches
|
||||
@ -210,7 +208,7 @@ def batch_make_video_embeddings(
|
||||
video_grid_thw_on_device = video_grid_thw.to(visual.device,
|
||||
dtype=torch.int64)
|
||||
return visual(pixel_values_on_device,
|
||||
grid_thw=video_grid_thw_on_device)
|
||||
grid_thw=video_grid_thw_on_device).cpu()
|
||||
|
||||
# V1 Test: this calls a V0 internal.
|
||||
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
||||
@ -266,19 +264,22 @@ def run_embedding_input_test(
|
||||
processor = AutoProcessor.from_pretrained(model)
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
runner="generate",
|
||||
max_model_len=4000,
|
||||
max_num_seqs=3,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={
|
||||
"image": mm_limit,
|
||||
"video": mm_limit
|
||||
},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend
|
||||
) as vllm_model:
|
||||
with set_default_torch_num_threads(1):
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
runner="generate",
|
||||
max_model_len=4000,
|
||||
max_num_seqs=3,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={
|
||||
"image": mm_limit,
|
||||
"video": mm_limit
|
||||
},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
)
|
||||
|
||||
with vllm_model:
|
||||
outputs_per_case_for_original_input = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
@ -329,9 +330,8 @@ def run_embedding_input_test(
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
|
||||
size_factors, dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
size_factors, dtype, max_tokens,
|
||||
num_logprobs, monkeypatch) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_case: list[tuple[
|
||||
|
||||
@ -112,7 +112,7 @@ def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
|
||||
monkeypatch) -> None:
|
||||
|
||||
# Test V1: this test hangs during setup on single-scale input.
|
||||
# TODO: fixure out why and re-enable this on V1.
|
||||
# TODO: figure out why and re-enable this on V1.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
run_awq_test(
|
||||
vllm_runner,
|
||||
|
||||
@ -43,12 +43,9 @@ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module relies on V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
if not current_platform.is_cpu():
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -176,10 +173,11 @@ def test_compressed_tensors_w8a8_logprobs(
|
||||
|
||||
dtype = "bfloat16"
|
||||
|
||||
# skip language translation prompt for the static per tensor asym model
|
||||
if (model_path ==
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
|
||||
): # noqa: E501
|
||||
# skip language translation prompt for the static per tensor models
|
||||
if model_path in (
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
|
||||
):
|
||||
example_prompts = example_prompts[0:-1]
|
||||
|
||||
with hf_runner(model_path, dtype=dtype) as hf_model:
|
||||
|
||||
@ -60,8 +60,8 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
||||
|
||||
def check_model(model):
|
||||
@ -104,8 +104,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
if force_marlin:
|
||||
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
|
||||
|
||||
@ -31,41 +31,46 @@ MODEL_QUANT = [
|
||||
@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
|
||||
def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
|
||||
monkeypatch):
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048)
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
|
||||
GPTQLinearMethod)
|
||||
|
||||
for name, submodule in (vllm_model.llm.llm_engine.model_executor.
|
||||
driver_worker.model_runner.model.named_modules()):
|
||||
if name == "lm_head":
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
elif name == 'model.layers.0.self_attn.qkv_proj':
|
||||
# The first layer is quantized using bits=4, group_size=128
|
||||
# desc_act=True
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert config.weight_bits == 4
|
||||
assert config.group_size == 128
|
||||
assert config.desc_act
|
||||
elif name == 'model.layers.1.self_attn.qkv_proj':
|
||||
# The second layer is quantized using bits=8, group_size=32
|
||||
# desc_act=False
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert get_dynamic_override(config, layer_name=name,
|
||||
key="bits") == 8
|
||||
assert get_dynamic_override(config,
|
||||
layer_name=name,
|
||||
key="group_size") == 32
|
||||
assert not get_dynamic_override(
|
||||
config, layer_name=name, key="desc_act")
|
||||
elif (name == 'model.layers.2.self_attn.qkv_proj'
|
||||
or name == 'model.layers.2.mlp.gate_up_proj'):
|
||||
# All other layers (layer index >= 2) are not quantized
|
||||
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
|
||||
with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as llm:
|
||||
|
||||
del vllm_model
|
||||
def check_model(model):
|
||||
for name, submodule in model.named_modules():
|
||||
if name == "lm_head":
|
||||
assert isinstance(submodule.quant_method,
|
||||
linear_method_cls)
|
||||
elif name == 'model.layers.0.self_attn.qkv_proj':
|
||||
# The first layer is quantized using bits=4, group_size=128
|
||||
# desc_act=True
|
||||
assert isinstance(submodule.quant_method,
|
||||
linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert config.weight_bits == 4
|
||||
assert config.group_size == 128
|
||||
assert config.desc_act
|
||||
elif name == 'model.layers.1.self_attn.qkv_proj':
|
||||
# The second layer is quantized using bits=8, group_size=32
|
||||
# desc_act=False
|
||||
assert isinstance(submodule.quant_method,
|
||||
linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert get_dynamic_override(config,
|
||||
layer_name=name,
|
||||
key="bits") == 8
|
||||
assert get_dynamic_override(config,
|
||||
layer_name=name,
|
||||
key="group_size") == 32
|
||||
assert not get_dynamic_override(
|
||||
config, layer_name=name, key="desc_act")
|
||||
elif (name == 'model.layers.2.self_attn.qkv_proj'
|
||||
or name == 'model.layers.2.mlp.gate_up_proj'):
|
||||
# All other layers (layer index >= 2) are not quantized
|
||||
assert isinstance(submodule.quant_method,
|
||||
UnquantizedLinearMethod)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
@ -29,8 +29,8 @@ def test_lm_head(
|
||||
lm_head_quantized: bool,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
with vllm_runner(model_id, dtype=torch.float16,
|
||||
max_model_len=2048) as vllm_model:
|
||||
|
||||
|
||||
@ -11,16 +11,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module relies on V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
if not current_platform.is_cpu():
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("modelopt"),
|
||||
|
||||
@ -13,6 +13,16 @@ from vllm.model_executor.layers.quantization.ptpc_fp8 import (
|
||||
PTPCFp8LinearMethod)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
UNSUPPORTED_STR = (
|
||||
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
|
||||
"support output dtype of bfloat16. torch.float16 is specified.")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
|
||||
reason="PTPC FP8 is not supported on this GPU type.")
|
||||
@ -21,14 +31,22 @@ from vllm.platforms import current_platform
|
||||
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
|
||||
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
|
||||
|
||||
try:
|
||||
with vllm_runner("facebook/opt-125m",
|
||||
dtype=dtype,
|
||||
quantization="ptpc_fp8",
|
||||
kv_cache_dtype=kv_cache_dtype) as llm:
|
||||
llm = vllm_runner("facebook/opt-125m",
|
||||
dtype=dtype,
|
||||
quantization="ptpc_fp8",
|
||||
kv_cache_dtype=kv_cache_dtype)
|
||||
except AssertionError as e:
|
||||
if str(e) == UNSUPPORTED_STR:
|
||||
# If the error message matches, the test passes
|
||||
return
|
||||
else:
|
||||
# If the error message does not match, re-raise the exception
|
||||
raise
|
||||
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
with llm:
|
||||
|
||||
def check_model(model):
|
||||
fc1 = model.model.decoder.layers[0].fc1
|
||||
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
|
||||
if kv_cache_dtype == "ptpc_fp8":
|
||||
@ -40,17 +58,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
|
||||
if current_platform.has_device_capability(94):
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fnuz
|
||||
else:
|
||||
pytest.skip()
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
except AssertionError as e:
|
||||
if str(
|
||||
e
|
||||
) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501
|
||||
# If the error message matches, the test passes
|
||||
pass
|
||||
else:
|
||||
# If the error message does not match, re-raise the exception
|
||||
raise
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
@ -7,10 +7,10 @@ Run `pytest tests/quantization/test_quark.py`.
|
||||
See also `tests/kernels/moe/test_mxfp4_moe.py`.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
|
||||
import huggingface_hub
|
||||
import lm_eval
|
||||
@ -24,9 +24,8 @@ from vllm.platforms import current_platform
|
||||
|
||||
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
|
||||
"quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
|
||||
if QUARK_MXFP4_AVAILABLE:
|
||||
from quark.torch.export.nn.modules.realquantizer import (
|
||||
@ -43,11 +42,9 @@ except huggingface_hub.errors.RepositoryNotFoundError:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module relies on V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
|
||||
@ -132,13 +129,12 @@ def test_quark_fp8_parity(vllm_runner):
|
||||
}
|
||||
with (vllm_runner(quark_model_id, **llm_kwargs) as
|
||||
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
|
||||
quark_model = (quark_handle.llm.llm_engine.model_executor.
|
||||
driver_worker.model_runner.model)
|
||||
quark_state_dict = quark_model.state_dict()
|
||||
|
||||
fp8_model = (fp8_handle.llm.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
fp8_state_dict = fp8_model.state_dict()
|
||||
def get_state_dict(model):
|
||||
return {k: v.cpu() for k, v in model.state_dict().items()}
|
||||
|
||||
quark_state_dict, = quark_handle.apply_model(get_state_dict)
|
||||
fp8_state_dict, = fp8_handle.apply_model(get_state_dict)
|
||||
|
||||
assert fp8_state_dict.keys() == quark_state_dict.keys()
|
||||
|
||||
|
||||
@ -105,18 +105,21 @@ def test_register_quantization_config():
|
||||
])
|
||||
def test_custom_quant(vllm_runner, model, monkeypatch):
|
||||
"""Test infer with the custom quantization method."""
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
with vllm_runner(model_name=model,
|
||||
quantization="custom_quant",
|
||||
enforce_eager=True) as llm:
|
||||
|
||||
model = llm.llm.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
layer = model.model.layers[0]
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
|
||||
# Check the quantization method is FakeQuantLinearMethod
|
||||
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
|
||||
# Check the quantization method is FakeQuantLinearMethod
|
||||
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
@ -13,6 +13,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -55,6 +56,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
from vllm.worker.model_runner_base import InputProcessingError
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
@ -1817,13 +1819,16 @@ class LLMEngine:
|
||||
return sampling_params
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
method: Union[str, Callable[[WorkerBase], _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.model_executor.collective_rpc(method, timeout, args,
|
||||
kwargs)
|
||||
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
return self.collective_rpc("apply_model", args=(func, ))
|
||||
|
||||
|
||||
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||
|
||||
@ -522,9 +522,14 @@ class LLM:
|
||||
"""
|
||||
Run a function directly on the model inside each worker,
|
||||
returning the result for each of them.
|
||||
|
||||
!!! warning
|
||||
To reduce the overhead of data transfer, avoid returning large
|
||||
arrays or tensors from this method. If you must return them,
|
||||
make sure you move them to CPU first to avoid taking up additional
|
||||
VRAM!
|
||||
"""
|
||||
executor = self.llm_engine.model_executor
|
||||
return executor.apply_model(func)
|
||||
return self.llm_engine.apply_model(func)
|
||||
|
||||
def _get_beam_search_lora_requests(
|
||||
self,
|
||||
|
||||
@ -5,11 +5,10 @@ import asyncio
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
||||
Union)
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Set, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from typing_extensions import TypeVar
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
|
||||
import vllm.platforms
|
||||
from vllm.config import VllmConfig
|
||||
@ -63,10 +62,10 @@ class ExecutorBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
method: Union[str, Callable[[WorkerBase], _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
"""
|
||||
Execute an RPC call on all workers.
|
||||
|
||||
@ -91,7 +90,7 @@ class ExecutorBase(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
"""Determine the number of available blocks for the GPU KV cache and
|
||||
swappable CPU KV cache.
|
||||
|
||||
@ -99,9 +98,10 @@ class ExecutorBase(ABC):
|
||||
ExecutorBase may require modification of the result, e.g. to ensure the
|
||||
selected cache sizes are compatible with all workers.
|
||||
|
||||
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
||||
are blocks that are "active" on the device and can be appended to.
|
||||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
||||
Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where
|
||||
`num_gpu_blocks` are blocks that are "active" on the device and can be
|
||||
appended to.
|
||||
`num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be
|
||||
appended to.
|
||||
"""
|
||||
results = self.collective_rpc("determine_num_available_blocks")
|
||||
@ -127,16 +127,15 @@ class ExecutorBase(ABC):
|
||||
self.collective_rpc("initialize_cache",
|
||||
args=(num_gpu_blocks, num_cpu_blocks))
|
||||
|
||||
@deprecated("`llm_engine.model_executor.apply_model` will no longer work "
|
||||
"in V1 Engine. Please replace with `llm_engine.apply_model` "
|
||||
"and set `VLLM_ALLOW_INSECURE_SERIALIZATION=1`.")
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
"""
|
||||
Run a function directly on the model inside each worker,
|
||||
returning the result for each of them.
|
||||
"""
|
||||
|
||||
def rpc_func(worker: WorkerBase) -> _R:
|
||||
return func(worker.get_model())
|
||||
|
||||
return self.collective_rpc(rpc_func)
|
||||
return self.collective_rpc("apply_model", args=(func, ))
|
||||
|
||||
@cached_property # Avoid unnecessary RPC calls
|
||||
def supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
@ -308,8 +307,8 @@ class DistributedExecutorBase(ExecutorBase):
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[Any]:
|
||||
return self._run_workers(method, *args, **(kwargs or {}))
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -5,6 +5,7 @@ from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -33,6 +34,7 @@ from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
|
||||
StatLoggerFactory)
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -319,12 +321,15 @@ class LLMEngine:
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
method: Union[str, Callable[[WorkerBase], _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
return self.collective_rpc("apply_model", args=(func, ))
|
||||
|
||||
def __del__(self):
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
||||
@ -5,7 +5,8 @@ import dataclasses
|
||||
import os
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
|
||||
import cloudpickle
|
||||
import torch
|
||||
@ -28,6 +29,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
@warn_for_unimplemented_methods
|
||||
class WorkerBase:
|
||||
@ -70,6 +73,10 @@ class WorkerBase:
|
||||
def get_model(self) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
|
||||
"""Apply a function on the model inside this worker."""
|
||||
return fn(self.get_model())
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load model onto target device."""
|
||||
raise NotImplementedError
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user