mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 03:25:30 +08:00
[BugFix][torch.compile] Fix fused_scaled_matmul_reduce_scatter signature for PyTorch 2.8 (#26038)
Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com> Signed-off-by: <> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
1e6848a65d
commit
f4ba2061cf
@ -400,8 +400,6 @@ steps:
|
|||||||
- pytest -v -s compile/test_fusion_attn.py
|
- pytest -v -s compile/test_fusion_attn.py
|
||||||
- pytest -v -s compile/test_functionalization.py
|
- pytest -v -s compile/test_functionalization.py
|
||||||
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
||||||
- pytest -v -s compile/test_sequence_parallelism.py
|
|
||||||
- pytest -v -s compile/test_async_tp.py
|
|
||||||
- pytest -v -s compile/test_fusion_all_reduce.py
|
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||||
- pytest -v -s compile/test_decorator.py
|
- pytest -v -s compile/test_decorator.py
|
||||||
- pytest -v -s compile/test_noop_elimination.py
|
- pytest -v -s compile/test_noop_elimination.py
|
||||||
@ -1093,6 +1091,8 @@ steps:
|
|||||||
working_dir: "/vllm-workspace/"
|
working_dir: "/vllm-workspace/"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
commands:
|
commands:
|
||||||
|
- pytest -v -s tests/compile/test_async_tp.py
|
||||||
|
- pytest -v -s tests/compile/test_sequence_parallelism.py
|
||||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||||
|
|
||||||
|
|||||||
@ -169,15 +169,23 @@ class ScaledMMReduceScatterPattern(BasePattern):
|
|||||||
scale_a: torch.Tensor,
|
scale_a: torch.Tensor,
|
||||||
scale_b: torch.Tensor,
|
scale_b: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||||
|
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||||
|
scatter_dim = 0
|
||||||
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||||
input,
|
input,
|
||||||
mat2,
|
mat2,
|
||||||
scale_a,
|
scale_a,
|
||||||
scale_b,
|
scale_b,
|
||||||
"avg",
|
"avg",
|
||||||
scatter_dim=0,
|
scatter_dim, # orig_scatter_dim
|
||||||
out_dtype=self.dtype,
|
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||||
group_name=self.tp.device_group.group_name,
|
self.tp.device_group.group_name,
|
||||||
|
output_shape,
|
||||||
|
None, # bias
|
||||||
|
None, # result_scale
|
||||||
|
self.dtype, # out_dtype
|
||||||
|
False, # use_fast_accum
|
||||||
)
|
)
|
||||||
|
|
||||||
return gemm_rs
|
return gemm_rs
|
||||||
@ -296,15 +304,23 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
|
|||||||
scale_b: torch.Tensor,
|
scale_b: torch.Tensor,
|
||||||
cutlass_mm_output: torch.Tensor,
|
cutlass_mm_output: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||||
|
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||||
|
scatter_dim = 0
|
||||||
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||||
input,
|
input,
|
||||||
mat2,
|
mat2,
|
||||||
scale_a,
|
scale_a,
|
||||||
scale_b,
|
scale_b,
|
||||||
"avg",
|
"avg",
|
||||||
scatter_dim=0,
|
scatter_dim, # orig_scatter_dim
|
||||||
out_dtype=self.dtype,
|
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||||
group_name=self.tp.device_group.group_name,
|
self.tp.device_group.group_name,
|
||||||
|
output_shape,
|
||||||
|
None, # bias
|
||||||
|
None, # result_scale
|
||||||
|
self.dtype, # out_dtype
|
||||||
|
False, # use_fast_accum
|
||||||
)
|
)
|
||||||
|
|
||||||
return gemm_rs
|
return gemm_rs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user