[compile] Add patched_fused_scaled_matmul_reduce_scatter (#26604)

Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
Angela Yi 2025-10-11 05:44:43 -07:00 committed by GitHub
parent d0bed837ac
commit a25f2adee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 119 additions and 6 deletions

View File

@ -142,7 +142,7 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
class TestAGScaledMMModel(_BaseScaledMMModel):
@ -195,7 +195,7 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
@ -243,9 +243,15 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dynamic", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_async_tp_pass_replace(
test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype
test_model: str,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
dynamic: bool,
):
if (
test_model
@ -269,7 +275,15 @@ def test_async_tp_pass_replace(
# torch.distributed and cuda
torch.multiprocessing.spawn(
fn,
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
args=(
num_processes,
test_model,
batch_size,
seq_len,
hidden_size,
dtype,
dynamic,
),
nprocs=nprocs,
)
@ -284,6 +298,7 @@ def async_tp_pass_on_test_model(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
dynamic: bool,
):
current_platform.seed_everything(0)
@ -331,6 +346,9 @@ def async_tp_pass_on_test_model(
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
)
if dynamic:
torch._dynamo.mark_dynamic(hidden_states, 0)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)

View File

@ -172,7 +172,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
@ -307,7 +307,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,

View File

@ -37,6 +37,8 @@ from unittest.mock import patch
import torch
import torch.distributed
import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory
from torch.distributed import Backend, ProcessGroup
from typing_extensions import deprecated
@ -159,6 +161,90 @@ def all_gather_fake(
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
def patched_fused_scaled_matmul_reduce_scatter_fake(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
reduce_op: str,
orig_scatter_dim: int,
scatter_dim_after_maybe_reshape: int,
group_name: str,
output_shape: list[int],
bias: torch.Tensor | None = None,
result_scale: torch.Tensor | None = None,
out_dtype: torch.dtype | None = None,
use_fast_accum: bool = False,
) -> torch.Tensor:
# Copied from
# https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189
if A_scale.numel() > 1:
if A_scale.shape[:-1] != A.shape[:-1]:
raise ValueError(
"For row-wise scaling, the leading dims of A_scale "
"must match the leading dims of A "
f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
)
A_scale = A_scale.flatten(0, -2).contiguous()
elif A_scale.numel() != 1:
raise ValueError(
"Invalid A_scale shape "
f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
)
C = torch._scaled_mm(
A.flatten(0, -2).contiguous(),
B,
A_scale,
B_scale,
bias,
result_scale,
out_dtype,
use_fast_accum,
)
C = C.view(*output_shape[:-1], B.shape[1])
res = funcol.reduce_scatter_tensor(
C,
reduce_op,
orig_scatter_dim, # need original scatter dim for 3D+ output tensor here
group_name,
)
res = funcol.wait_tensor(res)
return res
def patched_fused_scaled_matmul_reduce_scatter(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
reduce_op: str,
orig_scatter_dim: int,
scatter_dim_after_maybe_reshape: int,
group_name: str,
output_shape: list[int],
bias: torch.Tensor | None = None,
result_scale: torch.Tensor | None = None,
out_dtype: torch.dtype | None = None,
use_fast_accum: bool = False,
) -> torch.Tensor:
return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
A,
B,
A_scale,
B_scale,
reduce_op,
orig_scatter_dim,
scatter_dim_after_maybe_reshape,
group_name,
output_shape,
bias,
result_scale,
out_dtype,
use_fast_accum,
)
if supports_custom_op():
direct_register_custom_op(
op_name="all_reduce",
@ -178,6 +264,15 @@ if supports_custom_op():
fake_impl=all_gather_fake,
)
# TODO: Remove this once the pytorch fix
# (https://github.com/pytorch/pytorch/pull/165086) gets released,
# in either 2.9.1 or 2.10
direct_register_custom_op(
op_name="patched_fused_scaled_matmul_reduce_scatter",
op_func=patched_fused_scaled_matmul_reduce_scatter,
fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake,
)
class GroupCoordinator:
"""