diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 88ad4f81df505..d396d3940f67f 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -142,7 +142,7 @@ class TestScaledMMRSModel(_BaseScaledMMModel): return [torch.ops.vllm.reduce_scatter.default] def ops_in_model_after(self): - return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default] + return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default] class TestAGScaledMMModel(_BaseScaledMMModel): @@ -195,7 +195,7 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel): return [torch.ops.vllm.reduce_scatter.default] def ops_in_model_after(self): - return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default] + return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default] class TestAGCutlassScaledMMModel(_BaseScaledMMModel): @@ -243,9 +243,15 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel): @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dynamic", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") def test_async_tp_pass_replace( - test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype + test_model: str, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, + dynamic: bool, ): if ( test_model @@ -269,7 +275,15 @@ def test_async_tp_pass_replace( # torch.distributed and cuda torch.multiprocessing.spawn( fn, - args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), + args=( + num_processes, + test_model, + batch_size, + seq_len, + hidden_size, + dtype, + dynamic, + ), nprocs=nprocs, ) @@ -284,6 +298,7 @@ def async_tp_pass_on_test_model( seq_len: int, hidden_size: int, dtype: torch.dtype, + dynamic: bool, ): current_platform.seed_everything(0) @@ -331,6 +346,9 @@ def async_tp_pass_on_test_model( (batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False ) + if dynamic: + torch._dynamo.mark_dynamic(hidden_states, 0) + compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 970d390f32b45..988a1069cd9e7 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -172,7 +172,7 @@ class ScaledMMReduceScatterPattern(BasePattern): # Calculate output shape: input @ mat2 with scatter_dim reduced output_shape = [*input.shape[:-1], mat2.shape[1]] scatter_dim = 0 - gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, @@ -307,7 +307,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern): # Calculate output shape: input @ mat2 with scatter_dim reduced output_shape = [*input.shape[:-1], mat2.shape[1]] scatter_dim = 0 - gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index aee5507ade467..cb5a75c59f096 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -37,6 +37,8 @@ from unittest.mock import patch import torch import torch.distributed +import torch.distributed._functional_collectives as funcol +import torch.distributed._symmetric_memory from torch.distributed import Backend, ProcessGroup from typing_extensions import deprecated @@ -159,6 +161,90 @@ def all_gather_fake( return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) +def patched_fused_scaled_matmul_reduce_scatter_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + # Copied from + # https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189 + if A_scale.numel() > 1: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = A_scale.flatten(0, -2).contiguous() + elif A_scale.numel() != 1: + raise ValueError( + "Invalid A_scale shape " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + + C = torch._scaled_mm( + A.flatten(0, -2).contiguous(), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + C = C.view(*output_shape[:-1], B.shape[1]) + res = funcol.reduce_scatter_tensor( + C, + reduce_op, + orig_scatter_dim, # need original scatter dim for 3D+ output tensor here + group_name, + ) + res = funcol.wait_tensor(res) + return res + + +def patched_fused_scaled_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + A, + B, + A_scale, + B_scale, + reduce_op, + orig_scatter_dim, + scatter_dim_after_maybe_reshape, + group_name, + output_shape, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + + if supports_custom_op(): direct_register_custom_op( op_name="all_reduce", @@ -178,6 +264,15 @@ if supports_custom_op(): fake_impl=all_gather_fake, ) + # TODO: Remove this once the pytorch fix + # (https://github.com/pytorch/pytorch/pull/165086) gets released, + # in either 2.9.1 or 2.10 + direct_register_custom_op( + op_name="patched_fused_scaled_matmul_reduce_scatter", + op_func=patched_fused_scaled_matmul_reduce_scatter, + fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake, + ) + class GroupCoordinator: """