mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-20 17:03:34 +08:00
[Feature] support sequence parallelism using compilation pass (#16155)
Signed-off-by: cascade812 <cascade812@outlook.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
ed7a29d9f8
commit
690fe019f0
@ -299,6 +299,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
source_file_dependencies:
|
||||
@ -583,6 +584,8 @@ steps:
|
||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||
# test sequence parallel
|
||||
- pytest -v -s distributed/test_sequence_parallel.py
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
@ -49,13 +49,15 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
do_fusion: bool):
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
||||
enable_noop=True)
|
||||
noop_pass = NoOpEliminationPass(config)
|
||||
fusion_pass = FusionPass.instance(config)
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(pass_config= \
|
||||
CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
||||
enable_noop=True))
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = FusionPass.instance(vllm_config)
|
||||
|
||||
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
|
||||
func_pass = FixFunctionalizationPass(config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
|
||||
|
||||
@ -77,12 +77,13 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
|
||||
vllm_config.compilation_config.pass_config = \
|
||||
CompilationConfig.PassConfig(enable_fusion=True,
|
||||
enable_noop=True)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
config = CompilationConfig.PassConfig(enable_fusion=True,
|
||||
enable_noop=True)
|
||||
noop_pass = NoOpEliminationPass(config)
|
||||
fusion_pass = FusionPass.instance(config)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = FusionPass.instance(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass)
|
||||
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.compilation.pass_manager import PostGradPassManager
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
# dummy custom pass that doesn't inherit
|
||||
@ -16,7 +16,7 @@ def simple_callable(graph: torch.fx.Graph):
|
||||
|
||||
# Should fail to add directly to the pass manager
|
||||
def test_bad_callable():
|
||||
config = CompilationConfig().pass_config
|
||||
config = VllmConfig()
|
||||
|
||||
pass_manager = PostGradPassManager()
|
||||
pass_manager.configure(config)
|
||||
@ -43,7 +43,7 @@ class ProperPass(InductorPass):
|
||||
],
|
||||
)
|
||||
def test_pass_manager_uuid(callable):
|
||||
config = CompilationConfig().pass_config
|
||||
config = VllmConfig()
|
||||
|
||||
pass_manager = PostGradPassManager()
|
||||
pass_manager.configure(config)
|
||||
@ -64,7 +64,8 @@ def test_pass_manager_uuid(callable):
|
||||
|
||||
# UUID should be different due to config change
|
||||
config2 = copy.deepcopy(config)
|
||||
config2.enable_fusion = not config2.enable_fusion
|
||||
config2.compilation_config.pass_config.enable_fusion = not \
|
||||
config2.compilation_config.pass_config.enable_fusion
|
||||
pass_manager3 = PostGradPassManager()
|
||||
pass_manager3.configure(config2)
|
||||
pass_manager3.add(callable)
|
||||
|
||||
190
tests/compile/test_sequence_parallelism.py
Normal file
190
tests/compile/test_sequence_parallelism.py
Normal file
@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
|
||||
find_specified_fn,
|
||||
find_specified_fn_maybe, is_func)
|
||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||
VllmConfig)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
from .backend import TestBackend
|
||||
|
||||
OPS_IN_MODEL_BEFORE = [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
]
|
||||
|
||||
OPS_IN_MODEL_AFTER = [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
torch.ops.vllm.all_gather.default,
|
||||
]
|
||||
|
||||
OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default]
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size)))
|
||||
self.norm = RMSNorm(hidden_size, 1e-05)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
"""
|
||||
Forward pass implementing the operations in the FX graph
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor
|
||||
residual: Residual tensor from previous layer
|
||||
|
||||
Returns:
|
||||
Tuple containing the output tensor
|
||||
"""
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
|
||||
#matrix multiplication
|
||||
permute = self.gate_proj.permute(1, 0)
|
||||
mm = torch.mm(view, permute)
|
||||
|
||||
# Tensor parallel all-reduce
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm)
|
||||
|
||||
# layer normalization
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
|
||||
return norm_output, residual_output
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
def test_sequence_parallelism_pass(batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
# need to use torch.mp.spawn otherwise will have problems with
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(fn,
|
||||
args=(num_processes, batch_size, seq_len,
|
||||
hidden_size, dtype),
|
||||
nprocs=nprocs)
|
||||
|
||||
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=CompilationConfig.PassConfig(
|
||||
enable_sequence_parallelism=True, ), )
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||
vllm_config.model_config = ModelConfig(model=model,
|
||||
task="auto",
|
||||
tokenizer=model,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
seed=42)
|
||||
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
backend_no_func = TestBackend(sequence_parallelism_pass)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
backend_func = TestBackend(sequence_parallelism_pass, func_pass)
|
||||
|
||||
model = TestModel(hidden_size, hidden_size * 2)
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||
compiled_model_no_func(hidden_states, residual)
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
|
||||
# Check substitution worked
|
||||
pre_nodes = backend_no_func.graph_pre_pass.nodes
|
||||
post_nodes = backend_no_func.graph_post_pass.nodes
|
||||
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
for op in OPS_IN_MODEL_BEFORE:
|
||||
find_specified_fn(pre_nodes, op)
|
||||
for op in OPS_IN_MODEL_AFTER:
|
||||
assert find_specified_fn_maybe(pre_nodes, op) is None
|
||||
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
for op in OPS_IN_MODEL_AFTER:
|
||||
find_specified_fn(post_nodes, op)
|
||||
for op in OPS_IN_MODEL_BEFORE:
|
||||
assert find_specified_fn_maybe(post_nodes, op) is None
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in OPS_IN_MODEL:
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
|
||||
op) is None # noqa: E501
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in OPS_IN_MODEL:
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in OPS_IN_MODEL)
|
||||
@ -14,7 +14,8 @@ import torch
|
||||
|
||||
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
|
||||
from ..utils import init_test_distributed_environment, multi_process_parallel
|
||||
|
||||
@ -47,6 +48,34 @@ def all_reduce_test_worker(
|
||||
torch.testing.assert_close(t, expected)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int,
|
||||
pp_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
|
||||
num_elements = 8
|
||||
all_tensors = [
|
||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
||||
(r + 1) for r in range(tp_size)
|
||||
]
|
||||
|
||||
index = rank % tp_size
|
||||
partition_size = num_elements // tp_size
|
||||
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
|
||||
t = all_tensors[index]
|
||||
t = tensor_model_parallel_reduce_scatter(t, 0)
|
||||
torch.testing.assert_close(t, expected)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def all_gather_test_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
296
tests/distributed/test_sequence_parallel.py
Normal file
296
tests/distributed/test_sequence_parallel.py
Normal file
@ -0,0 +1,296 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
|
||||
important to set the distributed backend to "mp" to avoid Ray scheduling
|
||||
all workers in a node other than the head node, which can cause the test
|
||||
to fail.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, NamedTuple, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import TaskOption
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||
|
||||
logger = init_logger("test_sequence_parallel")
|
||||
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
sp_enabled: bool
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
|
||||
class SPTestOptions(NamedTuple):
|
||||
multi_node_only: bool
|
||||
load_format: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SPTestSettings:
|
||||
parallel_setups: list[ParallelSetup]
|
||||
# NOTE: the length of distributed_backends and
|
||||
# vllm_major_versions should be the same, and they
|
||||
# are first zipped together to iterate over all
|
||||
# test settings.
|
||||
distributed_backends: list[str]
|
||||
# vllm major version: "0" for V0, "1" for V1
|
||||
vllm_major_versions: list[str]
|
||||
task: TaskOption
|
||||
test_options: SPTestOptions
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.distributed_backends) != len(self.vllm_major_versions):
|
||||
raise ValueError(
|
||||
f"Length mismatch: distributed_backends "
|
||||
f"({len(self.distributed_backends)}) != "
|
||||
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
||||
|
||||
@staticmethod
|
||||
def detailed(
|
||||
*,
|
||||
tp_base: int = 2,
|
||||
multi_node_only: bool = False,
|
||||
task: TaskOption = "auto",
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
return SPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=True,
|
||||
chunked_prefill=True)
|
||||
],
|
||||
distributed_backends=["mp", "ray"],
|
||||
vllm_major_versions=["1", "1"],
|
||||
task=task,
|
||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fast(
|
||||
*,
|
||||
tp_base: int = 2,
|
||||
task: TaskOption = "auto",
|
||||
multi_node_only: bool = False,
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
return SPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
],
|
||||
distributed_backends=["mp", "ray"],
|
||||
vllm_major_versions=["1", "1"],
|
||||
task=task,
|
||||
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
)
|
||||
|
||||
def iter_params(self, model_id: str):
|
||||
opts = self.test_options
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for backend, vllm_major_version in zip(self.distributed_backends,
|
||||
self.vllm_major_versions):
|
||||
yield (model_id, parallel_setup, backend, vllm_major_version,
|
||||
self.task, opts)
|
||||
|
||||
|
||||
def _compare_sp(
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
task: TaskOption,
|
||||
test_options: SPTestOptions,
|
||||
num_gpus_available: int,
|
||||
*,
|
||||
method: Literal["generate", "encode"],
|
||||
is_multimodal: bool,
|
||||
):
|
||||
(
|
||||
tp_size,
|
||||
sp_enabled,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
|
||||
multi_node_only, load_format = test_options
|
||||
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
trust_remote_code = model_info.trust_remote_code
|
||||
tokenizer_mode = model_info.tokenizer_mode
|
||||
hf_overrides = model_info.hf_overrides
|
||||
|
||||
if load_format == "dummy":
|
||||
# Avoid OOM
|
||||
text_overrides = {
|
||||
"num_hidden_layers": 4,
|
||||
"hidden_size": 512,
|
||||
"intermediate_size": 800,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 1,
|
||||
}
|
||||
|
||||
if is_multimodal:
|
||||
hf_overrides.update({"text_config": text_overrides})
|
||||
else:
|
||||
hf_overrides.update(text_overrides)
|
||||
else:
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
pp_size = 1
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend")
|
||||
if multi_node_only and not VLLM_MULTI_NODE:
|
||||
pytest.skip("Not in multi-node setting")
|
||||
|
||||
common_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"8",
|
||||
]
|
||||
if chunked_prefill:
|
||||
common_args.append("--enable-chunked-prefill")
|
||||
if eager_mode:
|
||||
common_args.append("--enforce-eager")
|
||||
if task != "auto":
|
||||
common_args.extend(["--task", task])
|
||||
if trust_remote_code:
|
||||
common_args.append("--trust-remote-code")
|
||||
if tokenizer_mode:
|
||||
common_args.extend(["--tokenizer-mode", tokenizer_mode])
|
||||
if load_format:
|
||||
common_args.extend(["--load-format", load_format])
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
|
||||
compilation_config = {
|
||||
'level': 3,
|
||||
'custom_ops': ["+rms_norm"],
|
||||
'compile_sizes': [4, 8],
|
||||
'splitting_ops': [],
|
||||
'pass_config': {
|
||||
'enable_sequence_parallism': sp_enabled,
|
||||
'enable_noop': True,
|
||||
'enable_fusion': True,
|
||||
},
|
||||
}
|
||||
|
||||
tp_sp_env = tp_env = {
|
||||
"VLLM_USE_V1": vllm_major_version,
|
||||
}
|
||||
|
||||
tp_sp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
"--compilation_config",
|
||||
str(compilation_config),
|
||||
]
|
||||
|
||||
tp_env = {
|
||||
"VLLM_USE_V1": vllm_major_version,
|
||||
}
|
||||
tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
|
||||
try:
|
||||
compare_two_settings(model_id,
|
||||
tp_sp_args,
|
||||
tp_args,
|
||||
tp_sp_env,
|
||||
tp_env,
|
||||
method=method)
|
||||
except Exception:
|
||||
testing_ray_compiled_graph = tp_sp_env is not None
|
||||
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||
# Ray Compiled Graph tests are flaky for V0,
|
||||
# so we don't want to fail the test
|
||||
logger.exception("Ray Compiled Graph tests failed")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
SP_TEXT_GENERATION_MODELS = {
|
||||
# [Decoder-only]
|
||||
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(),
|
||||
}
|
||||
|
||||
SP_TEST_MODELS = [
|
||||
# TODO support other models
|
||||
# [LANGUAGE GENERATION]
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||
"task", "test_options"),
|
||||
[
|
||||
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
|
||||
for params in settings.iter_params(model_id)
|
||||
if model_id in SP_TEST_MODELS
|
||||
],
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
def test_tp_sp_generation(
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
task: TaskOption,
|
||||
test_options: SPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_sp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
task,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
is_multimodal=False)
|
||||
@ -339,7 +339,7 @@ class VllmBackend:
|
||||
|
||||
def configure_post_pass(self):
|
||||
config = self.compilation_config
|
||||
self.post_grad_pass_manager.configure(config.pass_config)
|
||||
self.post_grad_pass_manager.configure(self.vllm_config)
|
||||
|
||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||
|
||||
@ -15,6 +15,8 @@ import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
from .inductor_pass import pass_context
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
@ -312,11 +314,12 @@ class InductorAdaptor(CompilerInterface):
|
||||
torch._functorch.config.patch(
|
||||
enable_remote_autograd_cache=False))
|
||||
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config)
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config)
|
||||
|
||||
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
|
||||
# compilation cache. So turn off the checks if we disable the
|
||||
|
||||
@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -531,7 +531,7 @@ class FusionPass(VllmInductorPass):
|
||||
_instance: 'Optional[FusionPass]' = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls, config: CompilationConfig.PassConfig):
|
||||
def instance(cls, config: VllmConfig):
|
||||
"""
|
||||
Get the singleton instance of the FusionPass.
|
||||
If the instance exists, the config is updated but
|
||||
@ -540,10 +540,10 @@ class FusionPass(VllmInductorPass):
|
||||
if cls._instance is None:
|
||||
cls._instance = FusionPass(config)
|
||||
else:
|
||||
cls._instance.config = config
|
||||
cls._instance.pass_config = config.compilation_config.pass_config
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: CompilationConfig.PassConfig):
|
||||
def __init__(self, config: VllmConfig):
|
||||
assert self.__class__._instance is None, \
|
||||
"FusionPass singleton instance already exists"
|
||||
super().__init__(config)
|
||||
|
||||
@ -12,6 +12,22 @@ def is_func(node: fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
|
||||
|
||||
# Returns the first specified node with the given op (if it exists)
|
||||
def find_specified_fn_maybe(nodes: Iterable[fx.Node],
|
||||
op: OpOverload) -> Optional[fx.Node]:
|
||||
for node in nodes:
|
||||
if node.target == op:
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first specified node with the given op
|
||||
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_specified_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
|
||||
op: OpOverload) -> Optional[fx.Node]:
|
||||
|
||||
@ -4,6 +4,7 @@ import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -18,6 +19,34 @@ else:
|
||||
from .torch25_custom_graph_pass import ( # noqa: yapf
|
||||
Torch25CustomGraphPass as CustomGraphPass)
|
||||
|
||||
_pass_context = None
|
||||
|
||||
|
||||
class PassContext:
|
||||
|
||||
def __init__(self, runtime_shape: Optional[int]):
|
||||
self.runtime_shape = runtime_shape
|
||||
|
||||
|
||||
def get_pass_context() -> PassContext:
|
||||
"""Get the current pass context."""
|
||||
assert _pass_context is not None
|
||||
return _pass_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pass_context(runtime_shape: Optional[int]):
|
||||
"""A context manager that stores the current pass context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _pass_context
|
||||
prev_context = _pass_context
|
||||
_pass_context = PassContext(runtime_shape)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_pass_context = prev_context
|
||||
|
||||
|
||||
class InductorPass(CustomGraphPass):
|
||||
"""
|
||||
@ -62,6 +91,9 @@ class InductorPass(CustomGraphPass):
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]):
|
||||
return True
|
||||
|
||||
|
||||
class CallableInductorPass(InductorPass):
|
||||
"""
|
||||
|
||||
@ -4,13 +4,15 @@ from typing import List
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .fusion import FusionPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .sequence_parallelism import SequenceParallelismPass
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -31,24 +33,29 @@ class PostGradPassManager(CustomGraphPass):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.passes: List[InductorPass] = []
|
||||
self.passes: List[VllmInductorPass] = []
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
shape = get_pass_context().runtime_shape
|
||||
for pass_ in self.passes:
|
||||
pass_(graph)
|
||||
if pass_.is_applicable_for_shape(shape):
|
||||
pass_(graph)
|
||||
|
||||
# always run fix_functionalization last
|
||||
self.fix_functionalization(graph)
|
||||
|
||||
def configure(self, pass_config: CompilationConfig.PassConfig):
|
||||
self.pass_config = pass_config
|
||||
if pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(pass_config)]
|
||||
def configure(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
if self.pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if pass_config.enable_fusion:
|
||||
self.passes += [FusionPass.instance(pass_config)]
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [FusionPass.instance(config)]
|
||||
|
||||
self.fix_functionalization = FixFunctionalizationPass(pass_config)
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
assert isinstance(pass_, InductorPass)
|
||||
|
||||
266
vllm/compilation/sequence_parallelism.py
Normal file
266
vllm/compilation/sequence_parallelism.py
Normal file
@ -0,0 +1,266 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AllReduceRMSNormPattern:
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
|
||||
def get_inputs(self):
|
||||
arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]],
|
||||
device=self.device,
|
||||
dtype=torch.long)
|
||||
unsqueeze = torch.rand([1, 8, 1], device=self.device, \
|
||||
dtype=self.dtype) > 0.5
|
||||
full_default = torch.zeros([1, 8, 4], device=self.device, \
|
||||
dtype=self.dtype)
|
||||
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
arg2_1: torch.Tensor,
|
||||
mul_6: torch.Tensor,
|
||||
unsqueeze: torch.Tensor,
|
||||
full_default: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
|
||||
where = torch.ops.aten.where.self(unsqueeze, full_default,
|
||||
embedding)
|
||||
all_reduce = tensor_model_parallel_all_reduce(where)
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.rms_norm.default,
|
||||
result=permute,
|
||||
input=all_reduce,
|
||||
weight=arg3_1,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return rmsnorm[1], all_reduce
|
||||
|
||||
def replacement(
|
||||
arg2_1: torch.Tensor,
|
||||
mul_6: torch.Tensor,
|
||||
unsqueeze: torch.Tensor,
|
||||
full_default: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
|
||||
where = torch.ops.aten.where.self(unsqueeze, full_default,
|
||||
embedding)
|
||||
|
||||
tp = get_tp_group()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
where, dim=0, world_size=tp_size, group_name=tp.unique_name)
|
||||
|
||||
rmsnorm_result = torch.empty_like(reduce_scatter)
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.rms_norm.default,
|
||||
result=rmsnorm_result,
|
||||
input=reduce_scatter,
|
||||
weight=arg3_1,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
rmsnorm[1],
|
||||
dim=0,
|
||||
world_size=tp_size,
|
||||
group_name=tp.unique_name)
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
||||
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=all_reduce,
|
||||
residual=residual,
|
||||
weight=rms_norm_weights,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return rmsnorm[1], rmsnorm[2]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
tp = get_tp_group()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
|
||||
|
||||
# TODO is it possible to extract epsilon from somewhere
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=reduce_scatter,
|
||||
residual=residual,
|
||||
weight=rms_norm_weights,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
rmsnorm[1],
|
||||
dim=0,
|
||||
world_size=tp_size,
|
||||
group_name=tp.unique_name)
|
||||
return all_gather, rmsnorm[2]
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
||||
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=all_reduce,
|
||||
residual=residual,
|
||||
weight=rms_norm_weights,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
tp = get_tp_group()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
|
||||
|
||||
# TODO is it possible to extract epsilon from somewhere
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=reduce_scatter,
|
||||
residual=residual,
|
||||
weight=rms_norm_weights,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
normalized = torch.ops.vllm.all_gather.default(
|
||||
rmsnorm[1],
|
||||
dim=0,
|
||||
world_size=tp_size,
|
||||
group_name=tp.unique_name)
|
||||
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SequenceParallelismPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="sequence_parallelism_pass")
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
EmbeddingAllReduceRMSNormPattern(
|
||||
epsilon, self.dtype, self.device).register(self.patterns)
|
||||
|
||||
MiddleAllReduceRMSNormPattern(epsilon, self.dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
LastAllReduceRMSNormPattern(epsilon, self.dtype,
|
||||
self.device).register(self.patterns)
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
# only do replace for specific shapes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.dump_graph(graph, "before_sequence_parallelism_pass")
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", count)
|
||||
self.dump_graph(graph, "after_sequence_parallelism_pass")
|
||||
@ -4,7 +4,7 @@ import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
# yapf: disable
|
||||
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
||||
from vllm.distributed import (
|
||||
@ -24,16 +24,19 @@ class VllmInductorPass(InductorPass):
|
||||
It provides timing, logging, and dumping utilities.
|
||||
"""
|
||||
|
||||
def __init__(self, config: CompilationConfig.PassConfig):
|
||||
self.config = config
|
||||
def __init__(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config \
|
||||
else None
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
|
||||
if stage in self.config.dump_graph_stages or always:
|
||||
if stage in self.pass_config.dump_graph_stages or always:
|
||||
# Make sure filename includes rank in the distributed setting
|
||||
parallel = p_is_init() and get_tp_world_size() > 1
|
||||
rank = f"-{get_tp_rank()}" if parallel else ""
|
||||
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
|
||||
filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py"
|
||||
|
||||
logger.info("%s printing graph to %s", self.pass_name, filepath)
|
||||
with open(filepath, "w") as f:
|
||||
|
||||
@ -3405,11 +3405,13 @@ class CompilationConfig(BaseModel):
|
||||
- enable_fusion: whether to enable the custom fusion pass.
|
||||
- enable_noop: whether to enable the custom no-op elimination pass.
|
||||
TODO(luka) better pass enabling system.
|
||||
- enable_sequence_parallelism: whether to enable sequence parallelism.
|
||||
"""
|
||||
dump_graph_stages: list[str] = Field(default_factory=list)
|
||||
dump_graph_dir: Path = Field(default=Path("."))
|
||||
enable_fusion: bool = True
|
||||
enable_noop: bool = True
|
||||
enable_sequence_parallelism: bool = False
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
@ -3418,7 +3420,8 @@ class CompilationConfig(BaseModel):
|
||||
Do not include dump_graph_* in the hash - they don't affect
|
||||
compilation.
|
||||
"""
|
||||
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
|
||||
dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \
|
||||
"enable_sequence_parallelism"})
|
||||
return InductorPass.hash_dict(dict_)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
@ -3840,6 +3843,8 @@ class VllmConfig:
|
||||
|
||||
if self.compilation_config is None:
|
||||
self.compilation_config = CompilationConfig()
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
if envs.VLLM_USE_V1 and self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
# NOTE(woosuk): Currently, we use inductor because the piecewise
|
||||
@ -3847,7 +3852,8 @@ class VllmConfig:
|
||||
# FIXME(woosuk): Disable inductor to reduce the compilation time
|
||||
# and avoid any potential issues with the inductor.
|
||||
# FIXME(rob): Add function to set all of these.
|
||||
self.compilation_config.custom_ops = ["none"]
|
||||
if not self.compilation_config.custom_ops:
|
||||
self.compilation_config.custom_ops = ["none"]
|
||||
self.compilation_config.use_cudagraph = True
|
||||
self.compilation_config.use_inductor = True
|
||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||
@ -3856,6 +3862,18 @@ class VllmConfig:
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
if self.parallel_config is not None and \
|
||||
self.parallel_config.tensor_parallel_size > 1 and \
|
||||
self.parallel_config.pipeline_parallel_size > 1 and \
|
||||
self.compilation_config is not None and \
|
||||
self.compilation_config.pass_config is not None and \
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
logger.warning_once(
|
||||
"Sequence parallelism is not supported with pipeline "
|
||||
"parallelism. Disabling sequence parallelism.")
|
||||
self.compilation_config.pass_config.\
|
||||
enable_sequence_parallelism = False
|
||||
|
||||
self._set_cudagraph_sizes()
|
||||
|
||||
if self.cache_config is not None and \
|
||||
@ -3895,6 +3913,26 @@ class VllmConfig:
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
def update_sizes_for_sequence_parallelism(self,
|
||||
possible_sizes: list) -> list:
|
||||
# remove the sizes that not multiple of tp_size when
|
||||
# enable sequence parallelism
|
||||
removed_sizes = [
|
||||
size for size in possible_sizes
|
||||
if size % self.parallel_config.tensor_parallel_size != 0
|
||||
]
|
||||
if removed_sizes:
|
||||
logger.warning(
|
||||
"Batch sizes %s are removed because they are not "
|
||||
"multiple of tp_size %d when "
|
||||
"sequence parallelism is enabled", removed_sizes,
|
||||
self.parallel_config.tensor_parallel_size)
|
||||
|
||||
return [
|
||||
size for size in possible_sizes
|
||||
if size % self.parallel_config.tensor_parallel_size == 0
|
||||
]
|
||||
|
||||
def _set_cudagraph_sizes(self):
|
||||
"""
|
||||
cudagraph batchsize padding logic:
|
||||
@ -3932,6 +3970,11 @@ class VllmConfig:
|
||||
not self.model_config.enforce_eager:
|
||||
|
||||
possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
|
||||
if self.parallel_config.tensor_parallel_size > 1 and \
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
possible_sizes = self.update_sizes_for_sequence_parallelism(
|
||||
possible_sizes)
|
||||
|
||||
# find the minimum size that is larger than max_num_seqs,
|
||||
# which then becomes the max_batchsize_to_capture
|
||||
larger_sizes = [
|
||||
@ -3955,6 +3998,11 @@ class VllmConfig:
|
||||
not self.model_config.enforce_eager:
|
||||
batch_size_capture_list = [1, 2, 4
|
||||
] + [i for i in range(8, 513, 8)]
|
||||
if self.parallel_config.tensor_parallel_size > 1 and \
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
batch_size_capture_list = \
|
||||
self.update_sizes_for_sequence_parallelism(batch_size_capture_list)
|
||||
|
||||
max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
batch_size_capture_list = [
|
||||
size for size in batch_size_capture_list
|
||||
|
||||
@ -19,6 +19,12 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
|
||||
return get_tp_group().all_gather(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""Reduce-Scatter the input tensor across model parallel group."""
|
||||
return get_tp_group().reduce_scatter(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_gather(input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
|
||||
@ -61,6 +61,40 @@ class DeviceCommunicatorBase:
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
def reduce_scatter(self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||
|
||||
output_tensor = torch.empty(output_shape,
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
|
||||
# Perform reduce-scatter operation
|
||||
torch.distributed.reduce_scatter_tensor(output_tensor,
|
||||
input_tensor,
|
||||
group=self.device_group)
|
||||
|
||||
# Reshape before returning
|
||||
return output_tensor.movedim(0, dim).contiguous()
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
|
||||
@ -70,6 +70,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||
|
||||
output = torch.empty(output_shape,
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
|
||||
pynccl_comm.reduce_scatter(output, input_)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
|
||||
@ -113,6 +113,38 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
return torch.empty_like(tensor)
|
||||
|
||||
|
||||
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
group_name: str) -> torch.Tensor:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group.reduce_scatter(tensor, dim)
|
||||
|
||||
|
||||
def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
group_name: str) -> torch.Tensor:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[dim] = tensor.shape[dim] // world_size
|
||||
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
|
||||
def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
group_name: str) -> torch.Tensor:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group.all_gather(tensor, dim)
|
||||
|
||||
|
||||
def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
group_name: str) -> torch.Tensor:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[dim] = tensor.shape[dim] * world_size
|
||||
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
from vllm.platforms import current_platform
|
||||
direct_register_custom_op(
|
||||
@ -123,6 +155,20 @@ if supports_custom_op():
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="reduce_scatter",
|
||||
op_func=reduce_scatter,
|
||||
mutates_args=[],
|
||||
fake_impl=reduce_scatter_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="all_gather",
|
||||
op_func=all_gather,
|
||||
mutates_args=[],
|
||||
fake_impl=all_gather_fake,
|
||||
)
|
||||
|
||||
|
||||
class GroupCoordinator:
|
||||
"""
|
||||
@ -322,6 +368,18 @@ class GroupCoordinator:
|
||||
|
||||
return self.device_communicator.all_gather(input_, dim)
|
||||
|
||||
def reduce_scatter(self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
return self.device_communicator.reduce_scatter(input_, dim)
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
|
||||
@ -1027,7 +1027,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens)
|
||||
else:
|
||||
# Eager mode.
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if self.vllm_config.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism and tp_size > 1:
|
||||
from vllm.utils import round_up
|
||||
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
||||
else:
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
attn_metadata.num_input_tokens = num_input_tokens
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user