mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 06:44:30 +08:00
[Bugfix] fix custom op test (#25429)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
aac85cc6d6
commit
d96a3fc653
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -34,15 +35,15 @@ class Relu3(ReLUSquaredActivation):
|
|||||||
[
|
[
|
||||||
# Default values based on compile level
|
# Default values based on compile level
|
||||||
# - All by default (no Inductor compilation)
|
# - All by default (no Inductor compilation)
|
||||||
("", 0, False, [True] * 4, True),
|
(None, 0, False, [True] * 4, True),
|
||||||
("", 1, True, [True] * 4, True),
|
(None, 1, True, [True] * 4, True),
|
||||||
("", 2, False, [True] * 4, True),
|
(None, 2, False, [True] * 4, True),
|
||||||
# - None by default (with Inductor)
|
# - None by default (with Inductor)
|
||||||
("", 3, True, [False] * 4, False),
|
(None, 3, True, [False] * 4, False),
|
||||||
("", 4, True, [False] * 4, False),
|
(None, 4, True, [False] * 4, False),
|
||||||
# - All by default (without Inductor)
|
# - All by default (without Inductor)
|
||||||
("", 3, False, [True] * 4, True),
|
(None, 3, False, [True] * 4, True),
|
||||||
("", 4, False, [True] * 4, True),
|
(None, 4, False, [True] * 4, True),
|
||||||
# Explicitly enabling/disabling
|
# Explicitly enabling/disabling
|
||||||
#
|
#
|
||||||
# Default: all
|
# Default: all
|
||||||
@ -54,7 +55,7 @@ class Relu3(ReLUSquaredActivation):
|
|||||||
# All but SiluAndMul
|
# All but SiluAndMul
|
||||||
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
|
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
|
||||||
# All but ReLU3 (even if ReLU2 is on)
|
# All but ReLU3 (even if ReLU2 is on)
|
||||||
("-relu3,relu2", 3, False, [1, 1, 1, 0], True),
|
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
|
||||||
# RMSNorm and SiluAndMul
|
# RMSNorm and SiluAndMul
|
||||||
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
|
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
|
||||||
# All but RMSNorm
|
# All but RMSNorm
|
||||||
@ -67,12 +68,13 @@ class Relu3(ReLUSquaredActivation):
|
|||||||
# All but RMSNorm
|
# All but RMSNorm
|
||||||
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
|
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
|
||||||
])
|
])
|
||||||
def test_enabled_ops(env: str, torch_level: int, use_inductor: bool,
|
def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool,
|
||||||
ops_enabled: list[int], default_on: bool):
|
ops_enabled: list[int], default_on: bool):
|
||||||
|
custom_ops = env.split(',') if env else []
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
|
compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
|
||||||
level=torch_level,
|
level=torch_level,
|
||||||
custom_ops=env.split(",")))
|
custom_ops=custom_ops))
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
assert CustomOp.default_on() == default_on
|
assert CustomOp.default_on() == default_on
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user