mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 23:37:54 +08:00
[Feature] support sequence parallelism using compilation pass (#16155)
Signed-off-by: cascade812 <cascade812@outlook.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
ed7a29d9f8
commit
690fe019f0
@ -299,6 +299,7 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s compile/test_pass_manager.py
|
- pytest -v -s compile/test_pass_manager.py
|
||||||
- pytest -v -s compile/test_fusion.py
|
- pytest -v -s compile/test_fusion.py
|
||||||
|
- pytest -v -s compile/test_sequence_parallelism.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -583,6 +584,8 @@ steps:
|
|||||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
||||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||||
|
# test sequence parallel
|
||||||
|
- pytest -v -s distributed/test_sequence_parallel.py
|
||||||
# this test fails consistently.
|
# this test fails consistently.
|
||||||
# TODO: investigate and fix
|
# TODO: investigate and fix
|
||||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
|||||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig, VllmConfig
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
@ -49,13 +49,15 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
|||||||
do_fusion: bool):
|
do_fusion: bool):
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
vllm_config = VllmConfig()
|
||||||
enable_noop=True)
|
vllm_config.compilation_config = CompilationConfig(pass_config= \
|
||||||
noop_pass = NoOpEliminationPass(config)
|
CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
||||||
fusion_pass = FusionPass.instance(config)
|
enable_noop=True))
|
||||||
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
fusion_pass = FusionPass.instance(vllm_config)
|
||||||
|
|
||||||
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
|
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
|
||||||
func_pass = FixFunctionalizationPass(config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
backend_func = TestBackend(*passes, func_pass)
|
backend_func = TestBackend(*passes, func_pass)
|
||||||
backend_no_func = TestBackend(*passes)
|
backend_no_func = TestBackend(*passes)
|
||||||
|
|
||||||
|
|||||||
@ -77,12 +77,13 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
|||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
|
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
|
||||||
|
vllm_config.compilation_config.pass_config = \
|
||||||
|
CompilationConfig.PassConfig(enable_fusion=True,
|
||||||
|
enable_noop=True)
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
config = CompilationConfig.PassConfig(enable_fusion=True,
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
enable_noop=True)
|
fusion_pass = FusionPass.instance(vllm_config)
|
||||||
noop_pass = NoOpEliminationPass(config)
|
|
||||||
fusion_pass = FusionPass.instance(config)
|
|
||||||
|
|
||||||
backend = TestBackend(noop_pass, fusion_pass)
|
backend = TestBackend(noop_pass, fusion_pass)
|
||||||
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
|
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||||
from vllm.compilation.pass_manager import PostGradPassManager
|
from vllm.compilation.pass_manager import PostGradPassManager
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
|
||||||
# dummy custom pass that doesn't inherit
|
# dummy custom pass that doesn't inherit
|
||||||
@ -16,7 +16,7 @@ def simple_callable(graph: torch.fx.Graph):
|
|||||||
|
|
||||||
# Should fail to add directly to the pass manager
|
# Should fail to add directly to the pass manager
|
||||||
def test_bad_callable():
|
def test_bad_callable():
|
||||||
config = CompilationConfig().pass_config
|
config = VllmConfig()
|
||||||
|
|
||||||
pass_manager = PostGradPassManager()
|
pass_manager = PostGradPassManager()
|
||||||
pass_manager.configure(config)
|
pass_manager.configure(config)
|
||||||
@ -43,7 +43,7 @@ class ProperPass(InductorPass):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_pass_manager_uuid(callable):
|
def test_pass_manager_uuid(callable):
|
||||||
config = CompilationConfig().pass_config
|
config = VllmConfig()
|
||||||
|
|
||||||
pass_manager = PostGradPassManager()
|
pass_manager = PostGradPassManager()
|
||||||
pass_manager.configure(config)
|
pass_manager.configure(config)
|
||||||
@ -64,7 +64,8 @@ def test_pass_manager_uuid(callable):
|
|||||||
|
|
||||||
# UUID should be different due to config change
|
# UUID should be different due to config change
|
||||||
config2 = copy.deepcopy(config)
|
config2 = copy.deepcopy(config)
|
||||||
config2.enable_fusion = not config2.enable_fusion
|
config2.compilation_config.pass_config.enable_fusion = not \
|
||||||
|
config2.compilation_config.pass_config.enable_fusion
|
||||||
pass_manager3 = PostGradPassManager()
|
pass_manager3 = PostGradPassManager()
|
||||||
pass_manager3.configure(config2)
|
pass_manager3.configure(config2)
|
||||||
pass_manager3.add(callable)
|
pass_manager3.add(callable)
|
||||||
|
|||||||
190
tests/compile/test_sequence_parallelism.py
Normal file
190
tests/compile/test_sequence_parallelism.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
|
from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
|
||||||
|
find_specified_fn,
|
||||||
|
find_specified_fn_maybe, is_func)
|
||||||
|
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||||
|
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||||
|
VllmConfig)
|
||||||
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
|
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||||
|
initialize_model_parallel)
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
from ..utils import multi_gpu_test
|
||||||
|
from .backend import TestBackend
|
||||||
|
|
||||||
|
OPS_IN_MODEL_BEFORE = [
|
||||||
|
torch.ops.vllm.all_reduce.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
OPS_IN_MODEL_AFTER = [
|
||||||
|
torch.ops.vllm.reduce_scatter.default,
|
||||||
|
torch.ops.vllm.all_gather.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default]
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.gate_proj = torch.nn.Parameter(
|
||||||
|
torch.empty((intermediate_size, hidden_size)))
|
||||||
|
self.norm = RMSNorm(hidden_size, 1e-05)
|
||||||
|
# Initialize weights
|
||||||
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, residual):
|
||||||
|
"""
|
||||||
|
Forward pass implementing the operations in the FX graph
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states: Input tensor
|
||||||
|
residual: Residual tensor from previous layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing the output tensor
|
||||||
|
"""
|
||||||
|
# Reshape input
|
||||||
|
view = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
|
||||||
|
#matrix multiplication
|
||||||
|
permute = self.gate_proj.permute(1, 0)
|
||||||
|
mm = torch.mm(view, permute)
|
||||||
|
|
||||||
|
# Tensor parallel all-reduce
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(mm)
|
||||||
|
|
||||||
|
# layer normalization
|
||||||
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
|
|
||||||
|
return norm_output, residual_output
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
|
@pytest.mark.parametrize("seq_len", [16])
|
||||||
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||||
|
reason="Only test on CUDA")
|
||||||
|
def test_sequence_parallelism_pass(batch_size: int, seq_len: int,
|
||||||
|
hidden_size: int, dtype: torch.dtype):
|
||||||
|
num_processes = 2
|
||||||
|
|
||||||
|
def run_torch_spawn(fn, nprocs):
|
||||||
|
# need to use torch.mp.spawn otherwise will have problems with
|
||||||
|
# torch.distributed and cuda
|
||||||
|
torch.multiprocessing.spawn(fn,
|
||||||
|
args=(num_processes, batch_size, seq_len,
|
||||||
|
hidden_size, dtype),
|
||||||
|
nprocs=nprocs)
|
||||||
|
|
||||||
|
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
|
||||||
|
|
||||||
|
|
||||||
|
def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
|
||||||
|
batch_size: int, seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype):
|
||||||
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
|
||||||
|
update_environment_variables({
|
||||||
|
'RANK': str(local_rank),
|
||||||
|
'LOCAL_RANK': str(local_rank),
|
||||||
|
'WORLD_SIZE': str(world_size),
|
||||||
|
'MASTER_ADDR': 'localhost',
|
||||||
|
'MASTER_PORT': '12345',
|
||||||
|
})
|
||||||
|
|
||||||
|
# initialize distributed
|
||||||
|
init_distributed_environment()
|
||||||
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
|
# configure vllm config for SequenceParallelismPass
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.compilation_config = CompilationConfig(
|
||||||
|
pass_config=CompilationConfig.PassConfig(
|
||||||
|
enable_sequence_parallelism=True, ), )
|
||||||
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
|
|
||||||
|
# this is a fake model name to construct the model config
|
||||||
|
# in the vllm_config, it's not really used.
|
||||||
|
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||||
|
vllm_config.model_config = ModelConfig(model=model,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype=dtype,
|
||||||
|
seed=42)
|
||||||
|
|
||||||
|
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||||
|
backend_no_func = TestBackend(sequence_parallelism_pass)
|
||||||
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
|
backend_func = TestBackend(sequence_parallelism_pass, func_pass)
|
||||||
|
|
||||||
|
model = TestModel(hidden_size, hidden_size * 2)
|
||||||
|
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||||
|
dtype=dtype)
|
||||||
|
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||||
|
|
||||||
|
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||||
|
compiled_model_no_func(hidden_states, residual)
|
||||||
|
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||||
|
compiled_model_func(hidden_states, residual)
|
||||||
|
|
||||||
|
# Check substitution worked
|
||||||
|
pre_nodes = backend_no_func.graph_pre_pass.nodes
|
||||||
|
post_nodes = backend_no_func.graph_post_pass.nodes
|
||||||
|
|
||||||
|
# In pre-nodes, all reduce should be there,
|
||||||
|
# reduce scatter and all gather should not
|
||||||
|
for op in OPS_IN_MODEL_BEFORE:
|
||||||
|
find_specified_fn(pre_nodes, op)
|
||||||
|
for op in OPS_IN_MODEL_AFTER:
|
||||||
|
assert find_specified_fn_maybe(pre_nodes, op) is None
|
||||||
|
|
||||||
|
# In post-nodes, reduce scatter and all gather should be there,
|
||||||
|
# all reduce should not
|
||||||
|
for op in OPS_IN_MODEL_AFTER:
|
||||||
|
find_specified_fn(post_nodes, op)
|
||||||
|
for op in OPS_IN_MODEL_BEFORE:
|
||||||
|
assert find_specified_fn_maybe(post_nodes, op) is None
|
||||||
|
|
||||||
|
# check if the functionalization pass is applied
|
||||||
|
for op in OPS_IN_MODEL:
|
||||||
|
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||||
|
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
|
||||||
|
op) is None # noqa: E501
|
||||||
|
|
||||||
|
# make sure the ops were all de-functionalized
|
||||||
|
found = dict()
|
||||||
|
for node in backend_func.graph_post_pass.nodes:
|
||||||
|
for op in OPS_IN_MODEL:
|
||||||
|
if is_func(node, op):
|
||||||
|
found[op] = True
|
||||||
|
assert all(found[op] for op in OPS_IN_MODEL)
|
||||||
@ -14,7 +14,8 @@ import torch
|
|||||||
|
|
||||||
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
|
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce,
|
||||||
|
tensor_model_parallel_reduce_scatter)
|
||||||
|
|
||||||
from ..utils import init_test_distributed_environment, multi_process_parallel
|
from ..utils import init_test_distributed_environment, multi_process_parallel
|
||||||
|
|
||||||
@ -47,6 +48,34 @@ def all_reduce_test_worker(
|
|||||||
torch.testing.assert_close(t, expected)
|
torch.testing.assert_close(t, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
|
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int,
|
||||||
|
pp_size: int, rank: int,
|
||||||
|
distributed_init_port: str):
|
||||||
|
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||||
|
# so that each worker can see all the GPUs
|
||||||
|
# they will be able to set the device to the correct GPU
|
||||||
|
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||||
|
distributed_init_port)
|
||||||
|
|
||||||
|
num_elements = 8
|
||||||
|
all_tensors = [
|
||||||
|
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
||||||
|
(r + 1) for r in range(tp_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
index = rank % tp_size
|
||||||
|
partition_size = num_elements // tp_size
|
||||||
|
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||||
|
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
|
||||||
|
t = all_tensors[index]
|
||||||
|
t = tensor_model_parallel_reduce_scatter(t, 0)
|
||||||
|
torch.testing.assert_close(t, expected)
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
def all_gather_test_worker(
|
def all_gather_test_worker(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
|||||||
296
tests/distributed/test_sequence_parallel.py
Normal file
296
tests/distributed/test_sequence_parallel.py
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||||
|
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
|
||||||
|
important to set the distributed backend to "mp" to avoid Ray scheduling
|
||||||
|
all workers in a node other than the head node, which can cause the test
|
||||||
|
to fail.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, NamedTuple, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import TaskOption
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from ..models.registry import HF_EXAMPLE_MODELS
|
||||||
|
from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||||
|
|
||||||
|
logger = init_logger("test_sequence_parallel")
|
||||||
|
|
||||||
|
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelSetup(NamedTuple):
|
||||||
|
tp_size: int
|
||||||
|
sp_enabled: bool
|
||||||
|
eager_mode: bool
|
||||||
|
chunked_prefill: bool
|
||||||
|
|
||||||
|
|
||||||
|
class SPTestOptions(NamedTuple):
|
||||||
|
multi_node_only: bool
|
||||||
|
load_format: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SPTestSettings:
|
||||||
|
parallel_setups: list[ParallelSetup]
|
||||||
|
# NOTE: the length of distributed_backends and
|
||||||
|
# vllm_major_versions should be the same, and they
|
||||||
|
# are first zipped together to iterate over all
|
||||||
|
# test settings.
|
||||||
|
distributed_backends: list[str]
|
||||||
|
# vllm major version: "0" for V0, "1" for V1
|
||||||
|
vllm_major_versions: list[str]
|
||||||
|
task: TaskOption
|
||||||
|
test_options: SPTestOptions
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if len(self.distributed_backends) != len(self.vllm_major_versions):
|
||||||
|
raise ValueError(
|
||||||
|
f"Length mismatch: distributed_backends "
|
||||||
|
f"({len(self.distributed_backends)}) != "
|
||||||
|
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def detailed(
|
||||||
|
*,
|
||||||
|
tp_base: int = 2,
|
||||||
|
multi_node_only: bool = False,
|
||||||
|
task: TaskOption = "auto",
|
||||||
|
load_format: Optional[str] = None,
|
||||||
|
):
|
||||||
|
return SPTestSettings(
|
||||||
|
parallel_setups=[
|
||||||
|
ParallelSetup(tp_size=tp_base,
|
||||||
|
sp_enabled=True,
|
||||||
|
eager_mode=False,
|
||||||
|
chunked_prefill=False),
|
||||||
|
ParallelSetup(tp_size=tp_base,
|
||||||
|
sp_enabled=True,
|
||||||
|
eager_mode=False,
|
||||||
|
chunked_prefill=True),
|
||||||
|
ParallelSetup(tp_size=tp_base,
|
||||||
|
sp_enabled=True,
|
||||||
|
eager_mode=True,
|
||||||
|
chunked_prefill=False),
|
||||||
|
ParallelSetup(tp_size=tp_base,
|
||||||
|
sp_enabled=True,
|
||||||
|
eager_mode=True,
|
||||||
|
chunked_prefill=True)
|
||||||
|
],
|
||||||
|
distributed_backends=["mp", "ray"],
|
||||||
|
vllm_major_versions=["1", "1"],
|
||||||
|
task=task,
|
||||||
|
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
||||||
|
load_format=load_format),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fast(
|
||||||
|
*,
|
||||||
|
tp_base: int = 2,
|
||||||
|
task: TaskOption = "auto",
|
||||||
|
multi_node_only: bool = False,
|
||||||
|
load_format: Optional[str] = None,
|
||||||
|
):
|
||||||
|
return SPTestSettings(
|
||||||
|
parallel_setups=[
|
||||||
|
ParallelSetup(tp_size=tp_base,
|
||||||
|
sp_enabled=True,
|
||||||
|
eager_mode=False,
|
||||||
|
chunked_prefill=False),
|
||||||
|
],
|
||||||
|
distributed_backends=["mp", "ray"],
|
||||||
|
vllm_major_versions=["1", "1"],
|
||||||
|
task=task,
|
||||||
|
test_options=SPTestOptions(multi_node_only=multi_node_only,
|
||||||
|
load_format=load_format),
|
||||||
|
)
|
||||||
|
|
||||||
|
def iter_params(self, model_id: str):
|
||||||
|
opts = self.test_options
|
||||||
|
|
||||||
|
for parallel_setup in self.parallel_setups:
|
||||||
|
for backend, vllm_major_version in zip(self.distributed_backends,
|
||||||
|
self.vllm_major_versions):
|
||||||
|
yield (model_id, parallel_setup, backend, vllm_major_version,
|
||||||
|
self.task, opts)
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_sp(
|
||||||
|
model_id: str,
|
||||||
|
parallel_setup: ParallelSetup,
|
||||||
|
distributed_backend: str,
|
||||||
|
vllm_major_version: str,
|
||||||
|
task: TaskOption,
|
||||||
|
test_options: SPTestOptions,
|
||||||
|
num_gpus_available: int,
|
||||||
|
*,
|
||||||
|
method: Literal["generate", "encode"],
|
||||||
|
is_multimodal: bool,
|
||||||
|
):
|
||||||
|
(
|
||||||
|
tp_size,
|
||||||
|
sp_enabled,
|
||||||
|
eager_mode,
|
||||||
|
chunked_prefill,
|
||||||
|
) = parallel_setup
|
||||||
|
|
||||||
|
multi_node_only, load_format = test_options
|
||||||
|
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||||
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
|
||||||
|
trust_remote_code = model_info.trust_remote_code
|
||||||
|
tokenizer_mode = model_info.tokenizer_mode
|
||||||
|
hf_overrides = model_info.hf_overrides
|
||||||
|
|
||||||
|
if load_format == "dummy":
|
||||||
|
# Avoid OOM
|
||||||
|
text_overrides = {
|
||||||
|
"num_hidden_layers": 4,
|
||||||
|
"hidden_size": 512,
|
||||||
|
"intermediate_size": 800,
|
||||||
|
"num_attention_heads": 4,
|
||||||
|
"num_key_value_heads": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_multimodal:
|
||||||
|
hf_overrides.update({"text_config": text_overrides})
|
||||||
|
else:
|
||||||
|
hf_overrides.update(text_overrides)
|
||||||
|
else:
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
|
||||||
|
pp_size = 1
|
||||||
|
if num_gpus_available < tp_size * pp_size:
|
||||||
|
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||||
|
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||||
|
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||||
|
"multiprocessing distributed backend")
|
||||||
|
if multi_node_only and not VLLM_MULTI_NODE:
|
||||||
|
pytest.skip("Not in multi-node setting")
|
||||||
|
|
||||||
|
common_args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"float16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"8",
|
||||||
|
]
|
||||||
|
if chunked_prefill:
|
||||||
|
common_args.append("--enable-chunked-prefill")
|
||||||
|
if eager_mode:
|
||||||
|
common_args.append("--enforce-eager")
|
||||||
|
if task != "auto":
|
||||||
|
common_args.extend(["--task", task])
|
||||||
|
if trust_remote_code:
|
||||||
|
common_args.append("--trust-remote-code")
|
||||||
|
if tokenizer_mode:
|
||||||
|
common_args.extend(["--tokenizer-mode", tokenizer_mode])
|
||||||
|
if load_format:
|
||||||
|
common_args.extend(["--load-format", load_format])
|
||||||
|
if hf_overrides:
|
||||||
|
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||||
|
|
||||||
|
compilation_config = {
|
||||||
|
'level': 3,
|
||||||
|
'custom_ops': ["+rms_norm"],
|
||||||
|
'compile_sizes': [4, 8],
|
||||||
|
'splitting_ops': [],
|
||||||
|
'pass_config': {
|
||||||
|
'enable_sequence_parallism': sp_enabled,
|
||||||
|
'enable_noop': True,
|
||||||
|
'enable_fusion': True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tp_sp_env = tp_env = {
|
||||||
|
"VLLM_USE_V1": vllm_major_version,
|
||||||
|
}
|
||||||
|
|
||||||
|
tp_sp_args = [
|
||||||
|
*common_args,
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
str(tp_size),
|
||||||
|
"--distributed-executor-backend",
|
||||||
|
distributed_backend,
|
||||||
|
"--compilation_config",
|
||||||
|
str(compilation_config),
|
||||||
|
]
|
||||||
|
|
||||||
|
tp_env = {
|
||||||
|
"VLLM_USE_V1": vllm_major_version,
|
||||||
|
}
|
||||||
|
tp_args = [
|
||||||
|
*common_args,
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
str(tp_size),
|
||||||
|
"--distributed-executor-backend",
|
||||||
|
"mp",
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
compare_two_settings(model_id,
|
||||||
|
tp_sp_args,
|
||||||
|
tp_args,
|
||||||
|
tp_sp_env,
|
||||||
|
tp_env,
|
||||||
|
method=method)
|
||||||
|
except Exception:
|
||||||
|
testing_ray_compiled_graph = tp_sp_env is not None
|
||||||
|
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||||
|
# Ray Compiled Graph tests are flaky for V0,
|
||||||
|
# so we don't want to fail the test
|
||||||
|
logger.exception("Ray Compiled Graph tests failed")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
SP_TEXT_GENERATION_MODELS = {
|
||||||
|
# [Decoder-only]
|
||||||
|
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(),
|
||||||
|
}
|
||||||
|
|
||||||
|
SP_TEST_MODELS = [
|
||||||
|
# TODO support other models
|
||||||
|
# [LANGUAGE GENERATION]
|
||||||
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||||
|
"task", "test_options"),
|
||||||
|
[
|
||||||
|
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
|
||||||
|
for params in settings.iter_params(model_id)
|
||||||
|
if model_id in SP_TEST_MODELS
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@create_new_process_for_each_test()
|
||||||
|
def test_tp_sp_generation(
|
||||||
|
model_id: str,
|
||||||
|
parallel_setup: ParallelSetup,
|
||||||
|
distributed_backend: str,
|
||||||
|
vllm_major_version: str,
|
||||||
|
task: TaskOption,
|
||||||
|
test_options: SPTestOptions,
|
||||||
|
num_gpus_available,
|
||||||
|
):
|
||||||
|
_compare_sp(model_id,
|
||||||
|
parallel_setup,
|
||||||
|
distributed_backend,
|
||||||
|
vllm_major_version,
|
||||||
|
task,
|
||||||
|
test_options,
|
||||||
|
num_gpus_available,
|
||||||
|
method="generate",
|
||||||
|
is_multimodal=False)
|
||||||
@ -339,7 +339,7 @@ class VllmBackend:
|
|||||||
|
|
||||||
def configure_post_pass(self):
|
def configure_post_pass(self):
|
||||||
config = self.compilation_config
|
config = self.compilation_config
|
||||||
self.post_grad_pass_manager.configure(config.pass_config)
|
self.post_grad_pass_manager.configure(self.vllm_config)
|
||||||
|
|
||||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||||
|
|||||||
@ -15,6 +15,8 @@ import vllm.envs as envs
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.utils import is_torch_equal_or_newer
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
from .inductor_pass import pass_context
|
||||||
|
|
||||||
|
|
||||||
class CompilerInterface:
|
class CompilerInterface:
|
||||||
"""
|
"""
|
||||||
@ -312,11 +314,12 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
torch._functorch.config.patch(
|
torch._functorch.config.patch(
|
||||||
enable_remote_autograd_cache=False))
|
enable_remote_autograd_cache=False))
|
||||||
|
|
||||||
compiled_graph = compile_fx(
|
with pass_context(runtime_shape):
|
||||||
graph,
|
compiled_graph = compile_fx(
|
||||||
example_inputs,
|
graph,
|
||||||
inner_compile=hijacked_compile_fx_inner,
|
example_inputs,
|
||||||
config_patches=current_config)
|
inner_compile=hijacked_compile_fx_inner,
|
||||||
|
config_patches=current_config)
|
||||||
|
|
||||||
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
|
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
|
||||||
# compilation cache. So turn off the checks if we disable the
|
# compilation cache. So turn off the checks if we disable the
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|||||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -531,7 +531,7 @@ class FusionPass(VllmInductorPass):
|
|||||||
_instance: 'Optional[FusionPass]' = None
|
_instance: 'Optional[FusionPass]' = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def instance(cls, config: CompilationConfig.PassConfig):
|
def instance(cls, config: VllmConfig):
|
||||||
"""
|
"""
|
||||||
Get the singleton instance of the FusionPass.
|
Get the singleton instance of the FusionPass.
|
||||||
If the instance exists, the config is updated but
|
If the instance exists, the config is updated but
|
||||||
@ -540,10 +540,10 @@ class FusionPass(VllmInductorPass):
|
|||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = FusionPass(config)
|
cls._instance = FusionPass(config)
|
||||||
else:
|
else:
|
||||||
cls._instance.config = config
|
cls._instance.pass_config = config.compilation_config.pass_config
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, config: CompilationConfig.PassConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
assert self.__class__._instance is None, \
|
assert self.__class__._instance is None, \
|
||||||
"FusionPass singleton instance already exists"
|
"FusionPass singleton instance already exists"
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@ -12,6 +12,22 @@ def is_func(node: fx.Node, target) -> bool:
|
|||||||
return node.op == "call_function" and node.target == target
|
return node.op == "call_function" and node.target == target
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the first specified node with the given op (if it exists)
|
||||||
|
def find_specified_fn_maybe(nodes: Iterable[fx.Node],
|
||||||
|
op: OpOverload) -> Optional[fx.Node]:
|
||||||
|
for node in nodes:
|
||||||
|
if node.target == op:
|
||||||
|
return node
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the first specified node with the given op
|
||||||
|
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||||
|
node = find_specified_fn_maybe(nodes, op)
|
||||||
|
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||||
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
|
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
|
||||||
op: OpOverload) -> Optional[fx.Node]:
|
op: OpOverload) -> Optional[fx.Node]:
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import hashlib
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import types
|
import types
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
from typing import Any, Callable, Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -18,6 +19,34 @@ else:
|
|||||||
from .torch25_custom_graph_pass import ( # noqa: yapf
|
from .torch25_custom_graph_pass import ( # noqa: yapf
|
||||||
Torch25CustomGraphPass as CustomGraphPass)
|
Torch25CustomGraphPass as CustomGraphPass)
|
||||||
|
|
||||||
|
_pass_context = None
|
||||||
|
|
||||||
|
|
||||||
|
class PassContext:
|
||||||
|
|
||||||
|
def __init__(self, runtime_shape: Optional[int]):
|
||||||
|
self.runtime_shape = runtime_shape
|
||||||
|
|
||||||
|
|
||||||
|
def get_pass_context() -> PassContext:
|
||||||
|
"""Get the current pass context."""
|
||||||
|
assert _pass_context is not None
|
||||||
|
return _pass_context
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def pass_context(runtime_shape: Optional[int]):
|
||||||
|
"""A context manager that stores the current pass context,
|
||||||
|
usually it is a list of sizes to specialize.
|
||||||
|
"""
|
||||||
|
global _pass_context
|
||||||
|
prev_context = _pass_context
|
||||||
|
_pass_context = PassContext(runtime_shape)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_pass_context = prev_context
|
||||||
|
|
||||||
|
|
||||||
class InductorPass(CustomGraphPass):
|
class InductorPass(CustomGraphPass):
|
||||||
"""
|
"""
|
||||||
@ -62,6 +91,9 @@ class InductorPass(CustomGraphPass):
|
|||||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||||
return hashlib.sha256(encoded).hexdigest()
|
return hashlib.sha256(encoded).hexdigest()
|
||||||
|
|
||||||
|
def is_applicable_for_shape(self, shape: Optional[int]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class CallableInductorPass(InductorPass):
|
class CallableInductorPass(InductorPass):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -4,13 +4,15 @@ from typing import List
|
|||||||
|
|
||||||
from torch import fx as fx
|
from torch import fx as fx
|
||||||
|
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .fix_functionalization import FixFunctionalizationPass
|
from .fix_functionalization import FixFunctionalizationPass
|
||||||
from .fusion import FusionPass
|
from .fusion import FusionPass
|
||||||
from .inductor_pass import CustomGraphPass, InductorPass
|
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||||
from .noop_elimination import NoOpEliminationPass
|
from .noop_elimination import NoOpEliminationPass
|
||||||
|
from .sequence_parallelism import SequenceParallelismPass
|
||||||
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -31,24 +33,29 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.passes: List[InductorPass] = []
|
self.passes: List[VllmInductorPass] = []
|
||||||
|
|
||||||
def __call__(self, graph: fx.Graph):
|
def __call__(self, graph: fx.Graph):
|
||||||
|
shape = get_pass_context().runtime_shape
|
||||||
for pass_ in self.passes:
|
for pass_ in self.passes:
|
||||||
pass_(graph)
|
if pass_.is_applicable_for_shape(shape):
|
||||||
|
pass_(graph)
|
||||||
|
|
||||||
# always run fix_functionalization last
|
# always run fix_functionalization last
|
||||||
self.fix_functionalization(graph)
|
self.fix_functionalization(graph)
|
||||||
|
|
||||||
def configure(self, pass_config: CompilationConfig.PassConfig):
|
def configure(self, config: VllmConfig):
|
||||||
self.pass_config = pass_config
|
self.pass_config = config.compilation_config.pass_config
|
||||||
if pass_config.enable_noop:
|
if self.pass_config.enable_noop:
|
||||||
self.passes += [NoOpEliminationPass(pass_config)]
|
self.passes += [NoOpEliminationPass(config)]
|
||||||
|
|
||||||
if pass_config.enable_fusion:
|
if self.pass_config.enable_fusion:
|
||||||
self.passes += [FusionPass.instance(pass_config)]
|
self.passes += [FusionPass.instance(config)]
|
||||||
|
|
||||||
self.fix_functionalization = FixFunctionalizationPass(pass_config)
|
if self.pass_config.enable_sequence_parallelism:
|
||||||
|
self.passes += [SequenceParallelismPass(config)]
|
||||||
|
|
||||||
|
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||||
|
|
||||||
def add(self, pass_: InductorPass):
|
def add(self, pass_: InductorPass):
|
||||||
assert isinstance(pass_, InductorPass)
|
assert isinstance(pass_, InductorPass)
|
||||||
|
|||||||
266
vllm/compilation/sequence_parallelism.py
Normal file
266
vllm/compilation/sequence_parallelism.py
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch._inductor.pattern_matcher as pm
|
||||||
|
import torch.fx as fx
|
||||||
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||||
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceRMSNormPattern:
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||||
|
|
||||||
|
def get_inputs(self):
|
||||||
|
arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||||
|
mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.long)
|
||||||
|
unsqueeze = torch.rand([1, 8, 1], device=self.device, \
|
||||||
|
dtype=self.dtype) > 0.5
|
||||||
|
full_default = torch.zeros([1, 8, 4], device=self.device, \
|
||||||
|
dtype=self.dtype)
|
||||||
|
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||||
|
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1]
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def pattern(
|
||||||
|
arg2_1: torch.Tensor,
|
||||||
|
mul_6: torch.Tensor,
|
||||||
|
unsqueeze: torch.Tensor,
|
||||||
|
full_default: torch.Tensor,
|
||||||
|
permute: torch.Tensor,
|
||||||
|
arg3_1: torch.Tensor,
|
||||||
|
):
|
||||||
|
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
|
||||||
|
where = torch.ops.aten.where.self(unsqueeze, full_default,
|
||||||
|
embedding)
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(where)
|
||||||
|
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||||
|
torch.ops._C.rms_norm.default,
|
||||||
|
result=permute,
|
||||||
|
input=all_reduce,
|
||||||
|
weight=arg3_1,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
return rmsnorm[1], all_reduce
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
arg2_1: torch.Tensor,
|
||||||
|
mul_6: torch.Tensor,
|
||||||
|
unsqueeze: torch.Tensor,
|
||||||
|
full_default: torch.Tensor,
|
||||||
|
permute: torch.Tensor,
|
||||||
|
arg3_1: torch.Tensor,
|
||||||
|
):
|
||||||
|
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
|
||||||
|
where = torch.ops.aten.where.self(unsqueeze, full_default,
|
||||||
|
embedding)
|
||||||
|
|
||||||
|
tp = get_tp_group()
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||||
|
where, dim=0, world_size=tp_size, group_name=tp.unique_name)
|
||||||
|
|
||||||
|
rmsnorm_result = torch.empty_like(reduce_scatter)
|
||||||
|
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||||
|
torch.ops._C.rms_norm.default,
|
||||||
|
result=rmsnorm_result,
|
||||||
|
input=reduce_scatter,
|
||||||
|
weight=arg3_1,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_gather = torch.ops.vllm.all_gather.default(
|
||||||
|
rmsnorm[1],
|
||||||
|
dim=0,
|
||||||
|
world_size=tp_size,
|
||||||
|
group_name=tp.unique_name)
|
||||||
|
|
||||||
|
return all_gather, reduce_scatter
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||||
|
|
||||||
|
def get_inputs(self):
|
||||||
|
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
|
rms_norm_weights = torch.empty([4, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
|
return [
|
||||||
|
residual,
|
||||||
|
mm_1,
|
||||||
|
rms_norm_weights,
|
||||||
|
]
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def pattern(
|
||||||
|
residual: torch.Tensor,
|
||||||
|
mm_1: torch.Tensor,
|
||||||
|
rms_norm_weights: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
||||||
|
|
||||||
|
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||||
|
torch.ops._C.fused_add_rms_norm.default,
|
||||||
|
input=all_reduce,
|
||||||
|
residual=residual,
|
||||||
|
weight=rms_norm_weights,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
return rmsnorm[1], rmsnorm[2]
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
residual: torch.Tensor,
|
||||||
|
mm_1: torch.Tensor,
|
||||||
|
rms_norm_weights: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
tp = get_tp_group()
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||||
|
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
|
||||||
|
|
||||||
|
# TODO is it possible to extract epsilon from somewhere
|
||||||
|
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||||
|
torch.ops._C.fused_add_rms_norm.default,
|
||||||
|
input=reduce_scatter,
|
||||||
|
residual=residual,
|
||||||
|
weight=rms_norm_weights,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_gather = torch.ops.vllm.all_gather.default(
|
||||||
|
rmsnorm[1],
|
||||||
|
dim=0,
|
||||||
|
world_size=tp_size,
|
||||||
|
group_name=tp.unique_name)
|
||||||
|
return all_gather, rmsnorm[2]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||||
|
|
||||||
|
def get_inputs(self):
|
||||||
|
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
|
rms_norm_weights = torch.empty([4, 4],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
|
return [
|
||||||
|
residual,
|
||||||
|
mm_1,
|
||||||
|
rms_norm_weights,
|
||||||
|
]
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def pattern(
|
||||||
|
residual: torch.Tensor,
|
||||||
|
mm_1: torch.Tensor,
|
||||||
|
rms_norm_weights: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
||||||
|
|
||||||
|
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||||
|
torch.ops._C.fused_add_rms_norm.default,
|
||||||
|
input=all_reduce,
|
||||||
|
residual=residual,
|
||||||
|
weight=rms_norm_weights,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
return rmsnorm[1]
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
residual: torch.Tensor,
|
||||||
|
mm_1: torch.Tensor,
|
||||||
|
rms_norm_weights: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
tp = get_tp_group()
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||||
|
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
|
||||||
|
|
||||||
|
# TODO is it possible to extract epsilon from somewhere
|
||||||
|
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||||
|
torch.ops._C.fused_add_rms_norm.default,
|
||||||
|
input=reduce_scatter,
|
||||||
|
residual=residual,
|
||||||
|
weight=rms_norm_weights,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized = torch.ops.vllm.all_gather.default(
|
||||||
|
rmsnorm[1],
|
||||||
|
dim=0,
|
||||||
|
world_size=tp_size,
|
||||||
|
group_name=tp.unique_name)
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelismPass(VllmInductorPass):
|
||||||
|
|
||||||
|
def __init__(self, config: VllmConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||||
|
pass_name="sequence_parallelism_pass")
|
||||||
|
for epsilon in [1e-5, 1e-6]:
|
||||||
|
EmbeddingAllReduceRMSNormPattern(
|
||||||
|
epsilon, self.dtype, self.device).register(self.patterns)
|
||||||
|
|
||||||
|
MiddleAllReduceRMSNormPattern(epsilon, self.dtype,
|
||||||
|
self.device).register(self.patterns)
|
||||||
|
|
||||||
|
LastAllReduceRMSNormPattern(epsilon, self.dtype,
|
||||||
|
self.device).register(self.patterns)
|
||||||
|
# WARNING: This is a hack to clear the pattern matcher cache
|
||||||
|
# and allow multiple values of epsilon.
|
||||||
|
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||||
|
|
||||||
|
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||||
|
# only do replace for specific shapes
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
return shape is not None and shape % tp_size == 0
|
||||||
|
|
||||||
|
def __call__(self, graph: fx.Graph):
|
||||||
|
self.dump_graph(graph, "before_sequence_parallelism_pass")
|
||||||
|
count = self.patterns.apply(graph)
|
||||||
|
logger.debug("Replaced %s patterns", count)
|
||||||
|
self.dump_graph(graph, "after_sequence_parallelism_pass")
|
||||||
@ -4,7 +4,7 @@ import time
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig, VllmConfig
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
@ -24,16 +24,19 @@ class VllmInductorPass(InductorPass):
|
|||||||
It provides timing, logging, and dumping utilities.
|
It provides timing, logging, and dumping utilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: CompilationConfig.PassConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
self.config = config
|
self.pass_config = config.compilation_config.pass_config
|
||||||
|
self.dtype = config.model_config.dtype if config.model_config else None
|
||||||
|
self.device = config.device_config.device if config.device_config \
|
||||||
|
else None
|
||||||
self.pass_name = self.__class__.__name__
|
self.pass_name = self.__class__.__name__
|
||||||
|
|
||||||
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
|
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
|
||||||
if stage in self.config.dump_graph_stages or always:
|
if stage in self.pass_config.dump_graph_stages or always:
|
||||||
# Make sure filename includes rank in the distributed setting
|
# Make sure filename includes rank in the distributed setting
|
||||||
parallel = p_is_init() and get_tp_world_size() > 1
|
parallel = p_is_init() and get_tp_world_size() > 1
|
||||||
rank = f"-{get_tp_rank()}" if parallel else ""
|
rank = f"-{get_tp_rank()}" if parallel else ""
|
||||||
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
|
filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py"
|
||||||
|
|
||||||
logger.info("%s printing graph to %s", self.pass_name, filepath)
|
logger.info("%s printing graph to %s", self.pass_name, filepath)
|
||||||
with open(filepath, "w") as f:
|
with open(filepath, "w") as f:
|
||||||
|
|||||||
@ -3405,11 +3405,13 @@ class CompilationConfig(BaseModel):
|
|||||||
- enable_fusion: whether to enable the custom fusion pass.
|
- enable_fusion: whether to enable the custom fusion pass.
|
||||||
- enable_noop: whether to enable the custom no-op elimination pass.
|
- enable_noop: whether to enable the custom no-op elimination pass.
|
||||||
TODO(luka) better pass enabling system.
|
TODO(luka) better pass enabling system.
|
||||||
|
- enable_sequence_parallelism: whether to enable sequence parallelism.
|
||||||
"""
|
"""
|
||||||
dump_graph_stages: list[str] = Field(default_factory=list)
|
dump_graph_stages: list[str] = Field(default_factory=list)
|
||||||
dump_graph_dir: Path = Field(default=Path("."))
|
dump_graph_dir: Path = Field(default=Path("."))
|
||||||
enable_fusion: bool = True
|
enable_fusion: bool = True
|
||||||
enable_noop: bool = True
|
enable_noop: bool = True
|
||||||
|
enable_sequence_parallelism: bool = False
|
||||||
|
|
||||||
def uuid(self):
|
def uuid(self):
|
||||||
"""
|
"""
|
||||||
@ -3418,7 +3420,8 @@ class CompilationConfig(BaseModel):
|
|||||||
Do not include dump_graph_* in the hash - they don't affect
|
Do not include dump_graph_* in the hash - they don't affect
|
||||||
compilation.
|
compilation.
|
||||||
"""
|
"""
|
||||||
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
|
dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \
|
||||||
|
"enable_sequence_parallelism"})
|
||||||
return InductorPass.hash_dict(dict_)
|
return InductorPass.hash_dict(dict_)
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
@ -3840,6 +3843,8 @@ class VllmConfig:
|
|||||||
|
|
||||||
if self.compilation_config is None:
|
if self.compilation_config is None:
|
||||||
self.compilation_config = CompilationConfig()
|
self.compilation_config = CompilationConfig()
|
||||||
|
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||||
|
self.compilation_config.custom_ops.append("+rms_norm")
|
||||||
if envs.VLLM_USE_V1 and self.model_config is not None and \
|
if envs.VLLM_USE_V1 and self.model_config is not None and \
|
||||||
not self.model_config.enforce_eager:
|
not self.model_config.enforce_eager:
|
||||||
# NOTE(woosuk): Currently, we use inductor because the piecewise
|
# NOTE(woosuk): Currently, we use inductor because the piecewise
|
||||||
@ -3847,7 +3852,8 @@ class VllmConfig:
|
|||||||
# FIXME(woosuk): Disable inductor to reduce the compilation time
|
# FIXME(woosuk): Disable inductor to reduce the compilation time
|
||||||
# and avoid any potential issues with the inductor.
|
# and avoid any potential issues with the inductor.
|
||||||
# FIXME(rob): Add function to set all of these.
|
# FIXME(rob): Add function to set all of these.
|
||||||
self.compilation_config.custom_ops = ["none"]
|
if not self.compilation_config.custom_ops:
|
||||||
|
self.compilation_config.custom_ops = ["none"]
|
||||||
self.compilation_config.use_cudagraph = True
|
self.compilation_config.use_cudagraph = True
|
||||||
self.compilation_config.use_inductor = True
|
self.compilation_config.use_inductor = True
|
||||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||||
@ -3856,6 +3862,18 @@ class VllmConfig:
|
|||||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||||
self.compilation_config.set_splitting_ops_for_v1()
|
self.compilation_config.set_splitting_ops_for_v1()
|
||||||
|
|
||||||
|
if self.parallel_config is not None and \
|
||||||
|
self.parallel_config.tensor_parallel_size > 1 and \
|
||||||
|
self.parallel_config.pipeline_parallel_size > 1 and \
|
||||||
|
self.compilation_config is not None and \
|
||||||
|
self.compilation_config.pass_config is not None and \
|
||||||
|
self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||||
|
logger.warning_once(
|
||||||
|
"Sequence parallelism is not supported with pipeline "
|
||||||
|
"parallelism. Disabling sequence parallelism.")
|
||||||
|
self.compilation_config.pass_config.\
|
||||||
|
enable_sequence_parallelism = False
|
||||||
|
|
||||||
self._set_cudagraph_sizes()
|
self._set_cudagraph_sizes()
|
||||||
|
|
||||||
if self.cache_config is not None and \
|
if self.cache_config is not None and \
|
||||||
@ -3895,6 +3913,26 @@ class VllmConfig:
|
|||||||
if not self.instance_id:
|
if not self.instance_id:
|
||||||
self.instance_id = random_uuid()[:5]
|
self.instance_id = random_uuid()[:5]
|
||||||
|
|
||||||
|
def update_sizes_for_sequence_parallelism(self,
|
||||||
|
possible_sizes: list) -> list:
|
||||||
|
# remove the sizes that not multiple of tp_size when
|
||||||
|
# enable sequence parallelism
|
||||||
|
removed_sizes = [
|
||||||
|
size for size in possible_sizes
|
||||||
|
if size % self.parallel_config.tensor_parallel_size != 0
|
||||||
|
]
|
||||||
|
if removed_sizes:
|
||||||
|
logger.warning(
|
||||||
|
"Batch sizes %s are removed because they are not "
|
||||||
|
"multiple of tp_size %d when "
|
||||||
|
"sequence parallelism is enabled", removed_sizes,
|
||||||
|
self.parallel_config.tensor_parallel_size)
|
||||||
|
|
||||||
|
return [
|
||||||
|
size for size in possible_sizes
|
||||||
|
if size % self.parallel_config.tensor_parallel_size == 0
|
||||||
|
]
|
||||||
|
|
||||||
def _set_cudagraph_sizes(self):
|
def _set_cudagraph_sizes(self):
|
||||||
"""
|
"""
|
||||||
cudagraph batchsize padding logic:
|
cudagraph batchsize padding logic:
|
||||||
@ -3932,6 +3970,11 @@ class VllmConfig:
|
|||||||
not self.model_config.enforce_eager:
|
not self.model_config.enforce_eager:
|
||||||
|
|
||||||
possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
|
possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
|
||||||
|
if self.parallel_config.tensor_parallel_size > 1 and \
|
||||||
|
self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||||
|
possible_sizes = self.update_sizes_for_sequence_parallelism(
|
||||||
|
possible_sizes)
|
||||||
|
|
||||||
# find the minimum size that is larger than max_num_seqs,
|
# find the minimum size that is larger than max_num_seqs,
|
||||||
# which then becomes the max_batchsize_to_capture
|
# which then becomes the max_batchsize_to_capture
|
||||||
larger_sizes = [
|
larger_sizes = [
|
||||||
@ -3955,6 +3998,11 @@ class VllmConfig:
|
|||||||
not self.model_config.enforce_eager:
|
not self.model_config.enforce_eager:
|
||||||
batch_size_capture_list = [1, 2, 4
|
batch_size_capture_list = [1, 2, 4
|
||||||
] + [i for i in range(8, 513, 8)]
|
] + [i for i in range(8, 513, 8)]
|
||||||
|
if self.parallel_config.tensor_parallel_size > 1 and \
|
||||||
|
self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||||
|
batch_size_capture_list = \
|
||||||
|
self.update_sizes_for_sequence_parallelism(batch_size_capture_list)
|
||||||
|
|
||||||
max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
batch_size_capture_list = [
|
batch_size_capture_list = [
|
||||||
size for size in batch_size_capture_list
|
size for size in batch_size_capture_list
|
||||||
|
|||||||
@ -19,6 +19,12 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
|
|||||||
return get_tp_group().all_gather(input_, dim)
|
return get_tp_group().all_gather(input_, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
|
||||||
|
dim: int = -1) -> torch.Tensor:
|
||||||
|
"""Reduce-Scatter the input tensor across model parallel group."""
|
||||||
|
return get_tp_group().reduce_scatter(input_, dim)
|
||||||
|
|
||||||
|
|
||||||
def tensor_model_parallel_gather(input_: torch.Tensor,
|
def tensor_model_parallel_gather(input_: torch.Tensor,
|
||||||
dst: int = 0,
|
dst: int = 0,
|
||||||
dim: int = -1) -> Optional[torch.Tensor]:
|
dim: int = -1) -> Optional[torch.Tensor]:
|
||||||
|
|||||||
@ -61,6 +61,40 @@ class DeviceCommunicatorBase:
|
|||||||
input_size[dim + 1:])
|
input_size[dim + 1:])
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
|
def reduce_scatter(self,
|
||||||
|
input_: torch.Tensor,
|
||||||
|
dim: int = -1) -> torch.Tensor:
|
||||||
|
world_size = self.world_size
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if world_size == 1:
|
||||||
|
return input_
|
||||||
|
assert -input_.dim() <= dim < input_.dim(), (
|
||||||
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||||
|
|
||||||
|
if dim < 0:
|
||||||
|
# Convert negative dim to positive.
|
||||||
|
dim += input_.dim()
|
||||||
|
|
||||||
|
# Note: This will produce an incorrect answer if we don't make
|
||||||
|
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||||
|
input_tensor = input_.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
|
assert input_tensor.shape[0] % world_size == 0
|
||||||
|
chunk_size = input_tensor.shape[0] // world_size
|
||||||
|
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||||
|
|
||||||
|
output_tensor = torch.empty(output_shape,
|
||||||
|
dtype=input_tensor.dtype,
|
||||||
|
device=input_tensor.device)
|
||||||
|
|
||||||
|
# Perform reduce-scatter operation
|
||||||
|
torch.distributed.reduce_scatter_tensor(output_tensor,
|
||||||
|
input_tensor,
|
||||||
|
group=self.device_group)
|
||||||
|
|
||||||
|
# Reshape before returning
|
||||||
|
return output_tensor.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
def gather(self,
|
def gather(self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
dst: int = 0,
|
dst: int = 0,
|
||||||
|
|||||||
@ -70,6 +70,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
torch.distributed.all_reduce(out, group=self.device_group)
|
torch.distributed.all_reduce(out, group=self.device_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||||
|
world_size = self.world_size
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
assert pynccl_comm is not None
|
||||||
|
if dim < 0:
|
||||||
|
# Convert negative dim to positive.
|
||||||
|
dim += input_.dim()
|
||||||
|
|
||||||
|
# Note: This will produce an incorrect answer if we don't make
|
||||||
|
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||||
|
input_tensor = input_.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
|
assert input_tensor.shape[0] % world_size == 0
|
||||||
|
chunk_size = input_tensor.shape[0] // world_size
|
||||||
|
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||||
|
|
||||||
|
output = torch.empty(output_shape,
|
||||||
|
dtype=input_tensor.dtype,
|
||||||
|
device=input_tensor.device)
|
||||||
|
|
||||||
|
pynccl_comm.reduce_scatter(output, input_)
|
||||||
|
|
||||||
|
# Reshape before returning
|
||||||
|
return output.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||||
|
|||||||
@ -113,6 +113,38 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
|||||||
return torch.empty_like(tensor)
|
return torch.empty_like(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
|
||||||
|
group_name: str) -> torch.Tensor:
|
||||||
|
assert group_name in _groups, f"Group {group_name} is not found."
|
||||||
|
group = _groups[group_name]()
|
||||||
|
if group is None:
|
||||||
|
raise ValueError(f"Group {group_name} is destroyed.")
|
||||||
|
return group.reduce_scatter(tensor, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
|
||||||
|
group_name: str) -> torch.Tensor:
|
||||||
|
new_shape = list(tensor.shape)
|
||||||
|
new_shape[dim] = tensor.shape[dim] // world_size
|
||||||
|
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
|
||||||
|
group_name: str) -> torch.Tensor:
|
||||||
|
assert group_name in _groups, f"Group {group_name} is not found."
|
||||||
|
group = _groups[group_name]()
|
||||||
|
if group is None:
|
||||||
|
raise ValueError(f"Group {group_name} is destroyed.")
|
||||||
|
return group.all_gather(tensor, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
|
||||||
|
group_name: str) -> torch.Tensor:
|
||||||
|
new_shape = list(tensor.shape)
|
||||||
|
new_shape[dim] = tensor.shape[dim] * world_size
|
||||||
|
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
@ -123,6 +155,20 @@ if supports_custom_op():
|
|||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="reduce_scatter",
|
||||||
|
op_func=reduce_scatter,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=reduce_scatter_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="all_gather",
|
||||||
|
op_func=all_gather,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=all_gather_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GroupCoordinator:
|
class GroupCoordinator:
|
||||||
"""
|
"""
|
||||||
@ -322,6 +368,18 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
return self.device_communicator.all_gather(input_, dim)
|
return self.device_communicator.all_gather(input_, dim)
|
||||||
|
|
||||||
|
def reduce_scatter(self,
|
||||||
|
input_: torch.Tensor,
|
||||||
|
dim: int = -1) -> torch.Tensor:
|
||||||
|
world_size = self.world_size
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if world_size == 1:
|
||||||
|
return input_
|
||||||
|
assert -input_.dim() <= dim < input_.dim(), (
|
||||||
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||||
|
|
||||||
|
return self.device_communicator.reduce_scatter(input_, dim)
|
||||||
|
|
||||||
def gather(self,
|
def gather(self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
dst: int = 0,
|
dst: int = 0,
|
||||||
|
|||||||
@ -1027,7 +1027,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
else:
|
else:
|
||||||
# Eager mode.
|
# Eager mode.
|
||||||
num_input_tokens = num_scheduled_tokens
|
# Pad tokens to multiple of tensor_parallel_size when
|
||||||
|
# enabled collective fusion for SP
|
||||||
|
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||||
|
if self.vllm_config.compilation_config.pass_config. \
|
||||||
|
enable_sequence_parallelism and tp_size > 1:
|
||||||
|
from vllm.utils import round_up
|
||||||
|
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
||||||
|
else:
|
||||||
|
num_input_tokens = num_scheduled_tokens
|
||||||
attn_metadata.num_input_tokens = num_input_tokens
|
attn_metadata.num_input_tokens = num_input_tokens
|
||||||
|
|
||||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user