mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 22:27:14 +08:00
[compile] Add patched_fused_scaled_matmul_reduce_scatter (#26604)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
parent
d0bed837ac
commit
a25f2adee9
@ -142,7 +142,7 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
|
|||||||
return [torch.ops.vllm.reduce_scatter.default]
|
return [torch.ops.vllm.reduce_scatter.default]
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
|
||||||
|
|
||||||
|
|
||||||
class TestAGScaledMMModel(_BaseScaledMMModel):
|
class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||||
@ -195,7 +195,7 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
|||||||
return [torch.ops.vllm.reduce_scatter.default]
|
return [torch.ops.vllm.reduce_scatter.default]
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
|
||||||
|
|
||||||
|
|
||||||
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||||
@ -243,9 +243,15 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
|||||||
@pytest.mark.parametrize("seq_len", [16])
|
@pytest.mark.parametrize("seq_len", [16])
|
||||||
@pytest.mark.parametrize("hidden_size", [16])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("dynamic", [True, False])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
def test_async_tp_pass_replace(
|
def test_async_tp_pass_replace(
|
||||||
test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype
|
test_model: str,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
dynamic: bool,
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
test_model
|
test_model
|
||||||
@ -269,7 +275,15 @@ def test_async_tp_pass_replace(
|
|||||||
# torch.distributed and cuda
|
# torch.distributed and cuda
|
||||||
torch.multiprocessing.spawn(
|
torch.multiprocessing.spawn(
|
||||||
fn,
|
fn,
|
||||||
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
|
args=(
|
||||||
|
num_processes,
|
||||||
|
test_model,
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
hidden_size,
|
||||||
|
dtype,
|
||||||
|
dynamic,
|
||||||
|
),
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -284,6 +298,7 @@ def async_tp_pass_on_test_model(
|
|||||||
seq_len: int,
|
seq_len: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
dynamic: bool,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
@ -331,6 +346,9 @@ def async_tp_pass_on_test_model(
|
|||||||
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if dynamic:
|
||||||
|
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||||
|
|
||||||
compiled_model = torch.compile(model, backend=backend)
|
compiled_model = torch.compile(model, backend=backend)
|
||||||
compiled_model(hidden_states)
|
compiled_model(hidden_states)
|
||||||
|
|
||||||
|
|||||||
@ -172,7 +172,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
|
|||||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||||
scatter_dim = 0
|
scatter_dim = 0
|
||||||
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||||
input,
|
input,
|
||||||
mat2,
|
mat2,
|
||||||
scale_a,
|
scale_a,
|
||||||
@ -307,7 +307,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
|
|||||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||||
scatter_dim = 0
|
scatter_dim = 0
|
||||||
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||||
input,
|
input,
|
||||||
mat2,
|
mat2,
|
||||||
scale_a,
|
scale_a,
|
||||||
|
|||||||
@ -37,6 +37,8 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
import torch.distributed._functional_collectives as funcol
|
||||||
|
import torch.distributed._symmetric_memory
|
||||||
from torch.distributed import Backend, ProcessGroup
|
from torch.distributed import Backend, ProcessGroup
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
@ -159,6 +161,90 @@ def all_gather_fake(
|
|||||||
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
|
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
|
||||||
|
def patched_fused_scaled_matmul_reduce_scatter_fake(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
A_scale: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor,
|
||||||
|
reduce_op: str,
|
||||||
|
orig_scatter_dim: int,
|
||||||
|
scatter_dim_after_maybe_reshape: int,
|
||||||
|
group_name: str,
|
||||||
|
output_shape: list[int],
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
result_scale: torch.Tensor | None = None,
|
||||||
|
out_dtype: torch.dtype | None = None,
|
||||||
|
use_fast_accum: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Copied from
|
||||||
|
# https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189
|
||||||
|
if A_scale.numel() > 1:
|
||||||
|
if A_scale.shape[:-1] != A.shape[:-1]:
|
||||||
|
raise ValueError(
|
||||||
|
"For row-wise scaling, the leading dims of A_scale "
|
||||||
|
"must match the leading dims of A "
|
||||||
|
f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
|
||||||
|
)
|
||||||
|
A_scale = A_scale.flatten(0, -2).contiguous()
|
||||||
|
elif A_scale.numel() != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid A_scale shape "
|
||||||
|
f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
|
||||||
|
)
|
||||||
|
|
||||||
|
C = torch._scaled_mm(
|
||||||
|
A.flatten(0, -2).contiguous(),
|
||||||
|
B,
|
||||||
|
A_scale,
|
||||||
|
B_scale,
|
||||||
|
bias,
|
||||||
|
result_scale,
|
||||||
|
out_dtype,
|
||||||
|
use_fast_accum,
|
||||||
|
)
|
||||||
|
C = C.view(*output_shape[:-1], B.shape[1])
|
||||||
|
res = funcol.reduce_scatter_tensor(
|
||||||
|
C,
|
||||||
|
reduce_op,
|
||||||
|
orig_scatter_dim, # need original scatter dim for 3D+ output tensor here
|
||||||
|
group_name,
|
||||||
|
)
|
||||||
|
res = funcol.wait_tensor(res)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def patched_fused_scaled_matmul_reduce_scatter(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
A_scale: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor,
|
||||||
|
reduce_op: str,
|
||||||
|
orig_scatter_dim: int,
|
||||||
|
scatter_dim_after_maybe_reshape: int,
|
||||||
|
group_name: str,
|
||||||
|
output_shape: list[int],
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
result_scale: torch.Tensor | None = None,
|
||||||
|
out_dtype: torch.dtype | None = None,
|
||||||
|
use_fast_accum: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
A_scale,
|
||||||
|
B_scale,
|
||||||
|
reduce_op,
|
||||||
|
orig_scatter_dim,
|
||||||
|
scatter_dim_after_maybe_reshape,
|
||||||
|
group_name,
|
||||||
|
output_shape,
|
||||||
|
bias,
|
||||||
|
result_scale,
|
||||||
|
out_dtype,
|
||||||
|
use_fast_accum,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="all_reduce",
|
op_name="all_reduce",
|
||||||
@ -178,6 +264,15 @@ if supports_custom_op():
|
|||||||
fake_impl=all_gather_fake,
|
fake_impl=all_gather_fake,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Remove this once the pytorch fix
|
||||||
|
# (https://github.com/pytorch/pytorch/pull/165086) gets released,
|
||||||
|
# in either 2.9.1 or 2.10
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="patched_fused_scaled_matmul_reduce_scatter",
|
||||||
|
op_func=patched_fused_scaled_matmul_reduce_scatter,
|
||||||
|
fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GroupCoordinator:
|
class GroupCoordinator:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user