mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:56:08 +08:00
[Feature] Support sequence parallelism for static fp8 quantization (#19181)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
parent
d0132f025d
commit
e6327c9b3e
@ -6,7 +6,9 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import FusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||
PassConfig, VllmConfig)
|
||||
@ -14,12 +16,15 @@ from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -30,13 +35,16 @@ prompts = [
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
def __init__(self,
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
vllm_config: VllmConfig = None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size)))
|
||||
self.norm = RMSNorm(hidden_size, 1e-05)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
@ -79,32 +87,138 @@ class TestModel(torch.nn.Module):
|
||||
return [torch.ops._C.fused_add_rms_norm.default]
|
||||
|
||||
|
||||
class TestQuantModel(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
vllm_config: VllmConfig = None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.vllm_config = vllm_config
|
||||
self.gate_proj = torch.nn.Parameter(torch.empty(
|
||||
(intermediate_size, hidden_size)),
|
||||
requires_grad=False)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True,
|
||||
use_per_token_if_dynamic=False)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
# Create a weight that is compatible with torch._scaled_mm,
|
||||
# which expects a column-major layout.
|
||||
self.w = torch.rand(hidden_size,
|
||||
intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
"""
|
||||
Forward pass implementing the operations in the FX graph
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor
|
||||
residual: Residual tensor from previous layer
|
||||
|
||||
Returns:
|
||||
Tuple containing the output tensor
|
||||
"""
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
|
||||
#matrix multiplication
|
||||
permute = self.gate_proj.permute(1, 0)
|
||||
mm = torch.mm(view, permute)
|
||||
|
||||
# Tensor parallel all-reduce
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm)
|
||||
|
||||
# layer normalization
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
|
||||
# for static input quantization
|
||||
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
|
||||
fp8_linear_result = self.fp8_linear.apply(norm_output,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.scale.to(
|
||||
norm_output.device))
|
||||
|
||||
return fp8_linear_result, residual_output
|
||||
|
||||
def ops_in_model_before(self):
|
||||
ops_to_remove = [torch.ops.vllm.all_reduce.default
|
||||
] # Always removed by SP
|
||||
# The following are only removed if fusion happens
|
||||
if self.vllm_config and self.vllm_config.compilation_config \
|
||||
.pass_config.enable_fusion:
|
||||
ops_to_remove.extend([
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
])
|
||||
return ops_to_remove
|
||||
|
||||
def ops_in_model_after(self):
|
||||
ops_to_add = [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
torch.ops.vllm.all_gather.default
|
||||
]
|
||||
# The following is only added if fusion happens
|
||||
if self.vllm_config and self.vllm_config.compilation_config \
|
||||
.pass_config.enable_fusion:
|
||||
ops_to_add.append(
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
|
||||
return ops_to_add
|
||||
|
||||
def ops_in_model(self):
|
||||
if self.vllm_config and self.vllm_config.compilation_config \
|
||||
.pass_config.enable_fusion:
|
||||
# If fusion happens, the fused op is the one
|
||||
# we check for (de)functionalization
|
||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
||||
] # noqa: E501
|
||||
else:
|
||||
# If no fusion, the original ops are checked
|
||||
return [
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
# TODO functionalization pass does not handle this yet
|
||||
# torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
def test_sequence_parallelism_pass(batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module],
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype,
|
||||
enable_fusion: bool):
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
# need to use torch.mp.spawn otherwise will have problems with
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(fn,
|
||||
args=(num_processes, batch_size, seq_len,
|
||||
hidden_size, dtype),
|
||||
args=(num_processes, test_model_cls,
|
||||
batch_size, seq_len, hidden_size,
|
||||
dtype, enable_fusion),
|
||||
nprocs=nprocs)
|
||||
|
||||
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype):
|
||||
def sequence_parallelism_pass_on_test_model(
|
||||
local_rank: int, world_size: int,
|
||||
test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype, enable_fusion: bool):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
@ -127,26 +241,39 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True))
|
||||
enable_sequence_parallelism=True,
|
||||
enable_fusion=enable_fusion,
|
||||
enable_noop=True)) # NoOp needed for fusion
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||
vllm_config.model_config = ModelConfig(model=model,
|
||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||
vllm_config.model_config = ModelConfig(model=model_name,
|
||||
task="auto",
|
||||
tokenizer=model,
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
seed=42)
|
||||
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
backend_no_func = TestBackend(sequence_parallelism_pass)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
backend_func = TestBackend(sequence_parallelism_pass, func_pass)
|
||||
|
||||
model = TestModel(hidden_size, hidden_size * 2)
|
||||
passes_for_backend = [noop_pass, sequence_parallelism_pass]
|
||||
|
||||
if enable_fusion:
|
||||
fusion_pass = FusionPass.instance(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
|
||||
backend_no_func = TestBackend(*passes_for_backend)
|
||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||
|
||||
model = test_model_cls(hidden_size,
|
||||
hidden_size * 2,
|
||||
vllm_config=vllm_config)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
@ -28,7 +28,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
sp_enabled: bool
|
||||
enable_fusion: bool
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
@ -67,49 +67,18 @@ class SPTestSettings:
|
||||
task: TaskOption = "auto",
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
parallel_setups = []
|
||||
for eager_mode_val in [False, True]:
|
||||
for pp_multiplier in [1, 2]:
|
||||
for chunked_prefill_val in [False, True]:
|
||||
parallel_setups.append(
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
enable_fusion=False,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val))
|
||||
return SPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=True,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=True,
|
||||
chunked_prefill=True)
|
||||
],
|
||||
parallel_setups=parallel_setups,
|
||||
distributed_backends=["mp", "ray"],
|
||||
vllm_major_versions=["1", "1"],
|
||||
task=task,
|
||||
@ -126,19 +95,44 @@ class SPTestSettings:
|
||||
multi_node_only: bool = False,
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
parallel_setups = []
|
||||
for eager_mode_val in [False, True]:
|
||||
for pp_multiplier in [1, 2]:
|
||||
for chunked_prefill_val in [False, True]:
|
||||
parallel_setups.append(
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
enable_fusion=False,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val))
|
||||
return SPTestSettings(
|
||||
parallel_setups=[
|
||||
parallel_setups=parallel_setups,
|
||||
distributed_backends=["mp", "ray"],
|
||||
vllm_major_versions=["1", "1"],
|
||||
task=task,
|
||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fp8_quant(
|
||||
*,
|
||||
tp_base: int = 2,
|
||||
pp_base: int = 1,
|
||||
task: TaskOption = "auto",
|
||||
multi_node_only: bool = False,
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
parallel_setups = []
|
||||
for fusion_val in [False, True]:
|
||||
parallel_setups.append(
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
],
|
||||
enable_fusion=fusion_val,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False))
|
||||
return SPTestSettings(
|
||||
parallel_setups=parallel_setups,
|
||||
distributed_backends=["mp", "ray"],
|
||||
vllm_major_versions=["1", "1"],
|
||||
task=task,
|
||||
@ -171,7 +165,7 @@ def _compare_sp(
|
||||
(
|
||||
tp_size,
|
||||
pp_size,
|
||||
sp_enabled,
|
||||
enable_fusion,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
@ -240,9 +234,9 @@ def _compare_sp(
|
||||
'compile_sizes': [4, 8],
|
||||
'splitting_ops': [],
|
||||
'pass_config': {
|
||||
'enable_sequence_parallelism': sp_enabled,
|
||||
'enable_sequence_parallelism': True,
|
||||
'enable_fusion': enable_fusion,
|
||||
'enable_noop': True,
|
||||
'enable_fusion': True,
|
||||
},
|
||||
}
|
||||
|
||||
@ -291,12 +285,14 @@ def _compare_sp(
|
||||
SP_TEXT_GENERATION_MODELS = {
|
||||
# [Decoder-only]
|
||||
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
|
||||
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(),
|
||||
}
|
||||
|
||||
SP_TEST_MODELS = [
|
||||
# TODO support other models
|
||||
# [LANGUAGE GENERATION]
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -193,7 +193,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501
|
||||
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
|
||||
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
|
||||
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B"}), # noqa: E501
|
||||
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501
|
||||
"fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501
|
||||
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
|
||||
is_available_online=False),
|
||||
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
|
||||
|
||||
@ -345,8 +345,8 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
# 0 is always None
|
||||
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
|
||||
self.insert_fused_node(fused_return_mapping,
|
||||
epsilon=rms_node.kwargs["epsilon"],
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
epsilon=rms_node.kwargs["epsilon"])
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
@ -51,15 +51,15 @@ class PostGradPassManager(CustomGraphPass):
|
||||
if self.pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [FusionPass.instance(config)]
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [FusionPass.instance(config)]
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
|
||||
@ -12,91 +12,142 @@ from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AllReduceRMSNormPattern:
|
||||
class _RMSNormAndQuantOpHelper:
|
||||
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
**kwargs):
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.quant_op = quant_op
|
||||
|
||||
def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
|
||||
return torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.rms_norm.default,
|
||||
result=result_buffer,
|
||||
input=input_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor,
|
||||
weight_tensor):
|
||||
return torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=input_tensor,
|
||||
residual=residual_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer,
|
||||
quant_result_buffer, input_tensor,
|
||||
weight_tensor, scale_tensor):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer,
|
||||
input_tensor,
|
||||
weight_tensor)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor)
|
||||
return quant_out_tuple
|
||||
|
||||
def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer,
|
||||
input_tensor, residual_tensor,
|
||||
weight_tensor, scale_tensor):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
|
||||
input_tensor, residual_tensor, weight_tensor)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor)
|
||||
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
|
||||
|
||||
|
||||
class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
**kwargs):
|
||||
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.reduce_scatter.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp_group.unique_name)
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp_group.unique_name)
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]],
|
||||
device=self.device,
|
||||
dtype=torch.long)
|
||||
unsqueeze = torch.rand([1, 8, 1], device=self.device, \
|
||||
dtype=self.dtype) > 0.5
|
||||
full_default = torch.zeros([1, 8, 4], device=self.device, \
|
||||
dtype=self.dtype)
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1]
|
||||
return [input, permute, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
arg2_1: torch.Tensor,
|
||||
mul_6: torch.Tensor,
|
||||
unsqueeze: torch.Tensor,
|
||||
full_default: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
|
||||
where = torch.ops.aten.where.self(unsqueeze, full_default,
|
||||
embedding)
|
||||
all_reduce = tensor_model_parallel_all_reduce(where)
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.rms_norm.default,
|
||||
result=permute,
|
||||
input=all_reduce,
|
||||
weight=arg3_1,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
all_reduce = self._all_reduce(input)
|
||||
rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
|
||||
|
||||
return rmsnorm[1], all_reduce
|
||||
|
||||
def replacement(
|
||||
arg2_1: torch.Tensor,
|
||||
mul_6: torch.Tensor,
|
||||
unsqueeze: torch.Tensor,
|
||||
full_default: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
|
||||
where = torch.ops.aten.where.self(unsqueeze, full_default,
|
||||
embedding)
|
||||
|
||||
tp = get_tp_group()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
where, dim=0, world_size=tp_size, group_name=tp.unique_name)
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(reduce_scatter)
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.rms_norm.default,
|
||||
result=rmsnorm_result,
|
||||
input=reduce_scatter,
|
||||
weight=arg3_1,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter,
|
||||
arg3_1)
|
||||
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
rmsnorm[1],
|
||||
dim=0,
|
||||
world_size=tp_size,
|
||||
group_name=tp.unique_name)
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
@ -104,7 +155,7 @@ class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
@ -127,16 +178,9 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
||||
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=all_reduce,
|
||||
residual=residual,
|
||||
weight=rms_norm_weights,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights)
|
||||
return rmsnorm[1], rmsnorm[2]
|
||||
|
||||
def replacement(
|
||||
@ -144,32 +188,17 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
tp = get_tp_group()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
|
||||
|
||||
# TODO is it possible to extract epsilon from somewhere
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=reduce_scatter,
|
||||
residual=residual,
|
||||
weight=rms_norm_weights,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
rmsnorm[1],
|
||||
dim=0,
|
||||
world_size=tp_size,
|
||||
group_name=tp.unique_name)
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights)
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
return all_gather, rmsnorm[2]
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
@ -192,16 +221,9 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
||||
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=all_reduce,
|
||||
residual=residual,
|
||||
weight=rms_norm_weights,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights)
|
||||
return rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
@ -209,26 +231,185 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
tp = get_tp_group()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights)
|
||||
normalized = self._all_gather(rmsnorm[1])
|
||||
return normalized
|
||||
|
||||
# TODO is it possible to extract epsilon from somewhere
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=reduce_scatter,
|
||||
residual=residual,
|
||||
weight=rms_norm_weights,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
normalized = torch.ops.vllm.all_gather.default(
|
||||
rmsnorm[1],
|
||||
dim=0,
|
||||
world_size=tp_size,
|
||||
group_name=tp.unique_name)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
rmsnorm_result = torch.empty([1, 8, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
quant_result = torch.empty([1, 8, 4],
|
||||
device=self.device,
|
||||
dtype=FP8_DTYPE)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, rmsnorm_result, quant_result, weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, all_reduce, weight, scale)
|
||||
return static_fp8[1], all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(reduce_scatter,
|
||||
dtype=rmsnorm_result.dtype)
|
||||
quant_result = torch.empty_like(
|
||||
rmsnorm_result, # Output of RMSNorm
|
||||
dtype=quant_result.dtype)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, reduce_scatter, weight, scale)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
result,
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
result, all_reduce, residual, rms_norm_weights, scale)
|
||||
return static_fp8[1], rmsnorm_residual_out
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter,
|
||||
dtype=result.dtype)
|
||||
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
|
||||
scale)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
return all_gather, rmsnorm_residual_out
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
result,
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
result, all_reduce, residual, rms_norm_weights, scale)
|
||||
return static_fp8[1]
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter,
|
||||
dtype=result.dtype)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
|
||||
scale)
|
||||
normalized = self._all_gather(static_fp8[1])
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
@ -236,21 +417,54 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
|
||||
|
||||
class SequenceParallelismPass(VllmInductorPass):
|
||||
"""
|
||||
This pass enables sequence parallelism for models.
|
||||
It identifies patterns where an AllReduce operation is followed by
|
||||
an RMSNorm (or RMSNorm and then Quantization) operation.
|
||||
These patterns are replaced with a ReduceScatter operation, followed by
|
||||
a local RMSNorm/Quantization, and then an AllGather operation.
|
||||
|
||||
The general transformation is:
|
||||
Input -> AllReduce -> RMSNorm -> Output
|
||||
becomes
|
||||
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
|
||||
|
||||
While this pass itself does not directly yield performance improvements,
|
||||
it lays the groundwork for subsequent fusion passes, such as
|
||||
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
|
||||
significantly reduce communication overhead and improve overall model
|
||||
performance.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="sequence_parallelism_pass")
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
EmbeddingAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device).register(self.patterns)
|
||||
# RMSNorm + Static FP8 quantization patterns
|
||||
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
|
||||
FirstAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
MiddleAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
LastAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
|
||||
# Normal RMSNorm patterns
|
||||
FirstAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||
|
||||
@ -3802,11 +3802,11 @@ class PassConfig:
|
||||
its own stages (before, after, maybe in-between)."""
|
||||
dump_graph_dir: Path = Path(".")
|
||||
"""Directory to dump the graphs."""
|
||||
enable_fusion: bool = True
|
||||
enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1)
|
||||
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
|
||||
enable_attn_fusion: bool = False
|
||||
"""Whether to enable the custom attention+quant fusion pass."""
|
||||
enable_noop: bool = True
|
||||
enable_noop: bool = field(default_factory=lambda: not envs.VLLM_USE_V1)
|
||||
"""Whether to enable the custom no-op elimination pass."""
|
||||
enable_sequence_parallelism: bool = False
|
||||
"""Whether to enable sequence parallelism."""
|
||||
@ -4451,8 +4451,6 @@ class VllmConfig:
|
||||
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
|
||||
# is set to True, full CUDA graphs will be used.
|
||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||
self.compilation_config.pass_config.enable_fusion = False
|
||||
self.compilation_config.pass_config.enable_noop = False
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user