mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 05:51:21 +08:00
Revert #26113 "[Frontend] CompilationConfig overhaul (#20283): deprecate use_inductor in favor of backend, simplify custom_ops" (#26472)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
92be3f3517
commit
5728da11ea
@ -258,13 +258,13 @@ def tractable_computation(
|
|||||||
|
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def run_model(
|
def run_model(
|
||||||
llama_config, use_compile: bool, backend: str, split_attn: bool = False
|
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if use_compile:
|
if use_compile:
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
backend=backend,
|
use_inductor=use_inductor,
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
)
|
)
|
||||||
if split_attn:
|
if split_attn:
|
||||||
@ -338,8 +338,8 @@ def run_model(
|
|||||||
return output.cpu()
|
return output.cpu()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("backend", ["inductor", "eager"])
|
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||||
def test_toy_llama(backend: str):
|
def test_toy_llama(use_inductor: bool):
|
||||||
# compare output with and without piecewise compilation
|
# compare output with and without piecewise compilation
|
||||||
|
|
||||||
llama_config = LlamaConfig(
|
llama_config = LlamaConfig(
|
||||||
@ -358,10 +358,10 @@ def test_toy_llama(backend: str):
|
|||||||
num_backend_compilations=0,
|
num_backend_compilations=0,
|
||||||
num_cudagraph_captured=0,
|
num_cudagraph_captured=0,
|
||||||
):
|
):
|
||||||
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
|
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
|
||||||
run_model(tractable_config, backend="eager", use_compile=False)
|
run_model(tractable_config, use_inductor=False, use_compile=False)
|
||||||
|
|
||||||
if backend == "inductor":
|
if use_inductor:
|
||||||
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
|
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
|
||||||
else:
|
else:
|
||||||
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
||||||
@ -377,8 +377,10 @@ def test_toy_llama(backend: str):
|
|||||||
num_cudagraph_captured=2,
|
num_cudagraph_captured=2,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
|
outputs.append(
|
||||||
run_model(tractable_config, backend=backend, use_compile=True)
|
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
|
||||||
|
)
|
||||||
|
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
|
||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
@ -393,9 +395,16 @@ def test_toy_llama(backend: str):
|
|||||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
):
|
):
|
||||||
outputs.append(
|
outputs.append(
|
||||||
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
|
run_model(
|
||||||
|
llama_config,
|
||||||
|
use_inductor=use_inductor,
|
||||||
|
use_compile=True,
|
||||||
|
split_attn=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
|
run_model(
|
||||||
|
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(1, len(outputs)):
|
for i in range(1, len(outputs)):
|
||||||
assert torch.allclose(outputs[0], outputs[i])
|
assert torch.allclose(outputs[0], outputs[i])
|
||||||
|
|||||||
@ -37,59 +37,57 @@ class Relu3(ReLUSquaredActivation):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env, torch_level, backend, ops_enabled, default_on",
|
"env, torch_level, use_inductor, ops_enabled, default_on",
|
||||||
[
|
[
|
||||||
# Default values based on compile level
|
# Default values based on compile level
|
||||||
# - All by default (no Inductor compilation)
|
# - All by default (no Inductor compilation)
|
||||||
(None, 0, "eager", [True] * 4, True),
|
(None, 0, False, [True] * 4, True),
|
||||||
(None, 1, "eager", [True] * 4, True),
|
(None, 1, True, [True] * 4, True),
|
||||||
(None, 2, "eager", [True] * 4, True),
|
(None, 2, False, [True] * 4, True),
|
||||||
(None, 3, "eager", [True] * 4, True),
|
|
||||||
# - None by default (with Inductor)
|
# - None by default (with Inductor)
|
||||||
(None, 0, "inductor", [True] * 4, True),
|
(None, 3, True, [False] * 4, False),
|
||||||
# - None by default (with Inductor)
|
(None, 4, True, [False] * 4, False),
|
||||||
(None, 1, "inductor", [False] * 4, False),
|
# - All by default (without Inductor)
|
||||||
(None, 2, "inductor", [False] * 4, False),
|
(None, 3, False, [True] * 4, True),
|
||||||
(None, 3, "inductor", [False] * 4, False),
|
(None, 4, False, [True] * 4, True),
|
||||||
# Explicitly enabling/disabling
|
# Explicitly enabling/disabling
|
||||||
#
|
#
|
||||||
# Default: all
|
# Default: all
|
||||||
#
|
#
|
||||||
# All but SiluAndMul
|
# All but SiluAndMul
|
||||||
("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True),
|
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
|
||||||
# Only ReLU3
|
# Only ReLU3
|
||||||
("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False),
|
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
|
||||||
# All but SiluAndMul
|
# All but SiluAndMul
|
||||||
("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True),
|
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
|
||||||
# All but ReLU3 (even if ReLU2 is on)
|
# All but ReLU3 (even if ReLU2 is on)
|
||||||
("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True),
|
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
|
||||||
# RMSNorm and SiluAndMul
|
# RMSNorm and SiluAndMul
|
||||||
("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False),
|
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
|
||||||
# All but RMSNorm
|
# All but RMSNorm
|
||||||
("-rms_norm", 3, "eager", [0, 1, 1, 1], True),
|
("-rms_norm", 3, False, [0, 1, 1, 1], True),
|
||||||
#
|
#
|
||||||
# Default: none
|
# Default: none
|
||||||
#
|
#
|
||||||
# Only ReLU3
|
# Only ReLU3
|
||||||
("none,+relu3", 3, "inductor", [0, 0, 0, 1], False),
|
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
|
||||||
# All but RMSNorm
|
# All but RMSNorm
|
||||||
("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True),
|
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_enabled_ops(
|
def test_enabled_ops(
|
||||||
env: Optional[str],
|
env: Optional[str],
|
||||||
torch_level: int,
|
torch_level: int,
|
||||||
backend: str,
|
use_inductor: bool,
|
||||||
ops_enabled: list[int],
|
ops_enabled: list[int],
|
||||||
default_on: bool,
|
default_on: bool,
|
||||||
):
|
):
|
||||||
custom_ops = env.split(",") if env else []
|
custom_ops = env.split(",") if env else []
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
backend=backend, level=torch_level, custom_ops=custom_ops
|
use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# breakpoint()
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
assert CustomOp.default_on() == default_on
|
assert CustomOp.default_on() == default_on
|
||||||
|
|
||||||
|
|||||||
@ -34,7 +34,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||||
if compilation_config.backend == "inductor":
|
if compilation_config.use_inductor:
|
||||||
# Use standalone compile only if requested, version is new enough,
|
# Use standalone compile only if requested, version is new enough,
|
||||||
# and the symbol actually exists in this PyTorch build.
|
# and the symbol actually exists in this PyTorch build.
|
||||||
if (
|
if (
|
||||||
@ -48,10 +48,6 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
|||||||
logger.debug("Using InductorAdaptor")
|
logger.debug("Using InductorAdaptor")
|
||||||
return InductorAdaptor()
|
return InductorAdaptor()
|
||||||
else:
|
else:
|
||||||
assert compilation_config.backend == "eager", (
|
|
||||||
"Custom backends not supported with CompilationLevel.PIECEWISE"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Using EagerAdaptor")
|
logger.debug("Using EagerAdaptor")
|
||||||
return EagerAdaptor()
|
return EagerAdaptor()
|
||||||
|
|
||||||
|
|||||||
@ -180,11 +180,10 @@ class CompilationConfig:
|
|||||||
"""The directory to store the compiled graph, to accelerate Inductor
|
"""The directory to store the compiled graph, to accelerate Inductor
|
||||||
compilation. By default, it will use model-related information to generate
|
compilation. By default, it will use model-related information to generate
|
||||||
a cache directory."""
|
a cache directory."""
|
||||||
backend: str = "inductor"
|
backend: str = ""
|
||||||
"""The backend for compilation. It needs to be a string:
|
"""The backend for compilation. It needs to be a string:
|
||||||
|
|
||||||
- "" (empty string): use the default backend ("inductor" on CUDA-alike
|
- "" (empty string): use the default backend.
|
||||||
platforms).
|
|
||||||
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
||||||
- "full.module.name": a qualified name which can be used to import the
|
- "full.module.name": a qualified name which can be used to import the
|
||||||
|
|
||||||
@ -193,11 +192,7 @@ class CompilationConfig:
|
|||||||
distributed setting. When the compilation level is 1 or 2, the backend is
|
distributed setting. When the compilation level is 1 or 2, the backend is
|
||||||
used for the compilation directly (it sees the whole graph). When the
|
used for the compilation directly (it sees the whole graph). When the
|
||||||
compilation level is 3, the backend is used for the piecewise compilation
|
compilation level is 3, the backend is used for the piecewise compilation
|
||||||
(it sees a part of the graph). The backend can not be custom for compilation
|
(it sees a part of the graph)."""
|
||||||
level 3. Furthermore, compilation is only piecewise if splitting ops is set
|
|
||||||
accordingly and use_inductor_cudagraphs_partition is off. Note that the
|
|
||||||
default options for splitting ops are sufficient for piecewise compilation.
|
|
||||||
"""
|
|
||||||
custom_ops: list[str] = field(default_factory=list)
|
custom_ops: list[str] = field(default_factory=list)
|
||||||
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
|
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
|
||||||
to enable all, 'none' to disable all. Also specify a list of custom op
|
to enable all, 'none' to disable all. Also specify a list of custom op
|
||||||
@ -215,12 +210,8 @@ class CompilationConfig:
|
|||||||
compilation."""
|
compilation."""
|
||||||
|
|
||||||
# Inductor capture
|
# Inductor capture
|
||||||
use_inductor: Optional[bool] = None
|
use_inductor: bool = True
|
||||||
"""
|
"""Whether to use inductor compilation:
|
||||||
Whether to use inductor compilation.
|
|
||||||
|
|
||||||
This flag is deprecated and will be removed.
|
|
||||||
Please use the 'backend' option instead.
|
|
||||||
|
|
||||||
- False: inductor compilation is not used. graph runs in eager
|
- False: inductor compilation is not used. graph runs in eager
|
||||||
(custom_ops enabled by default).
|
(custom_ops enabled by default).
|
||||||
@ -228,11 +219,7 @@ class CompilationConfig:
|
|||||||
One graph for symbolic shape and one graph per size in compile_sizes
|
One graph for symbolic shape and one graph per size in compile_sizes
|
||||||
are compiled using configurations in inductor_compile_config.
|
are compiled using configurations in inductor_compile_config.
|
||||||
|
|
||||||
This setting is ignored if level<PIECEWISE.
|
This setting is ignored if level<PIECEWISE."""
|
||||||
|
|
||||||
For future compatibility:
|
|
||||||
If use_inductor is True, backend="inductor" otherwise backend="eager".
|
|
||||||
"""
|
|
||||||
compile_sizes: Optional[list[Union[int, str]]] = None
|
compile_sizes: Optional[list[Union[int, str]]] = None
|
||||||
"""Sizes to compile for inductor. In addition
|
"""Sizes to compile for inductor. In addition
|
||||||
to integers, it also supports "cudagraph_capture_sizes" to
|
to integers, it also supports "cudagraph_capture_sizes" to
|
||||||
@ -538,43 +525,7 @@ class CompilationConfig:
|
|||||||
"(where 'op' is the registered op name)"
|
"(where 'op' is the registered op name)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Currently only eager and inductor backend are supported.
|
|
||||||
# for piecewise compilation. Custom backends are not suppported for
|
|
||||||
# piecewise compilation. Update when more backends are supported.
|
|
||||||
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
|
|
||||||
"",
|
|
||||||
"eager",
|
|
||||||
"inductor",
|
|
||||||
]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid backend for piecewise compilation: {self.backend}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_inductor is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The 'use_inductor' flag is deprecated and will be\
|
|
||||||
removed in a future release."
|
|
||||||
"Please use the 'backend' option instead.",
|
|
||||||
)
|
|
||||||
self.backend = "inductor" if self.use_inductor else "eager"
|
|
||||||
|
|
||||||
if self.backend == "":
|
|
||||||
self.backend = "inductor"
|
|
||||||
|
|
||||||
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||||
"""
|
|
||||||
Initialize the backend for the compilation config from a vllm config.
|
|
||||||
Arguments:
|
|
||||||
vllm_config: The vllm config to initialize the backend from.
|
|
||||||
Returns:
|
|
||||||
The backend for the compilation config.
|
|
||||||
"""
|
|
||||||
if self.level is None:
|
|
||||||
raise ValueError(
|
|
||||||
"No compilation level is set. This method should only be \
|
|
||||||
called via vllm config where the level is set if none is \
|
|
||||||
provided."
|
|
||||||
)
|
|
||||||
if self.level == CompilationLevel.NO_COMPILATION:
|
if self.level == CompilationLevel.NO_COMPILATION:
|
||||||
raise ValueError("No compilation level is set.")
|
raise ValueError("No compilation level is set.")
|
||||||
|
|
||||||
@ -582,15 +533,15 @@ class CompilationConfig:
|
|||||||
|
|
||||||
torch_backends = list_backends(exclude_tags=tuple())
|
torch_backends = list_backends(exclude_tags=tuple())
|
||||||
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
||||||
|
if self.backend == "":
|
||||||
|
return "eager"
|
||||||
if self.backend in torch_backends:
|
if self.backend in torch_backends:
|
||||||
return self.backend
|
return self.backend
|
||||||
return resolve_obj_by_qualname(self.backend)
|
return resolve_obj_by_qualname(self.backend)
|
||||||
|
|
||||||
|
# TODO: pass user-specified backend to piecewise compilation
|
||||||
|
# merge with the config use_inductor
|
||||||
assert self.level == CompilationLevel.PIECEWISE
|
assert self.level == CompilationLevel.PIECEWISE
|
||||||
if self.backend not in ["eager", "inductor"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid backend for piecewise compilation: {self.backend}"
|
|
||||||
)
|
|
||||||
|
|
||||||
from vllm.compilation.backends import VllmBackend
|
from vllm.compilation.backends import VllmBackend
|
||||||
|
|
||||||
@ -743,7 +694,7 @@ class CompilationConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
inductor_used = (
|
inductor_used = (
|
||||||
self.level == CompilationLevel.PIECEWISE and self.backend == "inductor"
|
self.level == CompilationLevel.PIECEWISE and self.use_inductor
|
||||||
) or (
|
) or (
|
||||||
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
|
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -318,25 +318,6 @@ class VllmConfig:
|
|||||||
# NB: Passing both --enforce-eager and a compilation level
|
# NB: Passing both --enforce-eager and a compilation level
|
||||||
# in V0 means the compilation level wins out.
|
# in V0 means the compilation level wins out.
|
||||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
else:
|
|
||||||
assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION
|
|
||||||
assert self.compilation_config.level <= CompilationLevel.PIECEWISE
|
|
||||||
assert self.compilation_config.level <= 3
|
|
||||||
|
|
||||||
# If user does not set custom ops via none or all set it here based on
|
|
||||||
# compilation level and backend.
|
|
||||||
if (
|
|
||||||
self.compilation_config.custom_ops.count("none")
|
|
||||||
+ self.compilation_config.custom_ops.count("all")
|
|
||||||
== 0
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
self.compilation_config.level > 0
|
|
||||||
and self.compilation_config.backend != "eager"
|
|
||||||
):
|
|
||||||
self.compilation_config.custom_ops.append("none")
|
|
||||||
else:
|
|
||||||
self.compilation_config.custom_ops.append("all")
|
|
||||||
|
|
||||||
# async tp is built on top of sequence parallelism
|
# async tp is built on top of sequence parallelism
|
||||||
# and requires it to be enabled.
|
# and requires it to be enabled.
|
||||||
|
|||||||
@ -114,9 +114,7 @@ class CustomOp(nn.Module):
|
|||||||
custom_ops = compilation_config.custom_ops
|
custom_ops = compilation_config.custom_ops
|
||||||
if not hasattr(cls, "name"):
|
if not hasattr(cls, "name"):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Custom op %s was not registered, which means it won't appear\
|
"Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501
|
||||||
in the op registry. It will be enabled/disabled based on the\
|
|
||||||
global settings.", # noqa: E501
|
|
||||||
cls.__name__,
|
cls.__name__,
|
||||||
)
|
)
|
||||||
return CustomOp.default_on()
|
return CustomOp.default_on()
|
||||||
@ -130,17 +128,19 @@ class CustomOp(nn.Module):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def default_on() -> bool:
|
def default_on() -> bool:
|
||||||
"""
|
"""
|
||||||
Behavior controlled by `CompilationConfig.custom_ops`: On by default if
|
On by default if PyTorch Inductor is not used.
|
||||||
'all', off by default if 'none'.
|
Specifying 'all' or 'none' in custom_op takes precedence.
|
||||||
When PyTorch Inductor is used, 'none' is the default value,
|
|
||||||
otherwise 'all'.
|
|
||||||
"""
|
"""
|
||||||
|
from vllm.config import CompilationLevel
|
||||||
|
|
||||||
compilation_config = get_cached_compilation_config()
|
compilation_config = get_cached_compilation_config()
|
||||||
|
default_on = (
|
||||||
|
compilation_config.level < CompilationLevel.PIECEWISE
|
||||||
|
or not compilation_config.use_inductor
|
||||||
|
)
|
||||||
count_none = compilation_config.custom_ops.count("none")
|
count_none = compilation_config.custom_ops.count("none")
|
||||||
count_all = compilation_config.custom_ops.count("all")
|
count_all = compilation_config.custom_ops.count("all")
|
||||||
assert count_none + count_all == 1
|
return default_on and not count_none > 0 or count_all > 0
|
||||||
|
|
||||||
return not count_none > 0 or count_all > 0
|
|
||||||
|
|
||||||
# Dictionary of all custom ops (classes, indexed by registered name).
|
# Dictionary of all custom ops (classes, indexed by registered name).
|
||||||
# To check if an op with a name is enabled, call .enabled() on the class.
|
# To check if an op with a name is enabled, call .enabled() on the class.
|
||||||
|
|||||||
@ -274,6 +274,8 @@ class CpuPlatform(Platform):
|
|||||||
"epilogue_fusion": True,
|
"epilogue_fusion": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if compilation_config.use_inductor:
|
||||||
|
compilation_config.custom_ops = ["none"]
|
||||||
|
|
||||||
if vllm_config.lora_config is not None:
|
if vllm_config.lora_config is not None:
|
||||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user