mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-25 11:48:01 +08:00
Merge branch 'main' into elvischenv/update-flashinfer
This commit is contained in:
commit
94f7c225e8
@ -104,7 +104,6 @@ def run_benchmark_with_batch_invariant(
|
||||
random.seed(seed)
|
||||
|
||||
# Set environment variables
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
if batch_invariant:
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||
else:
|
||||
@ -140,6 +139,7 @@ def run_benchmark_with_batch_invariant(
|
||||
max_model_len=max_model_len,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=tp_size,
|
||||
attention_config={"backend": backend},
|
||||
enable_prefix_caching=False,
|
||||
)
|
||||
init_time = time.perf_counter() - start_init
|
||||
|
||||
@ -2,7 +2,7 @@ FROM intel/deep-learning-essentials:2025.2.2-0-devel-ubuntu24.04 AS vllm-base
|
||||
|
||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \
|
||||
add-apt-repository -y ppa:kobuk-team/intel-graphics
|
||||
add-apt-repository -y ppa:kobuk-team/intel-graphics-staging
|
||||
|
||||
RUN apt clean && apt-get update -y && \
|
||||
apt-get install -y --no-install-recommends --fix-missing \
|
||||
@ -47,6 +47,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --no-cache-dir \
|
||||
-r requirements/xpu.txt
|
||||
|
||||
# arctic-inference is built from source which needs torch-xpu properly installed
|
||||
# used for suffix method speculative decoding
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --no-cache-dir arctic-inference==0.1.1
|
||||
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/"
|
||||
|
||||
COPY . .
|
||||
|
||||
@ -64,7 +64,7 @@ th:not(:first-child) {
|
||||
| [CP](../configuration/optimization.md#chunked-prefill) | [❌](https://github.com/vllm-project/vllm/issues/2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [APC](automatic_prefix_caching.md) | [❌](https://github.com/vllm-project/vllm/issues/3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [🟠](https://github.com/vllm-project/vllm/issues/26963) |
|
||||
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ |
|
||||
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/26970) |
|
||||
| [pooling](../models/pooling_models.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
|
||||
|
||||
@ -557,7 +557,8 @@ def test_rms_group_quant(
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs["attention_config"] = {"backend": backend.name}
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
|
||||
@ -77,6 +77,7 @@ def test_dynamic_shapes_compilation(
|
||||
"evaluate_guards": evaluate_guards,
|
||||
},
|
||||
},
|
||||
max_model_len=1024,
|
||||
)
|
||||
|
||||
output = model.generate(prompt)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -53,37 +52,61 @@ class TestModel(torch.nn.Module):
|
||||
hidden_size: int,
|
||||
eps: float,
|
||||
group_shape: GroupShape,
|
||||
cuda_force_torch: bool,
|
||||
use_aiter: bool = False,
|
||||
cuda_force_torch: bool = False,
|
||||
use_aiter_quant_op: bool = True,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.use_aiter = use_aiter
|
||||
self.use_aiter_quant_op = use_aiter_quant_op
|
||||
self.cuda_force_torch = cuda_force_torch
|
||||
self.group_shape = group_shape
|
||||
self.enable_quant_fp8_custom_op = None # Will be set later if applicable
|
||||
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
|
||||
if group_shape.is_per_group():
|
||||
self.wscale = [
|
||||
torch.rand(
|
||||
(hidden_size // group_shape[1], hidden_size // group_shape[1]),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
else:
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
static = group_shape == GroupShape.PER_TENSOR
|
||||
|
||||
# Setup quantization scale descriptor
|
||||
static = group_shape == GroupShape.PER_TENSOR and not use_aiter
|
||||
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
|
||||
# Setup scales
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
else:
|
||||
self.scale = [None for _ in range(3)]
|
||||
|
||||
# Setup weights
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
|
||||
]
|
||||
if not group_shape.is_per_group():
|
||||
if not group_shape.is_per_group() or use_aiter:
|
||||
self.w = [self.w[0].t() for _ in range(3)]
|
||||
|
||||
# Setup weight scales
|
||||
if group_shape.is_per_group():
|
||||
scale_size = (
|
||||
(hidden_size + 128 - 1) // 128
|
||||
if use_aiter
|
||||
else hidden_size // group_shape[1]
|
||||
)
|
||||
wscale_shape: tuple[int, ...] = (scale_size, scale_size)
|
||||
else:
|
||||
wscale_shape = (1,)
|
||||
self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
# Setup FP8 linear operation
|
||||
is_per_group = group_shape.is_per_group()
|
||||
if is_per_group and use_aiter:
|
||||
self.fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(128, 128),
|
||||
act_quant_group_shape=group_shape,
|
||||
use_aiter_and_is_supported=use_aiter_quant_op,
|
||||
)
|
||||
# AITER blockwise doesn't use enable_quant_fp8_custom_op
|
||||
elif is_per_group:
|
||||
self.fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||
act_quant_group_shape=group_shape,
|
||||
@ -91,6 +114,13 @@ class TestModel(torch.nn.Module):
|
||||
use_aiter_and_is_supported=False,
|
||||
)
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
|
||||
elif use_aiter:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False,
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
else:
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
@ -100,7 +130,6 @@ class TestModel(torch.nn.Module):
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
self.group_shape = group_shape
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
@ -126,19 +155,49 @@ class TestModel(torch.nn.Module):
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
if (
|
||||
self.use_aiter
|
||||
and self.group_shape.is_per_group()
|
||||
and current_platform.is_fp8_fnuz()
|
||||
):
|
||||
return [rocm_aiter_ops.get_group_quant_op()]
|
||||
if self.use_aiter and self.group_shape.is_per_group():
|
||||
return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
|
||||
if self.use_aiter and self.use_aiter_quant_op:
|
||||
return [rocm_aiter_ops.get_per_token_quant_op()]
|
||||
if self.use_aiter:
|
||||
return [QUANT_OPS[self.quant_key]]
|
||||
if self.enable_quant_fp8_custom_op:
|
||||
return [QUANT_OPS[self.quant_key]]
|
||||
return [torch.ops.aten.reciprocal]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
if self.use_aiter and self.group_shape.is_per_group():
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||
AiterRMSFp8GroupQuantPattern,
|
||||
)
|
||||
|
||||
return [
|
||||
AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
|
||||
AiterRMSFp8GroupQuantPattern.FUSED_OP,
|
||||
]
|
||||
if self.use_aiter:
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
AiterFusedAddRMSNormDynamicQuantPattern,
|
||||
AiterRMSNormDynamicQuantPattern,
|
||||
)
|
||||
|
||||
return [
|
||||
AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
|
||||
AiterRMSNormDynamicQuantPattern.FUSED_OP,
|
||||
]
|
||||
return [
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
|
||||
]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return (
|
||||
[QUANT_OPS[self.quant_key]]
|
||||
if self.enable_quant_fp8_custom_op
|
||||
else [torch.ops.aten.reciprocal]
|
||||
)
|
||||
|
||||
def ops_in_model_before_partial(self):
|
||||
return (
|
||||
[RMS_OP, RMS_ADD_OP]
|
||||
@ -155,67 +214,45 @@ GROUP_SHAPES = [
|
||||
]
|
||||
|
||||
|
||||
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float, **kwargs):
|
||||
super().__init__()
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(128, 128),
|
||||
act_quant_group_shape=GroupShape(1, 128),
|
||||
cutlass_block_fp8_supported=False,
|
||||
use_aiter_and_is_supported=True,
|
||||
)
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
for _ in range(3)
|
||||
]
|
||||
def _run_fusion_test(
|
||||
model,
|
||||
fusion_pass,
|
||||
vllm_config,
|
||||
dtype,
|
||||
hidden_size,
|
||||
num_tokens,
|
||||
):
|
||||
"""Helper function for common fusion test logic.
|
||||
|
||||
scale_hidden_size = (hidden_size + 128 - 1) // 128
|
||||
self.wscale = [
|
||||
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
|
||||
for _ in range(3)
|
||||
]
|
||||
Must be called within vllm_config context.
|
||||
"""
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
|
||||
self.eps = eps
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
backend2 = TestBackend(noop_pass, cleanup_pass)
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
x = resid = torch.relu(x)
|
||||
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||
x2, resid, self.norm_weight[1], self.eps
|
||||
)
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
result_fused = model_fused(x)
|
||||
|
||||
x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])
|
||||
model_unfused = torch.compile(model, backend=backend2)
|
||||
result_unfused = model_unfused(x)
|
||||
|
||||
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||
x3, resid, self.norm_weight[2], self.eps
|
||||
)
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])
|
||||
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
|
||||
|
||||
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||
x4, resid, self.norm_weight[3], self.eps
|
||||
)
|
||||
return y4
|
||||
assert fusion_pass.matched_count == 3
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.rocm_aiter_rms_norm,
|
||||
torch.ops.vllm.rocm_aiter_group_fp8_quant,
|
||||
]
|
||||
|
||||
def ops_in_model_before_partial(self):
|
||||
return []
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
|
||||
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
|
||||
]
|
||||
return backend, backend2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@ -223,11 +260,8 @@ class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
|
||||
@pytest.mark.parametrize(
|
||||
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
|
||||
list(itertools.product([TestModel], [True, False], [True, False]))
|
||||
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.parametrize(
|
||||
@ -242,23 +276,13 @@ def test_fusion_rmsnorm_quant(
|
||||
num_tokens,
|
||||
eps,
|
||||
group_shape,
|
||||
model_class,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
cuda_force_torch,
|
||||
):
|
||||
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
|
||||
pytest.skip("AITER is not supported on this GPU.")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||
|
||||
if not enable_quant_fp8_custom_op and group_shape.is_per_group():
|
||||
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
|
||||
|
||||
# Skip test for 64-bit group shape when running with cutlass or deepgemm
|
||||
if group_shape == GroupShape(1, 64) and (
|
||||
cutlass_block_fp8_supported() or is_deep_gemm_supported()
|
||||
):
|
||||
@ -269,6 +293,7 @@ def test_fusion_rmsnorm_quant(
|
||||
custom_ops.append("+rms_norm")
|
||||
if enable_quant_fp8_custom_op:
|
||||
custom_ops.append("+quant_fp8")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
@ -279,60 +304,97 @@ def test_fusion_rmsnorm_quant(
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
if model_class is TestRmsnormGroupFp8QuantModel:
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
||||
)
|
||||
# Setup device before model creation
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity()
|
||||
|
||||
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
|
||||
else:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
backend2 = TestBackend(noop_pass, cleanup_pass)
|
||||
model = model_class(
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
model = TestModel(
|
||||
hidden_size=hidden_size,
|
||||
eps=eps,
|
||||
group_shape=group_shape,
|
||||
use_aiter=False,
|
||||
cuda_force_torch=cuda_force_torch,
|
||||
)
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
result_fused = model_fused(x)
|
||||
|
||||
model_unfused = torch.compile(model, backend=backend2)
|
||||
result_unfused = model_unfused(x)
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
|
||||
|
||||
assert fusion_pass.matched_count == 3
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
backend, _ = _run_fusion_test(
|
||||
model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
|
||||
)
|
||||
backend.check_before_ops(
|
||||
model.ops_in_model_before_partial(), fully_replaced=False
|
||||
)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
# If RMSNorm custom op is disabled (native/torch impl used),
|
||||
# there's a risk that the fused add doesn't get included in the
|
||||
# replacement and only the rms part gets fused with quant.
|
||||
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
||||
if (
|
||||
not enable_rms_norm_custom_op
|
||||
and model_class is not TestRmsnormGroupFp8QuantModel
|
||||
):
|
||||
if not enable_rms_norm_custom_op:
|
||||
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
||||
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
||||
assert n_add_nodes(backend.graph_pre_pass) == 7
|
||||
assert n_add_nodes(backend.graph_post_pass) == 2
|
||||
|
||||
|
||||
GROUP_SHAPE_QUANT_OPS_MATCHS = [
|
||||
(GroupShape.PER_TOKEN, True),
|
||||
(GroupShape.PER_TOKEN, False),
|
||||
(GroupShape(1, 128), True),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize(
|
||||
"group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
(not current_platform.is_rocm() or not IS_AITER_FOUND),
|
||||
reason="Only test on ROCm with aiter package installed",
|
||||
)
|
||||
def test_aiter_fusion_rmsnorm_quant(
|
||||
dtype: torch.dtype,
|
||||
hidden_size: int,
|
||||
num_tokens: int,
|
||||
eps: float,
|
||||
group_shape: GroupShape,
|
||||
use_aiter_quant_op: bool,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||
pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
|
||||
),
|
||||
)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass
|
||||
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity()
|
||||
|
||||
fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
|
||||
model = TestModel(
|
||||
hidden_size=hidden_size,
|
||||
eps=eps,
|
||||
group_shape=group_shape,
|
||||
use_aiter=True,
|
||||
use_aiter_quant_op=use_aiter_quant_op,
|
||||
)
|
||||
|
||||
_run_fusion_test(
|
||||
model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
|
||||
)
|
||||
|
||||
@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
AttentionBackendEnum.CUTLASS_MLA,
|
||||
@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
|
||||
|
||||
def run_attention_backend(
|
||||
backend: AttentionBackendEnum,
|
||||
kv_cache_spec: MLAAttentionSpec,
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config,
|
||||
device: torch.device,
|
||||
@ -740,7 +740,7 @@ def test_backend_correctness(
|
||||
kv_cache = kv_cache_per_block_size[block_size]
|
||||
|
||||
# Create kv_cache_spec with the correct block_size for this backend
|
||||
backend_kv_cache_spec = MLAAttentionSpec(
|
||||
backend_kv_cache_spec = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config
|
||||
@ -748,7 +748,6 @@ def test_backend_correctness(
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype,
|
||||
)
|
||||
|
||||
backend_output = run_attention_backend(
|
||||
|
||||
@ -4,6 +4,7 @@ import functools
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverload
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
@ -433,16 +434,16 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
||||
from aiter import rmsnorm2d_fwd_with_add
|
||||
|
||||
residual_out = torch.empty_like(residual)
|
||||
output = torch.empty_like(x)
|
||||
out = torch.empty_like(x)
|
||||
rmsnorm2d_fwd_with_add(
|
||||
output, # output
|
||||
out, # output
|
||||
x, # input
|
||||
residual, # residual input
|
||||
residual_out, # residual output
|
||||
weight,
|
||||
variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
return out, residual_out
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||
@ -451,7 +452,84 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(x), torch.empty_like(residual)
|
||||
residual_out = torch.empty_like(residual)
|
||||
out = torch.empty_like(x)
|
||||
return out, residual_out
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
import aiter as rocm_aiter
|
||||
|
||||
assert quant_dtype in [torch.int8, _FP8_DTYPE]
|
||||
|
||||
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
|
||||
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||
residual_out = torch.empty_like(x)
|
||||
|
||||
rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant(
|
||||
out,
|
||||
x,
|
||||
residual,
|
||||
residual_out,
|
||||
y_scale,
|
||||
weight,
|
||||
epsilon,
|
||||
use_model_sensitive_rmsnorm=0,
|
||||
)
|
||||
|
||||
return out, residual_out, y_scale
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
|
||||
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||
residual_out = torch.empty_like(x)
|
||||
|
||||
return out, residual_out, y_scale
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm_fused_dynamic_quant_impl(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
import aiter as rocm_aiter
|
||||
|
||||
assert quant_dtype in [torch.int8, _FP8_DTYPE]
|
||||
|
||||
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
|
||||
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||
|
||||
rocm_aiter.rmsnorm2d_fwd_with_dynamicquant(
|
||||
out, x, y_scale, weight, epsilon, use_model_sensitive_rmsnorm=0
|
||||
)
|
||||
|
||||
return out, y_scale
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
|
||||
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||
|
||||
return out, y_scale
|
||||
|
||||
|
||||
def _rocm_aiter_per_tensor_quant_impl(
|
||||
@ -527,7 +605,11 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
|
||||
dtype_quant=AITER_FP8_DTYPE,
|
||||
res1=residual,
|
||||
)
|
||||
return (x_quant, x_quant_scales, res)
|
||||
return (
|
||||
x_quant,
|
||||
res,
|
||||
x_quant_scales,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
|
||||
@ -541,8 +623,8 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
|
||||
scale_shape = (M, (N + group_size - 1) // group_size)
|
||||
return (
|
||||
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
||||
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
||||
torch.empty_like(residual, device=residual.device),
|
||||
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
||||
)
|
||||
|
||||
|
||||
@ -901,6 +983,20 @@ class rocm_aiter_ops:
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm_fused_dynamic_quant",
|
||||
op_func=_rocm_aiter_rmsnorm_fused_dynamic_quant_impl,
|
||||
fake_impl=_rocm_aiter_rmsnorm_fused_dynamic_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm_fused_add_dynamic_quant",
|
||||
op_func=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl,
|
||||
fake_impl=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
|
||||
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
|
||||
@ -936,13 +1032,54 @@ class rocm_aiter_ops:
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_per_token_quant",
|
||||
op_func=_rocm_aiter_per_token_quant_impl,
|
||||
mutates_args=["scale"],
|
||||
fake_impl=_rocm_aiter_per_token_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
_OPS_REGISTERED = True
|
||||
|
||||
@staticmethod
|
||||
def get_rmsnorm_fused_add_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
|
||||
|
||||
@staticmethod
|
||||
def get_rmsnorm_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_rms_norm.default
|
||||
|
||||
@staticmethod
|
||||
def get_rmsnorm_fused_add_dynamic_quant_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_rmsnorm_fused_add_dynamic_quant.default
|
||||
|
||||
@staticmethod
|
||||
def get_rmsnorm_fused_dynamic_quant_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_rmsnorm_fused_dynamic_quant.default
|
||||
|
||||
@staticmethod
|
||||
def get_rmsnorm_group_fused_quant_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
|
||||
|
||||
@staticmethod
|
||||
def get_rmsnorm_group_add_fused_quant_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
|
||||
|
||||
@staticmethod
|
||||
def get_per_token_quant_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_per_token_quant.default
|
||||
|
||||
@staticmethod
|
||||
def get_group_quant_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_group_fp8_quant.default
|
||||
|
||||
@staticmethod
|
||||
def get_act_mul_fused_fp8_group_quant_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
|
||||
|
||||
@staticmethod
|
||||
def rms_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
@staticmethod
|
||||
def rms_norm2d_with_add(
|
||||
x: torch.Tensor,
|
||||
@ -954,12 +1091,6 @@ class rocm_aiter_ops:
|
||||
x, residual, weight, variance_epsilon
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def rms_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
@staticmethod
|
||||
def gemm_a8w8(
|
||||
A: torch.Tensor,
|
||||
|
||||
@ -6,11 +6,13 @@ import torch
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
_normalize_quant_group_shape,
|
||||
kFp8Dynamic64Sym,
|
||||
@ -150,26 +152,50 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
|
||||
|
||||
|
||||
class MatcherRMSNorm(MatcherCustomOp):
|
||||
def __init__(self, epsilon: float, enabled: bool | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
enabled: bool | None = None,
|
||||
match_rocm_aiter: bool = False,
|
||||
):
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
self._rmsnorm_op = RMS_OP
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
|
||||
if match_rocm_aiter:
|
||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
|
||||
|
||||
def inputs(self):
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
return [input, weight]
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self._rmsnorm_op(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, weight)
|
||||
|
||||
result = torch.empty_like(input)
|
||||
_, result = auto_functionalized(
|
||||
RMS_OP,
|
||||
self._rmsnorm_op,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
@ -189,12 +215,23 @@ class MatcherRMSNorm(MatcherCustomOp):
|
||||
|
||||
|
||||
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
def __init__(self, epsilon: float, enabled: bool | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
enabled: bool | None = None,
|
||||
match_rocm_aiter: bool = False,
|
||||
):
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
|
||||
self._rmsnorm_op = RMS_ADD_OP
|
||||
|
||||
if match_rocm_aiter:
|
||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
|
||||
|
||||
def inputs(self):
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
@ -202,14 +239,27 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
residual = self.empty(5, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._rmsnorm_op(
|
||||
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
|
||||
)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, weight, residual)
|
||||
|
||||
_, result, residual = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
self._rmsnorm_op,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
@ -236,22 +286,46 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
enabled: bool | None = None,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
match_rocm_aiter: bool = False,
|
||||
):
|
||||
if enabled is None:
|
||||
enabled = QuantFP8.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.quant_key = quant_key
|
||||
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_e8m0 = is_e8m0
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
|
||||
if match_rocm_aiter:
|
||||
assert not quant_key.scale.group_shape.is_per_tensor(), (
|
||||
"ROCm aiter fusion pass does not support per tensor quantization"
|
||||
)
|
||||
if quant_key.scale.group_shape.is_per_token():
|
||||
self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op()
|
||||
else:
|
||||
assert quant_key.scale.group_shape.col == 128, (
|
||||
"ROCm aiter fusion pass currently supports "
|
||||
"quantization operation with group_size 128"
|
||||
)
|
||||
if current_platform.is_fp8_fnuz():
|
||||
self.QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
else:
|
||||
self.QUANT_OP = (
|
||||
torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
)
|
||||
|
||||
else:
|
||||
assert quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||
|
||||
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
||||
"Only QuantFP8 supported by"
|
||||
)
|
||||
assert quant_key.scale2 is None
|
||||
|
||||
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
||||
"Only QuantFP8 supported by"
|
||||
)
|
||||
assert quant_key.scale2 is None
|
||||
self.quant_fp8 = QuantFP8(
|
||||
quant_key.scale.static,
|
||||
quant_key.scale.group_shape,
|
||||
@ -259,11 +333,29 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
use_ue8m0=is_e8m0,
|
||||
)
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
quant_key_group_shape = self.quant_key.scale.group_shape
|
||||
if quant_key_group_shape == GroupShape.PER_TOKEN:
|
||||
return self.QUANT_OP(
|
||||
x=input,
|
||||
quant_dtype=self.quant_key.dtype,
|
||||
scale=scale,
|
||||
)
|
||||
else:
|
||||
return self.QUANT_OP(input, quant_key_group_shape.col)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, scale)
|
||||
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_key.dtype
|
||||
)
|
||||
|
||||
@ -16,7 +16,7 @@ from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
||||
RocmAiterRMSNormFusionPass,
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
)
|
||||
|
||||
@ -117,7 +117,9 @@ class PostGradPassManager(CustomGraphPass):
|
||||
if self.pass_config.fuse_norm_quant:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
|
||||
self.passes += [
|
||||
RocmAiterRMSNormFusionPass(config),
|
||||
]
|
||||
if self.pass_config.fuse_act_quant:
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
|
||||
@ -9,60 +9,195 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import empty_bf16
|
||||
from .fusion import (
|
||||
FusedRMSQuantKey,
|
||||
)
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherSiluAndMul
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
MatcherSiluAndMul,
|
||||
)
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
|
||||
AITER_RMS_ADD_GROUP_QUANT_OP = (
|
||||
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
|
||||
)
|
||||
|
||||
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
|
||||
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
|
||||
class AiterRMSNormQuantPattern:
|
||||
def __init__(
|
||||
self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
|
||||
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
|
||||
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
|
||||
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant,
|
||||
match_rocm_aiter=match_aiter_quant,
|
||||
)
|
||||
|
||||
|
||||
class AiterRMSFp8GroupQuantPattern:
|
||||
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""AITER RMSNorm + Dynamic Quantization pattern."""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
match_aiter_quant: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
result = self.FUSED_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
quant_dtype=self.quant_dtype,
|
||||
)
|
||||
|
||||
return result[0], result[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""AITER RMSNorm Fused Add + Dynamic Quantization pattern."""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
match_aiter_quant: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
):
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
return result, residual_out, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
):
|
||||
result = self.FUSED_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
quant_dtype=self.quant_dtype,
|
||||
)
|
||||
|
||||
return result[0], result[1], result[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter rms_norm & group fp8 quant custom
|
||||
ops into an aiter rms_norm_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = quant_dtype
|
||||
self.quant_op = quant_op
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
match_aiter_quant: bool = True,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
|
||||
|
||||
at2 = self.quant_op(at1, 128)
|
||||
|
||||
return at2[0], at2[1]
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
at = AITER_RMS_GROUP_QUANT_OP(
|
||||
at = self.FUSED_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
@ -71,49 +206,52 @@ class AiterRMSFp8GroupQuantPattern:
|
||||
|
||||
return at[0], at[1]
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
]
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AiterFusedAddRMSFp8GroupQuantPattern:
|
||||
class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
|
||||
into a aiter rms_norm_with_add_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = quant_dtype
|
||||
self.quant_op = quant_op
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
match_aiter_quant: bool = True,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
):
|
||||
at1 = AITER_RMS_ADD_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
)
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
at2 = self.quant_op(at1[0], 128)
|
||||
|
||||
# result, scale, residual
|
||||
return at2[0], at2[1], at1[1]
|
||||
return result, residual_out, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
):
|
||||
at = AITER_RMS_ADD_GROUP_QUANT_OP(
|
||||
at = self.FUSED_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
@ -124,18 +262,15 @@ class AiterFusedAddRMSFp8GroupQuantPattern:
|
||||
# result, scale, residual
|
||||
return at[0], at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
]
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
|
||||
into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
@ -144,20 +279,33 @@ class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
|
||||
pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse rms_norm + dynamic group fp8 quant
|
||||
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
||||
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
|
||||
self.patterns
|
||||
)
|
||||
# Fuse aiter rms_norm + aiter dynamic group fp8 quant
|
||||
AiterRMSFp8GroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
AiterFusedAddRMSFp8GroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, quant_op
|
||||
# Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant
|
||||
AiterFusedAddRMSFp8GroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
for match_aiter_quant in [True, False]:
|
||||
# Fuse aiter rms_norm + (aiter / vllm built-in)
|
||||
# dynamic per-token fp8 quant
|
||||
AiterRMSNormDynamicQuantPattern(
|
||||
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse aiter fused_add_rms_norm + (aiter / vllm built-in)
|
||||
# dynamic per-token fp8 quant
|
||||
AiterFusedAddRMSNormDynamicQuantPattern(
|
||||
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
@ -169,6 +317,8 @@ class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
|
||||
def uuid(self) -> Any:
|
||||
fusion_patterns = [
|
||||
AiterRMSNormDynamicQuantPattern,
|
||||
AiterFusedAddRMSNormDynamicQuantPattern,
|
||||
AiterRMSFp8GroupQuantPattern,
|
||||
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||
]
|
||||
@ -181,6 +331,8 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
ops into an aiter silu_and_mul_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
|
||||
|
||||
def __init__(self, quant_op: OpOverload):
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
self.quant_op = quant_op
|
||||
@ -196,7 +348,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
):
|
||||
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
||||
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
||||
return at[0], at[1]
|
||||
|
||||
inputs = [
|
||||
@ -216,6 +368,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
|
||||
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
@ -224,7 +381,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||
)
|
||||
|
||||
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
||||
for quant_op in self.QUANT_OPS:
|
||||
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@ -186,6 +186,7 @@ class DPMetadata:
|
||||
class ForwardContext:
|
||||
# copy from vllm_config.compilation_config.static_forward_context
|
||||
no_compile_layers: dict[str, Any]
|
||||
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
|
||||
"""
|
||||
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
|
||||
attention layer to its attention metadata
|
||||
@ -193,7 +194,6 @@ class ForwardContext:
|
||||
for each microbatch.
|
||||
Set dynamically for each forward pass
|
||||
"""
|
||||
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
|
||||
# TODO: remove after making all virtual_engines share the same kv cache
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
# set dynamically for each forward pass
|
||||
|
||||
@ -11,9 +11,11 @@ import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
is_batch_invariant = vllm_is_batch_invariant()
|
||||
|
||||
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
@ -150,7 +152,8 @@ def _get_lora_b_ptr(
|
||||
@functools.lru_cache
|
||||
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
|
||||
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
||||
if user_defined_config_folder is not None:
|
||||
# Avoid optimizing for the batch invariant case. Use default config
|
||||
if user_defined_config_folder is not None and not is_batch_invariant:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
gpu_name = gpu_name.replace(" ", "_")
|
||||
gpu_name = gpu_name.replace("-", "_")
|
||||
@ -203,11 +206,14 @@ def get_lora_op_configs(
|
||||
# default config
|
||||
default = {}
|
||||
if op_type == "shrink":
|
||||
split_k = 64 if batch < 128 else 8
|
||||
if is_batch_invariant:
|
||||
split_k = 1
|
||||
default = {
|
||||
"block_m": 32,
|
||||
"block_n": 16,
|
||||
"block_k": 256 if batch < 128 else 32,
|
||||
"split_k": 64 if batch < 128 else 8,
|
||||
"split_k": split_k,
|
||||
"num_warps": 4,
|
||||
"num_ctas": 1,
|
||||
"group_size_m": 8,
|
||||
|
||||
@ -325,6 +325,7 @@ def flashinfer_trtllm_fp4_moe(
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
tile_tokens_dim=None,
|
||||
routing_method_type=routing_method_type,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
|
||||
@ -48,7 +48,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
@ -167,7 +166,6 @@ class Jais2Attention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=getattr(config, "rope_parameters", None),
|
||||
is_neox_style=is_neox_style,
|
||||
@ -304,17 +302,12 @@ class Jais2Model(nn.Module):
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.org_vocab_size = config.vocab_size
|
||||
if get_pp_group().is_first_rank or (
|
||||
config.tie_word_embeddings and get_pp_group().is_last_rank
|
||||
@ -456,29 +449,15 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = self._init_model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=(
|
||||
DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size
|
||||
),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
@ -487,7 +466,7 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size, logit_scale
|
||||
config.vocab_size, scale=logit_scale
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
@ -156,7 +156,9 @@ class XPUPlatform(Platform):
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
|
||||
# decrease triton kernel compilation scratch space for speculative decoding
|
||||
if vllm_config.speculative_config is not None:
|
||||
os.environ["IGC_ForceOCLSIMDWidth"] = "16" # noqa: SIM112
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
# Only override worker_cls if it's still the default "auto"
|
||||
|
||||
@ -541,11 +541,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
metadata_cls if metadata_cls is not None else MLACommonMetadata
|
||||
)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.q_data_type = (
|
||||
current_platform.fp8_dtype()
|
||||
if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str)
|
||||
else vllm_config.model_config.dtype
|
||||
)
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
@ -689,6 +684,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
# For main run, qo_indptr == kv_indptr
|
||||
kv_indptr = qo_indptr.clone()
|
||||
|
||||
# Prepare main prefill
|
||||
self._fi_prefill_main.plan(
|
||||
qo_indptr=qo_indptr,
|
||||
@ -701,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
sm_scale=self._global_hyperparameters.sm_scale,
|
||||
window_left=self._global_hyperparameters.window_left,
|
||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
q_data_type=self.model_config.dtype,
|
||||
)
|
||||
|
||||
# Prepare context prefills
|
||||
@ -720,7 +716,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
sm_scale=self._global_hyperparameters.sm_scale,
|
||||
window_left=self._global_hyperparameters.window_left,
|
||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
q_data_type=self.model_config.dtype,
|
||||
)
|
||||
|
||||
prefill.prefill_main = self._fi_prefill_main
|
||||
@ -973,7 +969,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=chunked_context_metadata,
|
||||
q_data_type=self.q_data_type,
|
||||
)
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
@ -1384,15 +1379,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
return attn_out
|
||||
|
||||
def _run_prefill_new_tokens_fa(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
||||
):
|
||||
logger.debug_once("Running FlashAttention prefill new tokens", scope="local")
|
||||
return self._flash_attn_varlen_diff_headdims(
|
||||
q=q,
|
||||
k=k,
|
||||
@ -1407,23 +1395,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
)
|
||||
|
||||
def _run_prefill_new_tokens_fi(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
||||
):
|
||||
logger.debug_once("Running FlashInfer prefill new tokens", scope="local")
|
||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||
assert prefill.prefill_main is not None
|
||||
if fp8_attention:
|
||||
logger.debug_once("Running Flashinfer prefill in FP8")
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
q = q.to(fp8_dtype)
|
||||
k = k.to(fp8_dtype)
|
||||
v = v.to(fp8_dtype)
|
||||
|
||||
ret = prefill.prefill_main.run(
|
||||
q=q,
|
||||
k=k,
|
||||
@ -1436,18 +1412,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
return ret
|
||||
|
||||
def _run_prefill_new_tokens_cudnn(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
||||
):
|
||||
logger.debug_once("Running Cudnn prefill new tokens", scope="local")
|
||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||
assert prefill.query_seq_lens is not None
|
||||
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
|
||||
output, lse = cudnn_batch_prefill_with_kv_cache(
|
||||
q=q,
|
||||
k_cache=k,
|
||||
@ -1469,19 +1437,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
return output
|
||||
|
||||
def _run_prefill_context_chunk_fa(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||
):
|
||||
logger.debug_once("Running FlashAttention prefill context chunk", scope="local")
|
||||
assert prefill.chunked_context is not None
|
||||
assert fp8_attention is False, (
|
||||
"FlashAttention prefill does not support fp8 attention"
|
||||
)
|
||||
return self._flash_attn_varlen_diff_headdims(
|
||||
q=q,
|
||||
k=k,
|
||||
@ -1496,22 +1454,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
)
|
||||
|
||||
def _run_prefill_context_chunk_fi(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||
):
|
||||
logger.debug_once("Running FlashInfer prefill context chunk", scope="local")
|
||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||
if fp8_attention:
|
||||
logger.debug_once("Running FlashInfer prefill in FP8")
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
q = q.to(fp8_dtype)
|
||||
k = k.to(fp8_dtype)
|
||||
v = v.to(fp8_dtype)
|
||||
|
||||
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
|
||||
q=q,
|
||||
k=k,
|
||||
@ -1523,20 +1469,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
return attn_out, lse.transpose(0, 1).contiguous()
|
||||
|
||||
def _run_prefill_context_chunk_cudnn(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||
):
|
||||
logger.debug_once("Running Cudnn prefill context chunk", scope="local")
|
||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||
assert prefill.chunked_context is not None
|
||||
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||
assert prefill.query_seq_lens is not None
|
||||
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
|
||||
return cudnn_batch_prefill_with_kv_cache(
|
||||
q=q,
|
||||
k_cache=k,
|
||||
@ -1556,28 +1494,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
)
|
||||
|
||||
def _run_prefill_new_tokens_trtllm_ragged(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
|
||||
):
|
||||
logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local")
|
||||
"""TRT-LLM ragged attention for new tokens (causal)."""
|
||||
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||
|
||||
assert prefill.query_seq_lens is not None
|
||||
assert prefill.workspace_buffer is not None
|
||||
|
||||
if fp8_attention:
|
||||
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
q = q.to(fp8_dtype)
|
||||
k = k.to(fp8_dtype)
|
||||
v = v.to(fp8_dtype)
|
||||
|
||||
ret = trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
@ -1604,15 +1528,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
return ret
|
||||
|
||||
def _run_prefill_context_chunk_trtllm_ragged(
|
||||
self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
fp8_attention: bool,
|
||||
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
|
||||
):
|
||||
logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local")
|
||||
"""TRT-LLM ragged attention for context chunks (non-causal)."""
|
||||
from flashinfer.prefill import trtllm_ragged_attention_deepseek
|
||||
|
||||
@ -1629,13 +1546,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
)
|
||||
prefill.workspace_buffer.fill_(0)
|
||||
|
||||
if fp8_attention:
|
||||
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
q = q.to(fp8_dtype)
|
||||
k = k.to(fp8_dtype)
|
||||
v = v.to(fp8_dtype)
|
||||
|
||||
attn_out, lse = trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
@ -1788,7 +1698,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
fp8_attention: bool,
|
||||
):
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
@ -1827,7 +1736,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
fp8_attention=fp8_attention,
|
||||
)
|
||||
|
||||
if output is None:
|
||||
@ -1856,7 +1764,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
dcp_world_size: int,
|
||||
fp8_attention: bool,
|
||||
):
|
||||
assert k_scale is None, "DCP not support scaled kvcache now."
|
||||
assert attn_metadata.prefill is not None
|
||||
@ -1933,7 +1840,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
fp8_attention=fp8_attention,
|
||||
)
|
||||
|
||||
if output is None:
|
||||
@ -1964,7 +1870,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
fp8_attention: bool = False,
|
||||
) -> None:
|
||||
# TODO (zyongye): Prefill function here
|
||||
assert attn_metadata.prefill is not None
|
||||
@ -1984,7 +1889,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
k=k,
|
||||
v=v,
|
||||
return_softmax_lse=has_context,
|
||||
fp8_attention=fp8_attention,
|
||||
)
|
||||
|
||||
if has_context:
|
||||
@ -1997,12 +1901,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
attn_metadata,
|
||||
k_scale=None,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
fp8_attention=fp8_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
context_output, context_lse = self._compute_prefill_context(
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
|
||||
)
|
||||
|
||||
# unpad if necessary
|
||||
@ -2123,7 +2026,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
attn_metadata,
|
||||
layer._k_scale,
|
||||
output=output[num_decode_tokens:],
|
||||
fp8_attention=fp8_attention,
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
|
||||
@ -80,17 +80,20 @@ class AttentionSpec(KVCacheSpec):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FullAttentionSpec(AttentionSpec):
|
||||
sliding_window: int | None = None
|
||||
attention_chunk_size: int | None = None
|
||||
"""
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
attention layers and sliding window attention layers, sliding
|
||||
window attention are regarded as full attention in KV cache manager
|
||||
(blocks are allocated for all tokens), while computed as sliding window
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
attention layers and sliding window attention layers, sliding
|
||||
window attention are regarded as full attention in KV cache manager
|
||||
(blocks are allocated for all tokens), while computed as sliding window
|
||||
attention in model runner.
|
||||
In this case, we use FullAttentionSpec and record the sliding window size.
|
||||
"""
|
||||
|
||||
sliding_window: int | None = None
|
||||
"""
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
attention_chunk_size: int | None = None
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
@ -390,10 +393,11 @@ class KVCacheConfig:
|
||||
The KV cache configuration of a model.
|
||||
"""
|
||||
|
||||
"""The number of KV cache blocks"""
|
||||
num_blocks: int
|
||||
"""How should model runner initialize the KV cache tensors for each layer"""
|
||||
"""The number of KV cache blocks"""
|
||||
kv_cache_tensors: list[KVCacheTensor]
|
||||
"""How should model runner initialize the KV cache tensors for each layer"""
|
||||
kv_cache_groups: list[KVCacheGroupSpec]
|
||||
"""
|
||||
The kv cache groups of the model.
|
||||
For models with only one type of attention, there is only one group that
|
||||
@ -401,4 +405,3 @@ class KVCacheConfig:
|
||||
For models with multiple types of attention, there will be multiple groups,
|
||||
see `_get_kv_cache_config_uniform_page_size` for more details.
|
||||
"""
|
||||
kv_cache_groups: list[KVCacheGroupSpec]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user