mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 21:34:55 +08:00
[compile] Add patched_fused_scaled_matmul_reduce_scatter (#26604)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
parent
d0bed837ac
commit
a25f2adee9
@ -142,7 +142,7 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
||||
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||
@ -195,7 +195,7 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
||||
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||
@ -243,9 +243,15 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dynamic", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_async_tp_pass_replace(
|
||||
test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype
|
||||
test_model: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
dynamic: bool,
|
||||
):
|
||||
if (
|
||||
test_model
|
||||
@ -269,7 +275,15 @@ def test_async_tp_pass_replace(
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
|
||||
args=(
|
||||
num_processes,
|
||||
test_model,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
dynamic,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
@ -284,6 +298,7 @@ def async_tp_pass_on_test_model(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
dynamic: bool,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
@ -331,6 +346,9 @@ def async_tp_pass_on_test_model(
|
||||
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
||||
)
|
||||
|
||||
if dynamic:
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
|
||||
@ -172,7 +172,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
@ -307,7 +307,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
|
||||
@ -37,6 +37,8 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed._symmetric_memory
|
||||
from torch.distributed import Backend, ProcessGroup
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@ -159,6 +161,90 @@ def all_gather_fake(
|
||||
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
|
||||
def patched_fused_scaled_matmul_reduce_scatter_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
reduce_op: str,
|
||||
orig_scatter_dim: int,
|
||||
scatter_dim_after_maybe_reshape: int,
|
||||
group_name: str,
|
||||
output_shape: list[int],
|
||||
bias: torch.Tensor | None = None,
|
||||
result_scale: torch.Tensor | None = None,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
use_fast_accum: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Copied from
|
||||
# https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189
|
||||
if A_scale.numel() > 1:
|
||||
if A_scale.shape[:-1] != A.shape[:-1]:
|
||||
raise ValueError(
|
||||
"For row-wise scaling, the leading dims of A_scale "
|
||||
"must match the leading dims of A "
|
||||
f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
|
||||
)
|
||||
A_scale = A_scale.flatten(0, -2).contiguous()
|
||||
elif A_scale.numel() != 1:
|
||||
raise ValueError(
|
||||
"Invalid A_scale shape "
|
||||
f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
|
||||
)
|
||||
|
||||
C = torch._scaled_mm(
|
||||
A.flatten(0, -2).contiguous(),
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
bias,
|
||||
result_scale,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
)
|
||||
C = C.view(*output_shape[:-1], B.shape[1])
|
||||
res = funcol.reduce_scatter_tensor(
|
||||
C,
|
||||
reduce_op,
|
||||
orig_scatter_dim, # need original scatter dim for 3D+ output tensor here
|
||||
group_name,
|
||||
)
|
||||
res = funcol.wait_tensor(res)
|
||||
return res
|
||||
|
||||
|
||||
def patched_fused_scaled_matmul_reduce_scatter(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
reduce_op: str,
|
||||
orig_scatter_dim: int,
|
||||
scatter_dim_after_maybe_reshape: int,
|
||||
group_name: str,
|
||||
output_shape: list[int],
|
||||
bias: torch.Tensor | None = None,
|
||||
result_scale: torch.Tensor | None = None,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
use_fast_accum: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||
A,
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
reduce_op,
|
||||
orig_scatter_dim,
|
||||
scatter_dim_after_maybe_reshape,
|
||||
group_name,
|
||||
output_shape,
|
||||
bias,
|
||||
result_scale,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
)
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
direct_register_custom_op(
|
||||
op_name="all_reduce",
|
||||
@ -178,6 +264,15 @@ if supports_custom_op():
|
||||
fake_impl=all_gather_fake,
|
||||
)
|
||||
|
||||
# TODO: Remove this once the pytorch fix
|
||||
# (https://github.com/pytorch/pytorch/pull/165086) gets released,
|
||||
# in either 2.9.1 or 2.10
|
||||
direct_register_custom_op(
|
||||
op_name="patched_fused_scaled_matmul_reduce_scatter",
|
||||
op_func=patched_fused_scaled_matmul_reduce_scatter,
|
||||
fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake,
|
||||
)
|
||||
|
||||
|
||||
class GroupCoordinator:
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user