mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:45:49 +08:00
[Feature] Add async tensor parallelism for scaled mm (#20155)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
parent
f12d9256b3
commit
287f527f54
@ -22,6 +22,8 @@ from ..utils import (compare_two_settings, create_new_process_for_each_test,
|
||||
multi_gpu_test)
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -32,9 +34,10 @@ prompts = [
|
||||
|
||||
class TestMMRSModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16):
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.gate_proj = torch.nn.Parameter(torch.empty(
|
||||
(self.hidden_size * 2, hidden_size)),
|
||||
requires_grad=False)
|
||||
@ -64,9 +67,10 @@ class TestMMRSModel(torch.nn.Module):
|
||||
|
||||
class TestAGMMModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16):
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.weight = torch.nn.Parameter(torch.empty(
|
||||
(hidden_size, hidden_size)),
|
||||
requires_grad=False)
|
||||
@ -91,8 +95,125 @@ class TestAGMMModel(torch.nn.Module):
|
||||
return [torch.ops.symm_mem.fused_all_gather_matmul.default]
|
||||
|
||||
|
||||
class _BaseScaledMMModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
|
||||
.contiguous().transpose(0, 1)
|
||||
|
||||
# Initialize scale_b for _scaled_mm.
|
||||
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
|
||||
|
||||
|
||||
class TestScaledMMRSModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
|
||||
|
||||
"""
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||
scaled_mm = torch._scaled_mm(fp8_input,
|
||||
self.weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=self.scale_b,
|
||||
out_dtype=self.dtype)
|
||||
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
|
||||
return reduce_scatter
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the all gather + scaled_mm in the FX graph
|
||||
"""
|
||||
# Reshape input
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
||||
|
||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||
scaled_mm = torch._scaled_mm(all_gather,
|
||||
self.weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=self.scale_b,
|
||||
out_dtype=self.dtype)
|
||||
return scaled_mm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_gather.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
|
||||
|
||||
|
||||
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the cutlass_scaled_mm + reduce scatter
|
||||
in the FX graph
|
||||
|
||||
"""
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
|
||||
dtype=self.dtype,
|
||||
device=input.device)
|
||||
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
|
||||
self.scale_b, None)
|
||||
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
|
||||
return reduce_scatter
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the all gather + cutlass_scaled_mm
|
||||
in the FX graph
|
||||
"""
|
||||
# Reshape input
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
||||
|
||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||
|
||||
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
|
||||
dtype=self.dtype,
|
||||
device=all_gather.device)
|
||||
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
|
||||
scale_a, self.scale_b, None)
|
||||
return mm_out
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_gather.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel])
|
||||
@pytest.mark.parametrize("test_model", [
|
||||
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
|
||||
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@ -101,6 +222,14 @@ class TestAGMMModel(torch.nn.Module):
|
||||
reason="Only test on CUDA")
|
||||
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
|
||||
TestCutlassScaledMMRSModel,
|
||||
TestAGCutlassScaledMMModel) and dtype == torch.float16:
|
||||
pytest.skip(
|
||||
"Only bf16 high precision output types are supported for " \
|
||||
"per-token (row-wise) scaling"
|
||||
)
|
||||
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
async_tp_pass = AsyncTPPass(vllm_config)
|
||||
backend = TestBackend(async_tp_pass)
|
||||
|
||||
model = test_model_cls(hidden_size)
|
||||
model = test_model_cls(hidden_size,
|
||||
dtype) # Pass dtype to model constructor
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
dtype=dtype,
|
||||
@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"])
|
||||
@pytest.mark.parametrize("model_id", [
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
|
||||
])
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("async_tp_enabled", [True])
|
||||
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
||||
|
||||
@ -15,10 +15,13 @@ from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
if find_spec("flashinfer"):
|
||||
try:
|
||||
import flashinfer.comm as flashinfer_comm
|
||||
@ -28,7 +31,6 @@ if find_spec("flashinfer"):
|
||||
flashinfer_comm = None
|
||||
else:
|
||||
flashinfer_comm = None
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -118,6 +120,230 @@ class AllGatherGEMMPattern(BasePattern):
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class ScaledMMReduceScatterPattern(BasePattern):
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = torch.empty([16, 16], device=self.device,
|
||||
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
return [input, mm_weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(input: torch.Tensor, mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor) -> torch.Tensor:
|
||||
scaled_mm = torch.ops.aten._scaled_mm.default(input,
|
||||
mat2=mat2,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
scaled_mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(input: torch.Tensor, mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor) -> torch.Tensor:
|
||||
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim=0,
|
||||
out_dtype=self.dtype,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AllGatherScaledMMPattern(BasePattern):
|
||||
|
||||
def get_inputs(self):
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = torch.empty([16, 16], device=self.device,
|
||||
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [x, weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name)
|
||||
|
||||
return torch.ops.aten._scaled_mm.default(all_gather,
|
||||
mat2=weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype)
|
||||
|
||||
def replacement(x: torch.Tensor, weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = torch.empty([16, 16], device=self.device,
|
||||
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
cutlass_mm_output = torch.empty([16, 16],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor) -> torch.Tensor:
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=cutlass_mm_output,
|
||||
a=input,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None)
|
||||
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
cutlass_scaled_mm[1],
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(input: torch.Tensor, mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor) -> torch.Tensor:
|
||||
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim=0,
|
||||
out_dtype=self.dtype,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
|
||||
def get_inputs(self):
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = torch.empty([16, 16], device=self.device,
|
||||
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
s2 = weight.shape[1]
|
||||
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight, scale_a, scale_b, output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name)
|
||||
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=output,
|
||||
a=all_gather,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None)
|
||||
return cutlass_scaled_mm[1]
|
||||
|
||||
def replacement(x: torch.Tensor, weight: torch.Tensor,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||
output: torch.Tensor) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AsyncTPPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
@ -133,6 +359,20 @@ class AsyncTPPass(VllmInductorPass):
|
||||
AllGatherGEMMPattern(self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
# These fusions are enabled only for bfloat16 models because
|
||||
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
|
||||
# only supports bfloat16 as the output dtype.
|
||||
if self.model_dtype == torch.bfloat16:
|
||||
ScaledMMReduceScatterPattern(self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
AllGatherScaledMMPattern(self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
CutlassScaledMMReduceScatterPattern(
|
||||
self.model_dtype, self.device).register(self.patterns)
|
||||
AllGatherCutlassScaledMMPattern(
|
||||
self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
# only do replace for specific shapes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -142,7 +382,7 @@ class AsyncTPPass(VllmInductorPass):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_async_tp_pass")
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", count)
|
||||
logger.debug("Replaced %s patterns with async TP pass.", count)
|
||||
self.dump_graph(graph, "after_async_tp_pass")
|
||||
self.end_and_log()
|
||||
|
||||
|
||||
@ -477,6 +477,6 @@ class SequenceParallelismPass(VllmInductorPass):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_sequence_parallelism_pass")
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", count)
|
||||
logger.debug("Replaced %s patterns with sequence parallelism", count)
|
||||
self.dump_graph(graph, "after_sequence_parallelism_pass")
|
||||
self.end_and_log()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user