[Feature] Add async tensor parallelism for scaled mm (#20155)

Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
cascade 2025-07-30 14:23:41 -07:00 committed by GitHub
parent f12d9256b3
commit 287f527f54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 381 additions and 8 deletions

View File

@ -22,6 +22,8 @@ from ..utils import (compare_two_settings, create_new_process_for_each_test,
multi_gpu_test)
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
prompts = [
"Hello, my name is",
"The president of the United States is",
@ -32,9 +34,10 @@ prompts = [
class TestMMRSModel(torch.nn.Module):
def __init__(self, hidden_size=16):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.gate_proj = torch.nn.Parameter(torch.empty(
(self.hidden_size * 2, hidden_size)),
requires_grad=False)
@ -64,9 +67,10 @@ class TestMMRSModel(torch.nn.Module):
class TestAGMMModel(torch.nn.Module):
def __init__(self, hidden_size=16):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.weight = torch.nn.Parameter(torch.empty(
(hidden_size, hidden_size)),
requires_grad=False)
@ -91,8 +95,125 @@ class TestAGMMModel(torch.nn.Module):
return [torch.ops.symm_mem.fused_all_gather_matmul.default]
class _BaseScaledMMModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
.contiguous().transpose(0, 1)
# Initialize scale_b for _scaled_mm.
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
class TestScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
"""
fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(fp8_input,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype)
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
return reduce_scatter
def ops_in_model_before(self):
return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
class TestAGScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the all gather + scaled_mm in the FX graph
"""
# Reshape input
fp8_input = input.to(FP8_DTYPE)
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(all_gather,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype)
return scaled_mm
def ops_in_model_before(self):
return [torch.ops.vllm.all_gather.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the cutlass_scaled_mm + reduce scatter
in the FX graph
"""
fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=input.device)
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
self.scale_b, None)
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
return reduce_scatter
def ops_in_model_before(self):
return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the all gather + cutlass_scaled_mm
in the FX graph
"""
# Reshape input
fp8_input = input.to(FP8_DTYPE)
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=all_gather.device)
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
scale_a, self.scale_b, None)
return mm_out
def ops_in_model_before(self):
return [torch.ops.vllm.all_gather.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel])
@pytest.mark.parametrize("test_model", [
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@ -101,6 +222,14 @@ class TestAGMMModel(torch.nn.Module):
reason="Only test on CUDA")
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel) and dtype == torch.float16:
pytest.skip(
"Only bf16 high precision output types are supported for " \
"per-token (row-wise) scaling"
)
num_processes = 2
def run_torch_spawn(fn, nprocs):
@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)
model = test_model_cls(hidden_size)
model = test_model_cls(hidden_size,
dtype) # Pass dtype to model constructor
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype,
@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
@create_new_process_for_each_test()
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"])
@pytest.mark.parametrize("model_id", [
"meta-llama/Llama-3.2-1B-Instruct",
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
])
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("async_tp_enabled", [True])
@pytest.mark.parametrize("distributed_backend", ["mp"])

View File

@ -15,10 +15,13 @@ from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .vllm_inductor_pass import VllmInductorPass
FP8_DTYPE = current_platform.fp8_dtype()
if find_spec("flashinfer"):
try:
import flashinfer.comm as flashinfer_comm
@ -28,7 +31,6 @@ if find_spec("flashinfer"):
flashinfer_comm = None
else:
flashinfer_comm = None
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -118,6 +120,230 @@ class AllGatherGEMMPattern(BasePattern):
pm.fwd_only, pm_pass)
class ScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self):
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = torch.empty([16, 16], device=self.device,
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [input, mm_weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor) -> torch.Tensor:
scaled_mm = torch.ops.aten._scaled_mm.default(input,
mat2=mat2,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
scaled_mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name)
return reduce_scatter
def replacement(input: torch.Tensor, mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor) -> torch.Tensor:
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim=0,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
)
return gemm_rs
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class AllGatherScaledMMPattern(BasePattern):
def get_inputs(self):
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = torch.empty([16, 16], device=self.device,
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [x, weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name)
return torch.ops.aten._scaled_mm.default(all_gather,
mat2=weight,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype)
def replacement(x: torch.Tensor, weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class CutlassScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self):
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = torch.empty([16, 16], device=self.device,
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
cutlass_mm_output = torch.empty([16, 16],
device=self.device,
dtype=self.dtype)
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor) -> torch.Tensor:
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=cutlass_mm_output,
a=input,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
cutlass_scaled_mm[1],
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name)
return reduce_scatter
def replacement(input: torch.Tensor, mat2: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor) -> torch.Tensor:
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim=0,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
)
return gemm_rs
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class AllGatherCutlassScaledMMPattern(BasePattern):
def get_inputs(self):
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = torch.empty([16, 16], device=self.device,
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
s2 = weight.shape[1]
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
return [x, weight, scale_a, scale_b, output]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name)
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=output,
a=all_gather,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None)
return cutlass_scaled_mm[1]
def replacement(x: torch.Tensor, weight: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
output: torch.Tensor) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class AsyncTPPass(VllmInductorPass):
def __init__(self, config: VllmConfig):
@ -133,6 +359,20 @@ class AsyncTPPass(VllmInductorPass):
AllGatherGEMMPattern(self.model_dtype,
self.device).register(self.patterns)
# These fusions are enabled only for bfloat16 models because
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
# only supports bfloat16 as the output dtype.
if self.model_dtype == torch.bfloat16:
ScaledMMReduceScatterPattern(self.model_dtype,
self.device).register(self.patterns)
AllGatherScaledMMPattern(self.model_dtype,
self.device).register(self.patterns)
CutlassScaledMMReduceScatterPattern(
self.model_dtype, self.device).register(self.patterns)
AllGatherCutlassScaledMMPattern(
self.model_dtype, self.device).register(self.patterns)
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# only do replace for specific shapes
tp_size = get_tensor_model_parallel_world_size()
@ -142,7 +382,7 @@ class AsyncTPPass(VllmInductorPass):
self.begin()
self.dump_graph(graph, "before_async_tp_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
logger.debug("Replaced %s patterns with async TP pass.", count)
self.dump_graph(graph, "after_async_tp_pass")
self.end_and_log()

View File

@ -477,6 +477,6 @@ class SequenceParallelismPass(VllmInductorPass):
self.begin()
self.dump_graph(graph, "before_sequence_parallelism_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
logger.debug("Replaced %s patterns with sequence parallelism", count)
self.dump_graph(graph, "after_sequence_parallelism_pass")
self.end_and_log()