mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 03:25:01 +08:00
[Feature] Add async tensor parallelism for scaled mm (#20155)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
parent
f12d9256b3
commit
287f527f54
@ -22,6 +22,8 @@ from ..utils import (compare_two_settings, create_new_process_for_each_test,
|
|||||||
multi_gpu_test)
|
multi_gpu_test)
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
@ -32,9 +34,10 @@ prompts = [
|
|||||||
|
|
||||||
class TestMMRSModel(torch.nn.Module):
|
class TestMMRSModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16):
|
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
self.dtype = dtype
|
||||||
self.gate_proj = torch.nn.Parameter(torch.empty(
|
self.gate_proj = torch.nn.Parameter(torch.empty(
|
||||||
(self.hidden_size * 2, hidden_size)),
|
(self.hidden_size * 2, hidden_size)),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -64,9 +67,10 @@ class TestMMRSModel(torch.nn.Module):
|
|||||||
|
|
||||||
class TestAGMMModel(torch.nn.Module):
|
class TestAGMMModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size=16):
|
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
self.dtype = dtype
|
||||||
self.weight = torch.nn.Parameter(torch.empty(
|
self.weight = torch.nn.Parameter(torch.empty(
|
||||||
(hidden_size, hidden_size)),
|
(hidden_size, hidden_size)),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -91,8 +95,125 @@ class TestAGMMModel(torch.nn.Module):
|
|||||||
return [torch.ops.symm_mem.fused_all_gather_matmul.default]
|
return [torch.ops.symm_mem.fused_all_gather_matmul.default]
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseScaledMMModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.dtype = dtype
|
||||||
|
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
|
||||||
|
.contiguous().transpose(0, 1)
|
||||||
|
|
||||||
|
# Initialize scale_b for _scaled_mm.
|
||||||
|
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScaledMMRSModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
|
||||||
|
|
||||||
|
"""
|
||||||
|
fp8_input = input.to(FP8_DTYPE)
|
||||||
|
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||||
|
scaled_mm = torch._scaled_mm(fp8_input,
|
||||||
|
self.weight,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=self.scale_b,
|
||||||
|
out_dtype=self.dtype)
|
||||||
|
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
|
||||||
|
return reduce_scatter
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [torch.ops.vllm.reduce_scatter.default]
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Forward pass implementing the all gather + scaled_mm in the FX graph
|
||||||
|
"""
|
||||||
|
# Reshape input
|
||||||
|
fp8_input = input.to(FP8_DTYPE)
|
||||||
|
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
||||||
|
|
||||||
|
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||||
|
scaled_mm = torch._scaled_mm(all_gather,
|
||||||
|
self.weight,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=self.scale_b,
|
||||||
|
out_dtype=self.dtype)
|
||||||
|
return scaled_mm
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [torch.ops.vllm.all_gather.default]
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Forward pass implementing the cutlass_scaled_mm + reduce scatter
|
||||||
|
in the FX graph
|
||||||
|
|
||||||
|
"""
|
||||||
|
fp8_input = input.to(FP8_DTYPE)
|
||||||
|
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||||
|
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=input.device)
|
||||||
|
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
|
||||||
|
self.scale_b, None)
|
||||||
|
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
|
||||||
|
return reduce_scatter
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [torch.ops.vllm.reduce_scatter.default]
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Forward pass implementing the all gather + cutlass_scaled_mm
|
||||||
|
in the FX graph
|
||||||
|
"""
|
||||||
|
# Reshape input
|
||||||
|
fp8_input = input.to(FP8_DTYPE)
|
||||||
|
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
||||||
|
|
||||||
|
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||||
|
|
||||||
|
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=all_gather.device)
|
||||||
|
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
|
||||||
|
scale_a, self.scale_b, None)
|
||||||
|
return mm_out
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [torch.ops.vllm.all_gather.default]
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel])
|
@pytest.mark.parametrize("test_model", [
|
||||||
|
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
|
||||||
|
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
|
||||||
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("seq_len", [16])
|
@pytest.mark.parametrize("seq_len", [16])
|
||||||
@pytest.mark.parametrize("hidden_size", [16])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@ -101,6 +222,14 @@ class TestAGMMModel(torch.nn.Module):
|
|||||||
reason="Only test on CUDA")
|
reason="Only test on CUDA")
|
||||||
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
|
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
|
||||||
hidden_size: int, dtype: torch.dtype):
|
hidden_size: int, dtype: torch.dtype):
|
||||||
|
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
|
||||||
|
TestCutlassScaledMMRSModel,
|
||||||
|
TestAGCutlassScaledMMModel) and dtype == torch.float16:
|
||||||
|
pytest.skip(
|
||||||
|
"Only bf16 high precision output types are supported for " \
|
||||||
|
"per-token (row-wise) scaling"
|
||||||
|
)
|
||||||
|
|
||||||
num_processes = 2
|
num_processes = 2
|
||||||
|
|
||||||
def run_torch_spawn(fn, nprocs):
|
def run_torch_spawn(fn, nprocs):
|
||||||
@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
async_tp_pass = AsyncTPPass(vllm_config)
|
async_tp_pass = AsyncTPPass(vllm_config)
|
||||||
backend = TestBackend(async_tp_pass)
|
backend = TestBackend(async_tp_pass)
|
||||||
|
|
||||||
model = test_model_cls(hidden_size)
|
model = test_model_cls(hidden_size,
|
||||||
|
dtype) # Pass dtype to model constructor
|
||||||
|
|
||||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
|
|
||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"])
|
@pytest.mark.parametrize("model_id", [
|
||||||
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
|
||||||
|
])
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
@pytest.mark.parametrize("async_tp_enabled", [True])
|
@pytest.mark.parametrize("async_tp_enabled", [True])
|
||||||
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
||||||
|
|||||||
@ -15,10 +15,13 @@ from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
|||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
if find_spec("flashinfer"):
|
if find_spec("flashinfer"):
|
||||||
try:
|
try:
|
||||||
import flashinfer.comm as flashinfer_comm
|
import flashinfer.comm as flashinfer_comm
|
||||||
@ -28,7 +31,6 @@ if find_spec("flashinfer"):
|
|||||||
flashinfer_comm = None
|
flashinfer_comm = None
|
||||||
else:
|
else:
|
||||||
flashinfer_comm = None
|
flashinfer_comm = None
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -118,6 +120,230 @@ class AllGatherGEMMPattern(BasePattern):
|
|||||||
pm.fwd_only, pm_pass)
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledMMReduceScatterPattern(BasePattern):
|
||||||
|
|
||||||
|
def get_inputs(self):
|
||||||
|
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||||
|
mm_weight = torch.empty([16, 16], device=self.device,
|
||||||
|
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
|
||||||
|
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||||
|
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||||
|
return [input, mm_weight, scale_a, scale_b]
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def pattern(input: torch.Tensor, mat2: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor) -> torch.Tensor:
|
||||||
|
scaled_mm = torch.ops.aten._scaled_mm.default(input,
|
||||||
|
mat2=mat2,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
bias=None,
|
||||||
|
scale_result=None,
|
||||||
|
out_dtype=self.dtype)
|
||||||
|
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||||
|
scaled_mm,
|
||||||
|
dim=0,
|
||||||
|
world_size=self.tp_size,
|
||||||
|
group_name=self.tp.unique_name)
|
||||||
|
return reduce_scatter
|
||||||
|
|
||||||
|
def replacement(input: torch.Tensor, mat2: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor) -> torch.Tensor:
|
||||||
|
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||||
|
input,
|
||||||
|
mat2,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
"avg",
|
||||||
|
scatter_dim=0,
|
||||||
|
out_dtype=self.dtype,
|
||||||
|
group_name=self.tp.device_group.group_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return gemm_rs
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AllGatherScaledMMPattern(BasePattern):
|
||||||
|
|
||||||
|
def get_inputs(self):
|
||||||
|
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||||
|
weight = torch.empty([16, 16], device=self.device,
|
||||||
|
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
|
||||||
|
|
||||||
|
s1 = x.shape[0] * self.tp_size
|
||||||
|
|
||||||
|
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||||
|
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
return [x, weight, scale_a, scale_b]
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def pattern(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_gather = torch.ops.vllm.all_gather.default(
|
||||||
|
x,
|
||||||
|
dim=0,
|
||||||
|
world_size=self.tp_size,
|
||||||
|
group_name=self.tp.unique_name)
|
||||||
|
|
||||||
|
return torch.ops.aten._scaled_mm.default(all_gather,
|
||||||
|
mat2=weight,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
bias=None,
|
||||||
|
scale_result=None,
|
||||||
|
out_dtype=self.dtype)
|
||||||
|
|
||||||
|
def replacement(x: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor) -> torch.Tensor:
|
||||||
|
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||||
|
x,
|
||||||
|
[weight],
|
||||||
|
scale_a,
|
||||||
|
[scale_b],
|
||||||
|
gather_dim=0,
|
||||||
|
biases=[None],
|
||||||
|
result_scales=[None],
|
||||||
|
out_dtypes=[self.dtype],
|
||||||
|
use_fast_accum=[False],
|
||||||
|
group_name=self.tp.device_group.group_name,
|
||||||
|
)
|
||||||
|
return mm_outputs
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||||
|
|
||||||
|
def get_inputs(self):
|
||||||
|
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||||
|
mm_weight = torch.empty([16, 16], device=self.device,
|
||||||
|
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
|
||||||
|
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||||
|
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
cutlass_mm_output = torch.empty([16, 16],
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype)
|
||||||
|
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def pattern(input: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||||
|
cutlass_mm_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||||
|
torch.ops._C.cutlass_scaled_mm.default,
|
||||||
|
out=cutlass_mm_output,
|
||||||
|
a=input,
|
||||||
|
b=weight,
|
||||||
|
a_scales=scale_a,
|
||||||
|
b_scales=scale_b,
|
||||||
|
bias=None)
|
||||||
|
|
||||||
|
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||||
|
cutlass_scaled_mm[1],
|
||||||
|
dim=0,
|
||||||
|
world_size=self.tp_size,
|
||||||
|
group_name=self.tp.unique_name)
|
||||||
|
return reduce_scatter
|
||||||
|
|
||||||
|
def replacement(input: torch.Tensor, mat2: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||||
|
cutlass_mm_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||||
|
input,
|
||||||
|
mat2,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
"avg",
|
||||||
|
scatter_dim=0,
|
||||||
|
out_dtype=self.dtype,
|
||||||
|
group_name=self.tp.device_group.group_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return gemm_rs
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||||
|
|
||||||
|
def get_inputs(self):
|
||||||
|
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||||
|
weight = torch.empty([16, 16], device=self.device,
|
||||||
|
dtype=FP8_DTYPE).contiguous().transpose(0, 1)
|
||||||
|
|
||||||
|
s1 = x.shape[0] * self.tp_size
|
||||||
|
|
||||||
|
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||||
|
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
s2 = weight.shape[1]
|
||||||
|
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
return [x, weight, scale_a, scale_b, output]
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
|
||||||
|
def pattern(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_gather = torch.ops.vllm.all_gather.default(
|
||||||
|
x,
|
||||||
|
dim=0,
|
||||||
|
world_size=self.tp_size,
|
||||||
|
group_name=self.tp.unique_name)
|
||||||
|
|
||||||
|
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||||
|
torch.ops._C.cutlass_scaled_mm.default,
|
||||||
|
out=output,
|
||||||
|
a=all_gather,
|
||||||
|
b=weight,
|
||||||
|
a_scales=scale_a,
|
||||||
|
b_scales=scale_b,
|
||||||
|
bias=None)
|
||||||
|
return cutlass_scaled_mm[1]
|
||||||
|
|
||||||
|
def replacement(x: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||||
|
output: torch.Tensor) -> torch.Tensor:
|
||||||
|
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||||
|
x,
|
||||||
|
[weight],
|
||||||
|
scale_a,
|
||||||
|
[scale_b],
|
||||||
|
gather_dim=0,
|
||||||
|
biases=[None],
|
||||||
|
result_scales=[None],
|
||||||
|
out_dtypes=[self.dtype],
|
||||||
|
use_fast_accum=[False],
|
||||||
|
group_name=self.tp.device_group.group_name,
|
||||||
|
)
|
||||||
|
return mm_outputs
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||||
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
class AsyncTPPass(VllmInductorPass):
|
class AsyncTPPass(VllmInductorPass):
|
||||||
|
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
@ -133,6 +359,20 @@ class AsyncTPPass(VllmInductorPass):
|
|||||||
AllGatherGEMMPattern(self.model_dtype,
|
AllGatherGEMMPattern(self.model_dtype,
|
||||||
self.device).register(self.patterns)
|
self.device).register(self.patterns)
|
||||||
|
|
||||||
|
# These fusions are enabled only for bfloat16 models because
|
||||||
|
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
|
||||||
|
# only supports bfloat16 as the output dtype.
|
||||||
|
if self.model_dtype == torch.bfloat16:
|
||||||
|
ScaledMMReduceScatterPattern(self.model_dtype,
|
||||||
|
self.device).register(self.patterns)
|
||||||
|
AllGatherScaledMMPattern(self.model_dtype,
|
||||||
|
self.device).register(self.patterns)
|
||||||
|
|
||||||
|
CutlassScaledMMReduceScatterPattern(
|
||||||
|
self.model_dtype, self.device).register(self.patterns)
|
||||||
|
AllGatherCutlassScaledMMPattern(
|
||||||
|
self.model_dtype, self.device).register(self.patterns)
|
||||||
|
|
||||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||||
# only do replace for specific shapes
|
# only do replace for specific shapes
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -142,7 +382,7 @@ class AsyncTPPass(VllmInductorPass):
|
|||||||
self.begin()
|
self.begin()
|
||||||
self.dump_graph(graph, "before_async_tp_pass")
|
self.dump_graph(graph, "before_async_tp_pass")
|
||||||
count = self.patterns.apply(graph)
|
count = self.patterns.apply(graph)
|
||||||
logger.debug("Replaced %s patterns", count)
|
logger.debug("Replaced %s patterns with async TP pass.", count)
|
||||||
self.dump_graph(graph, "after_async_tp_pass")
|
self.dump_graph(graph, "after_async_tp_pass")
|
||||||
self.end_and_log()
|
self.end_and_log()
|
||||||
|
|
||||||
|
|||||||
@ -477,6 +477,6 @@ class SequenceParallelismPass(VllmInductorPass):
|
|||||||
self.begin()
|
self.begin()
|
||||||
self.dump_graph(graph, "before_sequence_parallelism_pass")
|
self.dump_graph(graph, "before_sequence_parallelism_pass")
|
||||||
count = self.patterns.apply(graph)
|
count = self.patterns.apply(graph)
|
||||||
logger.debug("Replaced %s patterns", count)
|
logger.debug("Replaced %s patterns with sequence parallelism", count)
|
||||||
self.dump_graph(graph, "after_sequence_parallelism_pass")
|
self.dump_graph(graph, "after_sequence_parallelism_pass")
|
||||||
self.end_and_log()
|
self.end_and_log()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user