mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 08:15:01 +08:00
[V1] Support LLM.apply_model (#18465)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
be874c0201
commit
3d9a1d2de5
@ -987,17 +987,7 @@ class VllmRunner:
|
|||||||
return [req_output.outputs.score for req_output in req_outputs]
|
return [req_output.outputs.score for req_output in req_outputs]
|
||||||
|
|
||||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||||
if hasattr(self.llm.llm_engine, "model_executor"):
|
return self.llm.apply_model(func)
|
||||||
# This works either in V0 or in V1 with
|
|
||||||
# VLLM_ENABLE_V1_MULTIPROCESSING=0
|
|
||||||
executor = self.llm.llm_engine.model_executor
|
|
||||||
return executor.apply_model(func)
|
|
||||||
|
|
||||||
# This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
|
|
||||||
def _apply_model(self):
|
|
||||||
return func(self.get_model())
|
|
||||||
|
|
||||||
return self.llm.llm_engine.collective_rpc(_apply_model)
|
|
||||||
|
|
||||||
def get_llm(self) -> LLM:
|
def get_llm(self) -> LLM:
|
||||||
return self.llm
|
return self.llm
|
||||||
|
|||||||
@ -1,20 +1,23 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import importlib
|
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from importlib.util import find_spec
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
||||||
|
QuarkLinearMethod, QuarkW4A4MXFP4)
|
||||||
|
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||||
|
QuarkW4A4MXFp4MoEMethod)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.flashinfer import has_flashinfer
|
from vllm.utils.flashinfer import has_flashinfer
|
||||||
|
|
||||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
|
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||||
"quark") is not None and version.parse(
|
|
||||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||||
|
|
||||||
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
|
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
|
||||||
@ -39,6 +42,12 @@ class ModelCase:
|
|||||||
tp: int
|
tp: int
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def enable_pickle(monkeypatch):
|
||||||
|
"""`LLM.apply_model` requires pickling a function."""
|
||||||
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('model_case', [
|
@pytest.mark.parametrize('model_case', [
|
||||||
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
|
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
|
||||||
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
||||||
@ -55,21 +64,19 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
|||||||
tensor_parallel_size=model_case.tp,
|
tensor_parallel_size=model_case.tp,
|
||||||
load_format="dummy") as llm:
|
load_format="dummy") as llm:
|
||||||
|
|
||||||
# TODO: llm.apply_model(check_model) currently relies on V0 internals.
|
def check_model(model):
|
||||||
# Re-enable this later.
|
layer = model.model.layers[0]
|
||||||
# def check_model(model):
|
|
||||||
# layer = model.model.layers[0]
|
|
||||||
|
|
||||||
# qkv_proj = layer.self_attn.qkv_proj
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
|
||||||
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||||
# assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
|
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
|
||||||
|
|
||||||
# assert isinstance(layer.mlp.experts.quant_method,
|
assert isinstance(layer.mlp.experts.quant_method,
|
||||||
# QuarkW4A4MXFp4MoEMethod)
|
QuarkW4A4MXFp4MoEMethod)
|
||||||
|
|
||||||
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
|
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
|
||||||
# llm.apply_model(check_model)
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy("Today I am in the French Alps and",
|
output = llm.generate_greedy("Today I am in the French Alps and",
|
||||||
max_tokens=20)
|
max_tokens=20)
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from vllm.multimodal.image import rescale_image_size
|
from vllm.multimodal.image import rescale_image_size
|
||||||
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
|
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
|
||||||
|
from vllm.utils import set_default_torch_num_threads
|
||||||
|
|
||||||
from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput,
|
from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput,
|
||||||
PromptVideoInput, VllmRunner)
|
PromptVideoInput, VllmRunner)
|
||||||
@ -17,11 +18,9 @@ from ...utils import check_logprobs_close
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def use_v0_only(monkeypatch):
|
def enable_pickle(monkeypatch):
|
||||||
"""
|
"""`LLM.apply_model` requires pickling a function."""
|
||||||
V1 Test: batch_make_xxxxx_embeddings calls a V0 internal
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
"""
|
|
||||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
|
||||||
|
|
||||||
|
|
||||||
models = ["Qwen/Qwen2-VL-2B-Instruct"]
|
models = ["Qwen/Qwen2-VL-2B-Instruct"]
|
||||||
@ -126,9 +125,8 @@ def batch_make_image_embeddings(
|
|||||||
image_grid_thw_on_device = image_grid_thw.to(visual.device,
|
image_grid_thw_on_device = image_grid_thw.to(visual.device,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
return visual(pixel_values_on_device,
|
return visual(pixel_values_on_device,
|
||||||
grid_thw=image_grid_thw_on_device)
|
grid_thw=image_grid_thw_on_device).cpu()
|
||||||
|
|
||||||
# V1 Test: this calls a V0 internal.
|
|
||||||
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
||||||
|
|
||||||
# split into original batches
|
# split into original batches
|
||||||
@ -210,7 +208,7 @@ def batch_make_video_embeddings(
|
|||||||
video_grid_thw_on_device = video_grid_thw.to(visual.device,
|
video_grid_thw_on_device = video_grid_thw.to(visual.device,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
return visual(pixel_values_on_device,
|
return visual(pixel_values_on_device,
|
||||||
grid_thw=video_grid_thw_on_device)
|
grid_thw=video_grid_thw_on_device).cpu()
|
||||||
|
|
||||||
# V1 Test: this calls a V0 internal.
|
# V1 Test: this calls a V0 internal.
|
||||||
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
||||||
@ -266,7 +264,9 @@ def run_embedding_input_test(
|
|||||||
processor = AutoProcessor.from_pretrained(model)
|
processor = AutoProcessor.from_pretrained(model)
|
||||||
|
|
||||||
# max_model_len should be greater than image_feature_size
|
# max_model_len should be greater than image_feature_size
|
||||||
with vllm_runner(model,
|
with set_default_torch_num_threads(1):
|
||||||
|
vllm_model = vllm_runner(
|
||||||
|
model,
|
||||||
runner="generate",
|
runner="generate",
|
||||||
max_model_len=4000,
|
max_model_len=4000,
|
||||||
max_num_seqs=3,
|
max_num_seqs=3,
|
||||||
@ -276,9 +276,10 @@ def run_embedding_input_test(
|
|||||||
"video": mm_limit
|
"video": mm_limit
|
||||||
},
|
},
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
distributed_executor_backend=distributed_executor_backend
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
) as vllm_model:
|
)
|
||||||
|
|
||||||
|
with vllm_model:
|
||||||
outputs_per_case_for_original_input = [
|
outputs_per_case_for_original_input = [
|
||||||
vllm_model.generate_greedy_logprobs(prompts,
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
@ -329,9 +330,8 @@ def run_embedding_input_test(
|
|||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@pytest.mark.parametrize("num_logprobs", [10])
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
|
def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
|
||||||
size_factors, dtype: str,
|
size_factors, dtype, max_tokens,
|
||||||
max_tokens: int,
|
num_logprobs, monkeypatch) -> None:
|
||||||
num_logprobs: int) -> None:
|
|
||||||
images = [asset.pil_image for asset in image_assets]
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
inputs_per_case: list[tuple[
|
inputs_per_case: list[tuple[
|
||||||
|
|||||||
@ -112,7 +112,7 @@ def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
|
|||||||
monkeypatch) -> None:
|
monkeypatch) -> None:
|
||||||
|
|
||||||
# Test V1: this test hangs during setup on single-scale input.
|
# Test V1: this test hangs during setup on single-scale input.
|
||||||
# TODO: fixure out why and re-enable this on V1.
|
# TODO: figure out why and re-enable this on V1.
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||||
run_awq_test(
|
run_awq_test(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
|
|||||||
@ -43,12 +43,9 @@ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def use_v0_only(monkeypatch):
|
def enable_pickle(monkeypatch):
|
||||||
"""
|
"""`LLM.apply_model` requires pickling a function."""
|
||||||
This module relies on V0 internals, so set VLLM_USE_V1=0.
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
"""
|
|
||||||
if not current_platform.is_cpu():
|
|
||||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -176,10 +173,11 @@ def test_compressed_tensors_w8a8_logprobs(
|
|||||||
|
|
||||||
dtype = "bfloat16"
|
dtype = "bfloat16"
|
||||||
|
|
||||||
# skip language translation prompt for the static per tensor asym model
|
# skip language translation prompt for the static per tensor models
|
||||||
if (model_path ==
|
if model_path in (
|
||||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
|
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
|
||||||
): # noqa: E501
|
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
|
||||||
|
):
|
||||||
example_prompts = example_prompts[0:-1]
|
example_prompts = example_prompts[0:-1]
|
||||||
|
|
||||||
with hf_runner(model_path, dtype=dtype) as hf_model:
|
with hf_runner(model_path, dtype=dtype) as hf_model:
|
||||||
|
|||||||
@ -60,8 +60,8 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
|
|||||||
if use_rocm_aiter:
|
if use_rocm_aiter:
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
# vllm_runner.apply_model() relies on V0 internals.
|
# `LLM.apply_model` requires pickling a function.
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
||||||
|
|
||||||
def check_model(model):
|
def check_model(model):
|
||||||
@ -104,8 +104,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
|||||||
if use_rocm_aiter:
|
if use_rocm_aiter:
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
# vllm_runner.apply_model() relies on V0 internals.
|
# `LLM.apply_model` requires pickling a function.
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
if force_marlin:
|
if force_marlin:
|
||||||
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
|
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
|
||||||
|
|||||||
@ -31,22 +31,24 @@ MODEL_QUANT = [
|
|||||||
@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
|
@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
|
||||||
def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
|
def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
|
||||||
monkeypatch):
|
monkeypatch):
|
||||||
# vllm_runner.apply_model() relies on V0 internals.
|
# `LLM.apply_model` requires pickling a function.
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048)
|
|
||||||
|
|
||||||
linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
|
linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
|
||||||
GPTQLinearMethod)
|
GPTQLinearMethod)
|
||||||
|
|
||||||
for name, submodule in (vllm_model.llm.llm_engine.model_executor.
|
with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as llm:
|
||||||
driver_worker.model_runner.model.named_modules()):
|
|
||||||
|
def check_model(model):
|
||||||
|
for name, submodule in model.named_modules():
|
||||||
if name == "lm_head":
|
if name == "lm_head":
|
||||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
assert isinstance(submodule.quant_method,
|
||||||
|
linear_method_cls)
|
||||||
elif name == 'model.layers.0.self_attn.qkv_proj':
|
elif name == 'model.layers.0.self_attn.qkv_proj':
|
||||||
# The first layer is quantized using bits=4, group_size=128
|
# The first layer is quantized using bits=4, group_size=128
|
||||||
# desc_act=True
|
# desc_act=True
|
||||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
assert isinstance(submodule.quant_method,
|
||||||
|
linear_method_cls)
|
||||||
config = submodule.quant_method.quant_config
|
config = submodule.quant_method.quant_config
|
||||||
assert config.weight_bits == 4
|
assert config.weight_bits == 4
|
||||||
assert config.group_size == 128
|
assert config.group_size == 128
|
||||||
@ -54,9 +56,11 @@ def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
|
|||||||
elif name == 'model.layers.1.self_attn.qkv_proj':
|
elif name == 'model.layers.1.self_attn.qkv_proj':
|
||||||
# The second layer is quantized using bits=8, group_size=32
|
# The second layer is quantized using bits=8, group_size=32
|
||||||
# desc_act=False
|
# desc_act=False
|
||||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
assert isinstance(submodule.quant_method,
|
||||||
|
linear_method_cls)
|
||||||
config = submodule.quant_method.quant_config
|
config = submodule.quant_method.quant_config
|
||||||
assert get_dynamic_override(config, layer_name=name,
|
assert get_dynamic_override(config,
|
||||||
|
layer_name=name,
|
||||||
key="bits") == 8
|
key="bits") == 8
|
||||||
assert get_dynamic_override(config,
|
assert get_dynamic_override(config,
|
||||||
layer_name=name,
|
layer_name=name,
|
||||||
@ -66,6 +70,7 @@ def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool,
|
|||||||
elif (name == 'model.layers.2.self_attn.qkv_proj'
|
elif (name == 'model.layers.2.self_attn.qkv_proj'
|
||||||
or name == 'model.layers.2.mlp.gate_up_proj'):
|
or name == 'model.layers.2.mlp.gate_up_proj'):
|
||||||
# All other layers (layer index >= 2) are not quantized
|
# All other layers (layer index >= 2) are not quantized
|
||||||
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
|
assert isinstance(submodule.quant_method,
|
||||||
|
UnquantizedLinearMethod)
|
||||||
|
|
||||||
del vllm_model
|
llm.apply_model(check_model)
|
||||||
|
|||||||
@ -29,8 +29,8 @@ def test_lm_head(
|
|||||||
lm_head_quantized: bool,
|
lm_head_quantized: bool,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
# vllm_runner.apply_model() relies on V0 internals.
|
# `LLM.apply_model` requires pickling a function.
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
with vllm_runner(model_id, dtype=torch.float16,
|
with vllm_runner(model_id, dtype=torch.float16,
|
||||||
max_model_len=2048) as vllm_model:
|
max_model_len=2048) as vllm_model:
|
||||||
|
|
||||||
|
|||||||
@ -11,16 +11,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def use_v0_only(monkeypatch):
|
def enable_pickle(monkeypatch):
|
||||||
"""
|
"""`LLM.apply_model` requires pickling a function."""
|
||||||
This module relies on V0 internals, so set VLLM_USE_V1=0.
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
"""
|
|
||||||
if not current_platform.is_cpu():
|
|
||||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("modelopt"),
|
@pytest.mark.skipif(not is_quant_method_supported("modelopt"),
|
||||||
|
|||||||
@ -13,6 +13,16 @@ from vllm.model_executor.layers.quantization.ptpc_fp8 import (
|
|||||||
PTPCFp8LinearMethod)
|
PTPCFp8LinearMethod)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
UNSUPPORTED_STR = (
|
||||||
|
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
|
||||||
|
"support output dtype of bfloat16. torch.float16 is specified.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def enable_pickle(monkeypatch):
|
||||||
|
"""`LLM.apply_model` requires pickling a function."""
|
||||||
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
|
@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
|
||||||
reason="PTPC FP8 is not supported on this GPU type.")
|
reason="PTPC FP8 is not supported on this GPU type.")
|
||||||
@ -21,14 +31,22 @@ from vllm.platforms import current_platform
|
|||||||
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
|
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
|
||||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
|
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
|
||||||
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
|
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with vllm_runner("facebook/opt-125m",
|
llm = vllm_runner("facebook/opt-125m",
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
quantization="ptpc_fp8",
|
quantization="ptpc_fp8",
|
||||||
kv_cache_dtype=kv_cache_dtype) as llm:
|
kv_cache_dtype=kv_cache_dtype)
|
||||||
|
except AssertionError as e:
|
||||||
|
if str(e) == UNSUPPORTED_STR:
|
||||||
|
# If the error message matches, the test passes
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# If the error message does not match, re-raise the exception
|
||||||
|
raise
|
||||||
|
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
with llm:
|
||||||
|
|
||||||
|
def check_model(model):
|
||||||
fc1 = model.model.decoder.layers[0].fc1
|
fc1 = model.model.decoder.layers[0].fc1
|
||||||
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
|
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
|
||||||
if kv_cache_dtype == "ptpc_fp8":
|
if kv_cache_dtype == "ptpc_fp8":
|
||||||
@ -40,17 +58,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
|
|||||||
if current_platform.has_device_capability(94):
|
if current_platform.has_device_capability(94):
|
||||||
# For GPUs with hardware support, we keep weights in fp8
|
# For GPUs with hardware support, we keep weights in fp8
|
||||||
assert fc1.weight.dtype == torch.float8_e4m3fnuz
|
assert fc1.weight.dtype == torch.float8_e4m3fnuz
|
||||||
else:
|
|
||||||
pytest.skip()
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
assert output
|
assert output
|
||||||
except AssertionError as e:
|
|
||||||
if str(
|
|
||||||
e
|
|
||||||
) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501
|
|
||||||
# If the error message matches, the test passes
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# If the error message does not match, re-raise the exception
|
|
||||||
raise
|
|
||||||
|
|||||||
@ -7,10 +7,10 @@ Run `pytest tests/quantization/test_quark.py`.
|
|||||||
See also `tests/kernels/moe/test_mxfp4_moe.py`.
|
See also `tests/kernels/moe/test_mxfp4_moe.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
import lm_eval
|
import lm_eval
|
||||||
@ -24,8 +24,7 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
|
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
|
||||||
|
|
||||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
|
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||||
"quark") is not None and version.parse(
|
|
||||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||||
|
|
||||||
if QUARK_MXFP4_AVAILABLE:
|
if QUARK_MXFP4_AVAILABLE:
|
||||||
@ -43,11 +42,9 @@ except huggingface_hub.errors.RepositoryNotFoundError:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def use_v0_only(monkeypatch):
|
def enable_pickle(monkeypatch):
|
||||||
"""
|
"""`LLM.apply_model` requires pickling a function."""
|
||||||
This module relies on V0 internals, so set VLLM_USE_V1=0.
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
"""
|
|
||||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
|
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
|
||||||
@ -132,13 +129,12 @@ def test_quark_fp8_parity(vllm_runner):
|
|||||||
}
|
}
|
||||||
with (vllm_runner(quark_model_id, **llm_kwargs) as
|
with (vllm_runner(quark_model_id, **llm_kwargs) as
|
||||||
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
|
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
|
||||||
quark_model = (quark_handle.llm.llm_engine.model_executor.
|
|
||||||
driver_worker.model_runner.model)
|
|
||||||
quark_state_dict = quark_model.state_dict()
|
|
||||||
|
|
||||||
fp8_model = (fp8_handle.llm.llm_engine.model_executor.driver_worker.
|
def get_state_dict(model):
|
||||||
model_runner.model)
|
return {k: v.cpu() for k, v in model.state_dict().items()}
|
||||||
fp8_state_dict = fp8_model.state_dict()
|
|
||||||
|
quark_state_dict, = quark_handle.apply_model(get_state_dict)
|
||||||
|
fp8_state_dict, = fp8_handle.apply_model(get_state_dict)
|
||||||
|
|
||||||
assert fp8_state_dict.keys() == quark_state_dict.keys()
|
assert fp8_state_dict.keys() == quark_state_dict.keys()
|
||||||
|
|
||||||
|
|||||||
@ -105,18 +105,21 @@ def test_register_quantization_config():
|
|||||||
])
|
])
|
||||||
def test_custom_quant(vllm_runner, model, monkeypatch):
|
def test_custom_quant(vllm_runner, model, monkeypatch):
|
||||||
"""Test infer with the custom quantization method."""
|
"""Test infer with the custom quantization method."""
|
||||||
# vllm_runner.apply_model() relies on V0 internals.
|
# `LLM.apply_model` requires pickling a function.
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
with vllm_runner(model_name=model,
|
with vllm_runner(model_name=model,
|
||||||
quantization="custom_quant",
|
quantization="custom_quant",
|
||||||
enforce_eager=True) as llm:
|
enforce_eager=True) as llm:
|
||||||
|
|
||||||
model = llm.llm.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
def check_model(model):
|
||||||
layer = model.model.layers[0]
|
layer = model.model.layers[0]
|
||||||
qkv_proj = layer.self_attn.qkv_proj
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
|
||||||
# Check the quantization method is FakeQuantLinearMethod
|
# Check the quantization method is FakeQuantLinearMethod
|
||||||
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
|
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
assert output
|
assert output
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from typing import Sequence as GenericSequence
|
|||||||
from typing import Set, Type, Union, cast
|
from typing import Set, Type, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -55,6 +56,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
|||||||
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
|
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
from vllm.worker.model_runner_base import InputProcessingError
|
from vllm.worker.model_runner_base import InputProcessingError
|
||||||
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||||
@ -1817,13 +1819,16 @@ class LLMEngine:
|
|||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable[..., _R]],
|
method: Union[str, Callable[[WorkerBase], _R]],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: tuple = (),
|
args: tuple = (),
|
||||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||||
return self.model_executor.collective_rpc(method, timeout, args,
|
return self.model_executor.collective_rpc(method, timeout, args,
|
||||||
kwargs)
|
kwargs)
|
||||||
|
|
||||||
|
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||||
|
return self.collective_rpc("apply_model", args=(func, ))
|
||||||
|
|
||||||
|
|
||||||
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
|
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||||
|
|||||||
@ -522,9 +522,14 @@ class LLM:
|
|||||||
"""
|
"""
|
||||||
Run a function directly on the model inside each worker,
|
Run a function directly on the model inside each worker,
|
||||||
returning the result for each of them.
|
returning the result for each of them.
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
To reduce the overhead of data transfer, avoid returning large
|
||||||
|
arrays or tensors from this method. If you must return them,
|
||||||
|
make sure you move them to CPU first to avoid taking up additional
|
||||||
|
VRAM!
|
||||||
"""
|
"""
|
||||||
executor = self.llm_engine.model_executor
|
return self.llm_engine.apply_model(func)
|
||||||
return executor.apply_model(func)
|
|
||||||
|
|
||||||
def _get_beam_search_lora_requests(
|
def _get_beam_search_lora_requests(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -5,11 +5,10 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
from typing import Any, Awaitable, Callable, List, Optional, Set, Union
|
||||||
Union)
|
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar, deprecated
|
||||||
|
|
||||||
import vllm.platforms
|
import vllm.platforms
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -63,10 +62,10 @@ class ExecutorBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable[..., _R]],
|
method: Union[str, Callable[[WorkerBase], _R]],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: tuple = (),
|
||||||
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||||
"""
|
"""
|
||||||
Execute an RPC call on all workers.
|
Execute an RPC call on all workers.
|
||||||
|
|
||||||
@ -91,7 +90,7 @@ class ExecutorBase(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||||
"""Determine the number of available blocks for the GPU KV cache and
|
"""Determine the number of available blocks for the GPU KV cache and
|
||||||
swappable CPU KV cache.
|
swappable CPU KV cache.
|
||||||
|
|
||||||
@ -99,9 +98,10 @@ class ExecutorBase(ABC):
|
|||||||
ExecutorBase may require modification of the result, e.g. to ensure the
|
ExecutorBase may require modification of the result, e.g. to ensure the
|
||||||
selected cache sizes are compatible with all workers.
|
selected cache sizes are compatible with all workers.
|
||||||
|
|
||||||
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where
|
||||||
are blocks that are "active" on the device and can be appended to.
|
`num_gpu_blocks` are blocks that are "active" on the device and can be
|
||||||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
appended to.
|
||||||
|
`num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be
|
||||||
appended to.
|
appended to.
|
||||||
"""
|
"""
|
||||||
results = self.collective_rpc("determine_num_available_blocks")
|
results = self.collective_rpc("determine_num_available_blocks")
|
||||||
@ -127,16 +127,15 @@ class ExecutorBase(ABC):
|
|||||||
self.collective_rpc("initialize_cache",
|
self.collective_rpc("initialize_cache",
|
||||||
args=(num_gpu_blocks, num_cpu_blocks))
|
args=(num_gpu_blocks, num_cpu_blocks))
|
||||||
|
|
||||||
|
@deprecated("`llm_engine.model_executor.apply_model` will no longer work "
|
||||||
|
"in V1 Engine. Please replace with `llm_engine.apply_model` "
|
||||||
|
"and set `VLLM_ALLOW_INSECURE_SERIALIZATION=1`.")
|
||||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||||
"""
|
"""
|
||||||
Run a function directly on the model inside each worker,
|
Run a function directly on the model inside each worker,
|
||||||
returning the result for each of them.
|
returning the result for each of them.
|
||||||
"""
|
"""
|
||||||
|
return self.collective_rpc("apply_model", args=(func, ))
|
||||||
def rpc_func(worker: WorkerBase) -> _R:
|
|
||||||
return func(worker.get_model())
|
|
||||||
|
|
||||||
return self.collective_rpc(rpc_func)
|
|
||||||
|
|
||||||
@cached_property # Avoid unnecessary RPC calls
|
@cached_property # Avoid unnecessary RPC calls
|
||||||
def supported_tasks(self) -> tuple[SupportedTask, ...]:
|
def supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
@ -308,8 +307,8 @@ class DistributedExecutorBase(ExecutorBase):
|
|||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable],
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[dict[str, Any]] = None) -> list[Any]:
|
||||||
return self._run_workers(method, *args, **(kwargs or {}))
|
return self._run_workers(method, *args, **(kwargs or {}))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from collections.abc import Mapping
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -33,6 +34,7 @@ from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
|
|||||||
StatLoggerFactory)
|
StatLoggerFactory)
|
||||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||||
from vllm.v1.metrics.stats import IterationStats
|
from vllm.v1.metrics.stats import IterationStats
|
||||||
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -319,12 +321,15 @@ class LLMEngine:
|
|||||||
return self.engine_core.pin_lora(lora_id)
|
return self.engine_core.pin_lora(lora_id)
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable[..., _R]],
|
method: Union[str, Callable[[WorkerBase], _R]],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: tuple = (),
|
args: tuple = (),
|
||||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||||
|
|
||||||
|
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||||
|
return self.collective_rpc("apply_model", args=(func, ))
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if dp_group := getattr(self, "dp_group", None):
|
if dp_group := getattr(self, "dp_group", None):
|
||||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||||
|
|||||||
@ -5,7 +5,8 @@ import dataclasses
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type,
|
||||||
|
TypeVar, Union)
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import torch
|
import torch
|
||||||
@ -28,6 +29,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
|
|
||||||
@warn_for_unimplemented_methods
|
@warn_for_unimplemented_methods
|
||||||
class WorkerBase:
|
class WorkerBase:
|
||||||
@ -70,6 +73,10 @@ class WorkerBase:
|
|||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
|
||||||
|
"""Apply a function on the model inside this worker."""
|
||||||
|
return fn(self.get_model())
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
"""Load model onto target device."""
|
"""Load model onto target device."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user