diff --git a/benchmarks/benchmark_batch_invariance.py b/benchmarks/benchmark_batch_invariance.py index b5c16c42de467..7473a41e51406 100755 --- a/benchmarks/benchmark_batch_invariance.py +++ b/benchmarks/benchmark_batch_invariance.py @@ -104,7 +104,6 @@ def run_benchmark_with_batch_invariant( random.seed(seed) # Set environment variables - os.environ["VLLM_ATTENTION_BACKEND"] = backend if batch_invariant: os.environ["VLLM_BATCH_INVARIANT"] = "1" else: @@ -140,6 +139,7 @@ def run_benchmark_with_batch_invariant( max_model_len=max_model_len, dtype="bfloat16", tensor_parallel_size=tp_size, + attention_config={"backend": backend}, enable_prefix_caching=False, ) init_time = time.perf_counter() - start_init diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 72d2053102c22..4168c1570d874 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -2,7 +2,7 @@ FROM intel/deep-learning-essentials:2025.2.2-0-devel-ubuntu24.04 AS vllm-base RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \ - add-apt-repository -y ppa:kobuk-team/intel-graphics + add-apt-repository -y ppa:kobuk-team/intel-graphics-staging RUN apt clean && apt-get update -y && \ apt-get install -y --no-install-recommends --fix-missing \ @@ -47,6 +47,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip install --no-cache-dir \ -r requirements/xpu.txt +# arctic-inference is built from source which needs torch-xpu properly installed +# used for suffix method speculative decoding +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --no-cache-dir arctic-inference==0.1.1 + ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/" COPY . . diff --git a/docs/features/README.md b/docs/features/README.md index e9e5232929b72..b9083b9993159 100644 --- a/docs/features/README.md +++ b/docs/features/README.md @@ -64,7 +64,7 @@ th:not(:first-child) { | [CP](../configuration/optimization.md#chunked-prefill) | [❌](https://github.com/vllm-project/vllm/issues/2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [APC](automatic_prefix_caching.md) | [❌](https://github.com/vllm-project/vllm/issues/3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [🟠](https://github.com/vllm-project/vllm/issues/26963) | +| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/26970) | | [pooling](../models/pooling_models.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | enc-dec | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 28ab2cee71a6a..f8a629ed46cee 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -557,7 +557,8 @@ def test_rms_group_quant( # To capture subprocess logs, we need to know whether spawn or fork is used. # Force spawn as it is more general. monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + model_kwargs["attention_config"] = {"backend": backend.name} compilation_config = CompilationConfig( # Testing properties diff --git a/tests/compile/test_dynamic_shapes_compilation.py b/tests/compile/test_dynamic_shapes_compilation.py index 9ccb363b088f5..1fda21dea6361 100644 --- a/tests/compile/test_dynamic_shapes_compilation.py +++ b/tests/compile/test_dynamic_shapes_compilation.py @@ -77,6 +77,7 @@ def test_dynamic_shapes_compilation( "evaluate_guards": evaluate_guards, }, }, + max_model_len=1024, ) output = model.generate(prompt) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6b72c595cd779..7755e9f9b7380 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import itertools import pytest import torch @@ -53,37 +52,61 @@ class TestModel(torch.nn.Module): hidden_size: int, eps: float, group_shape: GroupShape, - cuda_force_torch: bool, + use_aiter: bool = False, + cuda_force_torch: bool = False, + use_aiter_quant_op: bool = True, *args, **kwargs, ): super().__init__(*args, **kwargs) + self.use_aiter = use_aiter + self.use_aiter_quant_op = use_aiter_quant_op self.cuda_force_torch = cuda_force_torch + self.group_shape = group_shape + self.enable_quant_fp8_custom_op = None # Will be set later if applicable + self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] - if group_shape.is_per_group(): - self.wscale = [ - torch.rand( - (hidden_size // group_shape[1], hidden_size // group_shape[1]), - dtype=torch.float32, - ) - for _ in range(3) - ] - else: - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - static = group_shape == GroupShape.PER_TENSOR + + # Setup quantization scale descriptor + static = group_shape == GroupShape.PER_TENSOR and not use_aiter quant_scale = ScaleDesc(torch.float32, static, group_shape) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + + # Setup scales if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: self.scale = [None for _ in range(3)] + + # Setup weights self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3) ] - if not group_shape.is_per_group(): + if not group_shape.is_per_group() or use_aiter: self.w = [self.w[0].t() for _ in range(3)] + # Setup weight scales if group_shape.is_per_group(): + scale_size = ( + (hidden_size + 128 - 1) // 128 + if use_aiter + else hidden_size // group_shape[1] + ) + wscale_shape: tuple[int, ...] = (scale_size, scale_size) + else: + wscale_shape = (1,) + self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)] + + # Setup FP8 linear operation + is_per_group = group_shape.is_per_group() + if is_per_group and use_aiter: + self.fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(128, 128), + act_quant_group_shape=group_shape, + use_aiter_and_is_supported=use_aiter_quant_op, + ) + # AITER blockwise doesn't use enable_quant_fp8_custom_op + elif is_per_group: self.fp8_linear = W8A8BlockFp8LinearOp( weight_group_shape=GroupShape(group_shape[1], group_shape[1]), act_quant_group_shape=group_shape, @@ -91,6 +114,13 @@ class TestModel(torch.nn.Module): use_aiter_and_is_supported=False, ) self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled() + elif use_aiter: + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, + act_quant_group_shape=group_shape, + ) + self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() else: with override_cutlass_fp8_supported(not cuda_force_torch): self.fp8_linear = Fp8LinearOp( @@ -100,7 +130,6 @@ class TestModel(torch.nn.Module): self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() self.enable_rms_norm_custom_op = self.norm[0].enabled() - self.group_shape = group_shape def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -126,19 +155,49 @@ class TestModel(torch.nn.Module): y4, resid = self.norm[3](x4, resid) # use resid here return y4 + def ops_in_model_before(self): + if ( + self.use_aiter + and self.group_shape.is_per_group() + and current_platform.is_fp8_fnuz() + ): + return [rocm_aiter_ops.get_group_quant_op()] + if self.use_aiter and self.group_shape.is_per_group(): + return [torch.ops.vllm.triton_per_token_group_quant_fp8.default] + if self.use_aiter and self.use_aiter_quant_op: + return [rocm_aiter_ops.get_per_token_quant_op()] + if self.use_aiter: + return [QUANT_OPS[self.quant_key]] + if self.enable_quant_fp8_custom_op: + return [QUANT_OPS[self.quant_key]] + return [torch.ops.aten.reciprocal] + def ops_in_model_after(self): + if self.use_aiter and self.group_shape.is_per_group(): + from vllm.compilation.rocm_aiter_fusion import ( + AiterFusedAddRMSFp8GroupQuantPattern, + AiterRMSFp8GroupQuantPattern, + ) + + return [ + AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP, + AiterRMSFp8GroupQuantPattern.FUSED_OP, + ] + if self.use_aiter: + from vllm.compilation.rocm_aiter_fusion import ( + AiterFusedAddRMSNormDynamicQuantPattern, + AiterRMSNormDynamicQuantPattern, + ) + + return [ + AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP, + AiterRMSNormDynamicQuantPattern.FUSED_OP, + ] return [ FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], ] - def ops_in_model_before(self): - return ( - [QUANT_OPS[self.quant_key]] - if self.enable_quant_fp8_custom_op - else [torch.ops.aten.reciprocal] - ) - def ops_in_model_before_partial(self): return ( [RMS_OP, RMS_ADD_OP] @@ -155,67 +214,45 @@ GROUP_SHAPES = [ ] -class TestRmsnormGroupFp8QuantModel(torch.nn.Module): - def __init__(self, hidden_size: int, eps: float, **kwargs): - super().__init__() - self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(128, 128), - act_quant_group_shape=GroupShape(1, 128), - cutlass_block_fp8_supported=False, - use_aiter_and_is_supported=True, - ) - self.w = [ - torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - for _ in range(3) - ] +def _run_fusion_test( + model, + fusion_pass, + vllm_config, + dtype, + hidden_size, + num_tokens, +): + """Helper function for common fusion test logic. - scale_hidden_size = (hidden_size + 128 - 1) // 128 - self.wscale = [ - torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32) - for _ in range(3) - ] + Must be called within vllm_config context. + """ + noop_pass = NoOpEliminationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - self.norm_weight = [torch.ones(hidden_size) for _ in range(4)] - self.eps = eps + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + backend2 = TestBackend(noop_pass, cleanup_pass) - def forward(self, x): - # avoid having graph input be an arg to a pattern directly - x = resid = torch.relu(x) - y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps) + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) - x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0]) - # make sure resid is used for replacement to work - y2, resid = rocm_aiter_ops.rms_norm2d_with_add( - x2, resid, self.norm_weight[1], self.eps - ) + model_fused = torch.compile(model, backend=backend) + result_fused = model_fused(x) - x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1]) + model_unfused = torch.compile(model, backend=backend2) + result_unfused = model_unfused(x) - y3, resid = rocm_aiter_ops.rms_norm2d_with_add( - x3, resid, self.norm_weight[2], self.eps - ) + if dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) - x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2]) + torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL) - y4, resid = rocm_aiter_ops.rms_norm2d_with_add( - x4, resid, self.norm_weight[3], self.eps - ) - return y4 + assert fusion_pass.matched_count == 3 + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) - def ops_in_model_before(self): - return [ - torch.ops.vllm.rocm_aiter_rms_norm, - torch.ops.vllm.rocm_aiter_group_fp8_quant, - ] - - def ops_in_model_before_partial(self): - return [] - - def ops_in_model_after(self): - return [ - torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant, - torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant, - ] + return backend, backend2 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -223,11 +260,8 @@ class TestRmsnormGroupFp8QuantModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("group_shape", GROUP_SHAPES) -@pytest.mark.parametrize( - "model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op", - list(itertools.product([TestModel], [True, False], [True, False])) - + [(TestRmsnormGroupFp8QuantModel, False, False)], -) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize( @@ -242,23 +276,13 @@ def test_fusion_rmsnorm_quant( num_tokens, eps, group_shape, - model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, cuda_force_torch, ): - if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND: - pytest.skip("AITER is not supported on this GPU.") - - torch.set_default_device("cuda") - torch.set_default_dtype(dtype) - torch.manual_seed(1) - maybe_create_device_identity() # needed for certain non-cutlass fp8 paths - if not enable_quant_fp8_custom_op and group_shape.is_per_group(): pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization") - # Skip test for 64-bit group shape when running with cutlass or deepgemm if group_shape == GroupShape(1, 64) and ( cutlass_block_fp8_supported() or is_deep_gemm_supported() ): @@ -269,6 +293,7 @@ def test_fusion_rmsnorm_quant( custom_ops.append("+rms_norm") if enable_quant_fp8_custom_op: custom_ops.append("+quant_fp8") + vllm_config = VllmConfig( model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( @@ -279,60 +304,97 @@ def test_fusion_rmsnorm_quant( ), ), ) + with vllm.config.set_current_vllm_config(vllm_config): - # Reshape pass is needed for the fusion pass to work - noop_pass = NoOpEliminationPass(vllm_config) - if model_class is TestRmsnormGroupFp8QuantModel: - from vllm.compilation.rocm_aiter_fusion import ( - RocmAiterRMSNormFp8GroupQuantFusionPass, - ) + # Setup device before model creation + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + maybe_create_device_identity() - fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config) - else: - fusion_pass = RMSNormQuantFusionPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - - backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) - backend2 = TestBackend(noop_pass, cleanup_pass) - model = model_class( + fusion_pass = RMSNormQuantFusionPass(vllm_config) + model = TestModel( hidden_size=hidden_size, eps=eps, group_shape=group_shape, + use_aiter=False, cuda_force_torch=cuda_force_torch, ) - # First dimension dynamic - x = torch.rand(num_tokens, hidden_size) - torch._dynamo.mark_dynamic(x, 0) - model_fused = torch.compile(model, backend=backend) - result_fused = model_fused(x) - - model_unfused = torch.compile(model, backend=backend2) - result_unfused = model_unfused(x) - - if dtype == torch.float16: - ATOL, RTOL = (2e-3, 2e-3) - else: - ATOL, RTOL = (1e-2, 1e-2) - - torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL) - - assert fusion_pass.matched_count == 3 - backend.check_before_ops(model.ops_in_model_before()) + backend, _ = _run_fusion_test( + model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens + ) backend.check_before_ops( model.ops_in_model_before_partial(), fully_replaced=False ) - backend.check_after_ops(model.ops_in_model_after()) # If RMSNorm custom op is disabled (native/torch impl used), # there's a risk that the fused add doesn't get included in the # replacement and only the rms part gets fused with quant. # Hence, we check only 2 add nodes are left (final fused rmsnorm add). - if ( - not enable_rms_norm_custom_op - and model_class is not TestRmsnormGroupFp8QuantModel - ): + if not enable_rms_norm_custom_op: n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) assert n_add_nodes(backend.graph_pre_pass) == 7 assert n_add_nodes(backend.graph_post_pass) == 2 + + +GROUP_SHAPE_QUANT_OPS_MATCHS = [ + (GroupShape.PER_TOKEN, True), + (GroupShape.PER_TOKEN, False), + (GroupShape(1, 128), True), +] + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [256]) +@pytest.mark.parametrize("num_tokens", [257]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize( + "group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS +) +@pytest.mark.skipif( + (not current_platform.is_rocm() or not IS_AITER_FOUND), + reason="Only test on ROCm with aiter package installed", +) +def test_aiter_fusion_rmsnorm_quant( + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + eps: float, + group_shape: GroupShape, + use_aiter_quant_op: bool, + monkeypatch: pytest.MonkeyPatch, +): + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: + from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass + + m.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + maybe_create_device_identity() + + fusion_pass = RocmAiterRMSNormFusionPass(vllm_config) + model = TestModel( + hidden_size=hidden_size, + eps=eps, + group_shape=group_shape, + use_aiter=True, + use_aiter_quant_op=use_aiter_quant_op, + ) + + _run_fusion_test( + model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens + ) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 6a62440d95417..783e02ce89bdb 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.kv_cache_interface import MLAAttentionSpec +from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ AttentionBackendEnum.CUTLASS_MLA, @@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase): def run_attention_backend( backend: AttentionBackendEnum, - kv_cache_spec: MLAAttentionSpec, + kv_cache_spec: FullAttentionSpec, layer_names: list[str], vllm_config, device: torch.device, @@ -740,7 +740,7 @@ def test_backend_correctness( kv_cache = kv_cache_per_block_size[block_size] # Create kv_cache_spec with the correct block_size for this backend - backend_kv_cache_spec = MLAAttentionSpec( + backend_kv_cache_spec = FullAttentionSpec( block_size=block_size, num_kv_heads=vllm_config.model_config.get_num_kv_heads( vllm_config.parallel_config @@ -748,7 +748,6 @@ def test_backend_correctness( head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, sliding_window=vllm_config.model_config.get_sliding_window(), - cache_dtype_str=vllm_config.cache_config.cache_dtype, ) backend_output = run_attention_backend( diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 03e3bb7594910..299c8219120ae 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -4,6 +4,7 @@ import functools from collections.abc import Callable import torch +from torch._ops import OpOverload import vllm.envs as envs from vllm.platforms import current_platform @@ -433,16 +434,16 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_impl( from aiter import rmsnorm2d_fwd_with_add residual_out = torch.empty_like(residual) - output = torch.empty_like(x) + out = torch.empty_like(x) rmsnorm2d_fwd_with_add( - output, # output + out, # output x, # input residual, # residual input residual_out, # residual output weight, variance_epsilon, ) - return output, residual_out + return out, residual_out def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( @@ -451,7 +452,84 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( weight: torch.Tensor, variance_epsilon: float, ) -> tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(x), torch.empty_like(residual) + residual_out = torch.empty_like(residual) + out = torch.empty_like(x) + return out, residual_out + + +def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + import aiter as rocm_aiter + + assert quant_dtype in [torch.int8, _FP8_DTYPE] + + y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) + out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + residual_out = torch.empty_like(x) + + rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant( + out, + x, + residual, + residual_out, + y_scale, + weight, + epsilon, + use_model_sensitive_rmsnorm=0, + ) + + return out, residual_out, y_scale + + +def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) + out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + residual_out = torch.empty_like(x) + + return out, residual_out, y_scale + + +def _rocm_aiter_rmsnorm_fused_dynamic_quant_impl( + x: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + import aiter as rocm_aiter + + assert quant_dtype in [torch.int8, _FP8_DTYPE] + + y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) + out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + + rocm_aiter.rmsnorm2d_fwd_with_dynamicquant( + out, x, y_scale, weight, epsilon, use_model_sensitive_rmsnorm=0 + ) + + return out, y_scale + + +def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake( + x: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) + out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + + return out, y_scale def _rocm_aiter_per_tensor_quant_impl( @@ -527,7 +605,11 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl( dtype_quant=AITER_FP8_DTYPE, res1=residual, ) - return (x_quant, x_quant_scales, res) + return ( + x_quant, + res, + x_quant_scales, + ) def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake( @@ -541,8 +623,8 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake( scale_shape = (M, (N + group_size - 1) // group_size) return ( torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), - torch.empty(scale_shape, dtype=torch.float32, device=x.device), torch.empty_like(residual, device=residual.device), + torch.empty(scale_shape, dtype=torch.float32, device=x.device), ) @@ -901,6 +983,20 @@ class rocm_aiter_ops: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fused_dynamic_quant", + op_func=_rocm_aiter_rmsnorm_fused_dynamic_quant_impl, + fake_impl=_rocm_aiter_rmsnorm_fused_dynamic_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fused_add_dynamic_quant", + op_func=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl, + fake_impl=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_rmsnorm_fp8_group_quant", op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl, @@ -936,13 +1032,54 @@ class rocm_aiter_ops: direct_register_custom_op( op_name="rocm_aiter_per_token_quant", op_func=_rocm_aiter_per_token_quant_impl, - mutates_args=["scale"], fake_impl=_rocm_aiter_per_token_quant_fake, dispatch_key=current_platform.dispatch_key, ) _OPS_REGISTERED = True + @staticmethod + def get_rmsnorm_fused_add_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default + + @staticmethod + def get_rmsnorm_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rms_norm.default + + @staticmethod + def get_rmsnorm_fused_add_dynamic_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm_fused_add_dynamic_quant.default + + @staticmethod + def get_rmsnorm_fused_dynamic_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm_fused_dynamic_quant.default + + @staticmethod + def get_rmsnorm_group_fused_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default + + @staticmethod + def get_rmsnorm_group_add_fused_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default + + @staticmethod + def get_per_token_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_per_token_quant.default + + @staticmethod + def get_group_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_group_fp8_quant.default + + @staticmethod + def get_act_mul_fused_fp8_group_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default + + @staticmethod + def rms_norm( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) + @staticmethod def rms_norm2d_with_add( x: torch.Tensor, @@ -954,12 +1091,6 @@ class rocm_aiter_ops: x, residual, weight, variance_epsilon ) - @staticmethod - def rms_norm( - x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float - ) -> torch.Tensor: - return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) - @staticmethod def gemm_a8w8( A: torch.Tensor, diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index ec9ed34f561b4..7301aa3e5932d 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -6,11 +6,13 @@ import torch from torch._higher_order_ops import auto_functionalized from torch._ops import OpOverload +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, QuantKey, _normalize_quant_group_shape, kFp8Dynamic64Sym, @@ -150,26 +152,50 @@ class MatcherRotaryEmbedding(MatcherCustomOp): class MatcherRMSNorm(MatcherCustomOp): - def __init__(self, epsilon: float, enabled: bool | None = None): + def __init__( + self, + epsilon: float, + enabled: bool | None = None, + match_rocm_aiter: bool = False, + ): if enabled is None: enabled = RMSNorm.enabled() super().__init__(enabled) self.epsilon = epsilon + self._rmsnorm_op = RMS_OP + self.match_rocm_aiter = match_rocm_aiter + + if match_rocm_aiter: + self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op() def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) weight = self.empty(16) return [input, weight] + def forward_rocm_aiter( + self, + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + return self._rmsnorm_op( + x=input, + weight=weight, + variance_epsilon=self.epsilon, + ) + def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, ) -> torch.Tensor: + if self.match_rocm_aiter: + return self.forward_rocm_aiter(input, weight) + result = torch.empty_like(input) _, result = auto_functionalized( - RMS_OP, + self._rmsnorm_op, result=result, input=input, weight=weight, @@ -189,12 +215,23 @@ class MatcherRMSNorm(MatcherCustomOp): class MatcherFusedAddRMSNorm(MatcherCustomOp): - def __init__(self, epsilon: float, enabled: bool | None = None): + def __init__( + self, + epsilon: float, + enabled: bool | None = None, + match_rocm_aiter: bool = False, + ): if enabled is None: enabled = RMSNorm.enabled() super().__init__(enabled) self.epsilon = epsilon + self.match_rocm_aiter = match_rocm_aiter + + self._rmsnorm_op = RMS_ADD_OP + + if match_rocm_aiter: + self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op() def inputs(self): input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) @@ -202,14 +239,27 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp): residual = self.empty(5, 16) return [input, weight, residual] + def forward_rocm_aiter( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self._rmsnorm_op( + x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon + ) + def forward_custom( self, input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + if self.match_rocm_aiter: + return self.forward_rocm_aiter(input, weight, residual) + _, result, residual = auto_functionalized( - RMS_ADD_OP, + self._rmsnorm_op, input=input, residual=residual, weight=weight, @@ -236,22 +286,46 @@ class MatcherQuantFP8(MatcherCustomOp): enabled: bool | None = None, has_col_major_scales: bool = False, is_e8m0: bool = False, + match_rocm_aiter: bool = False, ): if enabled is None: enabled = QuantFP8.enabled() super().__init__(enabled) self.quant_key = quant_key - assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" - self.QUANT_OP = QUANT_OPS[quant_key] - self.has_col_major_scales = has_col_major_scales self.is_e8m0 = is_e8m0 + self.match_rocm_aiter = match_rocm_aiter + + if match_rocm_aiter: + assert not quant_key.scale.group_shape.is_per_tensor(), ( + "ROCm aiter fusion pass does not support per tensor quantization" + ) + if quant_key.scale.group_shape.is_per_token(): + self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op() + else: + assert quant_key.scale.group_shape.col == 128, ( + "ROCm aiter fusion pass currently supports " + "quantization operation with group_size 128" + ) + if current_platform.is_fp8_fnuz(): + self.QUANT_OP = rocm_aiter_ops.get_group_quant_op() + else: + self.QUANT_OP = ( + torch.ops.vllm.triton_per_token_group_quant_fp8.default + ) + + else: + assert quant_key in QUANT_OPS, ( + f"unsupported quantization scheme {quant_key}" + ) + self.QUANT_OP = QUANT_OPS[quant_key] + + assert quant_key.dtype == current_platform.fp8_dtype(), ( + "Only QuantFP8 supported by" + ) + assert quant_key.scale2 is None - assert quant_key.dtype == current_platform.fp8_dtype(), ( - "Only QuantFP8 supported by" - ) - assert quant_key.scale2 is None self.quant_fp8 = QuantFP8( quant_key.scale.static, quant_key.scale.group_shape, @@ -259,11 +333,29 @@ class MatcherQuantFP8(MatcherCustomOp): use_ue8m0=is_e8m0, ) + def forward_rocm_aiter( + self, + input: torch.Tensor, + scale: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + quant_key_group_shape = self.quant_key.scale.group_shape + if quant_key_group_shape == GroupShape.PER_TOKEN: + return self.QUANT_OP( + x=input, + quant_dtype=self.quant_key.dtype, + scale=scale, + ) + else: + return self.QUANT_OP(input, quant_key_group_shape.col) + def forward_custom( self, input: torch.Tensor, scale: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if self.match_rocm_aiter: + return self.forward_rocm_aiter(input, scale) + result = torch.empty( input.shape, device=input.device, dtype=self.quant_key.dtype ) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 4ebb386f75ed8..4c2dee505a941 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -16,7 +16,7 @@ from .vllm_inductor_pass import VllmInductorPass if rocm_aiter_ops.is_enabled(): from vllm.compilation.rocm_aiter_fusion import ( - RocmAiterRMSNormFp8GroupQuantFusionPass, + RocmAiterRMSNormFusionPass, RocmAiterSiluMulFp8GroupQuantFusionPass, ) @@ -117,7 +117,9 @@ class PostGradPassManager(CustomGraphPass): if self.pass_config.fuse_norm_quant: self.passes += [RMSNormQuantFusionPass(config)] if rocm_aiter_ops.is_enabled(): - self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)] + self.passes += [ + RocmAiterRMSNormFusionPass(config), + ] if self.pass_config.fuse_act_quant: self.passes += [ActivationQuantFusionPass(config)] if rocm_aiter_ops.is_enabled(): diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/rocm_aiter_fusion.py index 8b5db9de38181..f66bb76b97f05 100644 --- a/vllm/compilation/rocm_aiter_fusion.py +++ b/vllm/compilation/rocm_aiter_fusion.py @@ -9,60 +9,195 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 +from vllm._aiter_ops import rocm_aiter_ops from vllm.compilation.activation_quant_fusion import ActivationQuantPattern from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + QuantKey, + ScaleDesc, +) from vllm.platforms import current_platform -from .fusion import empty_bf16 +from .fusion import ( + FusedRMSQuantKey, +) from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherSiluAndMul +from .matcher_utils import ( + MatcherFusedAddRMSNorm, + MatcherQuantFP8, + MatcherRMSNorm, + MatcherSiluAndMul, +) from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() -AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default -AITER_RMS_ADD_GROUP_QUANT_OP = ( - torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default -) -AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default -AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default +class AiterRMSNormQuantPattern: + def __init__( + self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True + ): + self.epsilon = epsilon + self.quant_dtype = key.quant.dtype -AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default -TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default - -FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default + self.rmsnorm_matcher = ( + MatcherRMSNorm(epsilon, match_rocm_aiter=True) + if not key.fused_add + else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True) + ) + self.quant_matcher = MatcherQuantFP8( + key.quant, + match_rocm_aiter=match_aiter_quant, + ) -class AiterRMSFp8GroupQuantPattern: +class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): + """AITER RMSNorm + Dynamic Quantization pattern.""" + + FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op() + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + match_aiter_quant: bool = True, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + + super().__init__(epsilon, key, match_aiter_quant) + + def register(self, pm_pass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + ): + result_rms = self.rmsnorm_matcher(input, weight) + result, scale = self.quant_matcher(result_rms) + return result, scale + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + ): + result = self.FUSED_OP( + x=input, + weight=weight, + epsilon=self.epsilon, + quant_dtype=self.quant_dtype, + ) + + return result[0], result[1] + + pm.register_replacement( + pattern, + replacement, + self.rmsnorm_matcher.inputs(), + pm.fwd_only, + pm_pass, + ) + + +class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): + """AITER RMSNorm Fused Add + Dynamic Quantization pattern.""" + + FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op() + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + match_aiter_quant: bool = True, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + + super().__init__(epsilon, key, match_aiter_quant) + + def register(self, pm_pass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ): + result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual) + result, scale = self.quant_matcher(result_rms) + + return result, residual_out, scale + + def replacement( + input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor + ): + result = self.FUSED_OP( + x=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + quant_dtype=self.quant_dtype, + ) + + return result[0], result[1], result[2] + + pm.register_replacement( + pattern, + replacement, + self.rmsnorm_matcher.inputs(), + pm.fwd_only, + pm_pass, + ) + + +class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): """ This pattern fuses aiter rms_norm & group fp8 quant custom ops into an aiter rms_norm_group_fp8_quant op. """ - def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): - self.epsilon = epsilon - self.quant_dtype = quant_dtype - self.quant_op = quant_op + FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op() + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape, + match_aiter_quant: bool = True, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + + super().__init__(epsilon, key, match_aiter_quant) def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, weight: torch.Tensor, ): - at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon) - - at2 = self.quant_op(at1, 128) - - return at2[0], at2[1] + result_rms = self.rmsnorm_matcher(input, weight) + result, scale = self.quant_matcher(result_rms) + return result, scale def replacement( input: torch.Tensor, weight: torch.Tensor, ): - at = AITER_RMS_GROUP_QUANT_OP( + at = self.FUSED_OP( x=input, weight=weight, variance_epsilon=self.epsilon, @@ -71,49 +206,52 @@ class AiterRMSFp8GroupQuantPattern: return at[0], at[1] - inputs = [ - empty_bf16(5, 4), # input - empty_bf16(1, 5), # weight - ] - - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass + ) -class AiterFusedAddRMSFp8GroupQuantPattern: +class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): """ This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops into a aiter rms_norm_with_add_group_fp8_quant op. """ - def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): - self.epsilon = epsilon - self.quant_dtype = quant_dtype - self.quant_op = quant_op + FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op() + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape, + match_aiter_quant: bool = True, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + + super().__init__(epsilon, key, match_aiter_quant) def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, ): - at1 = AITER_RMS_ADD_OP( - x=input, - residual=residual, - weight=weight, - variance_epsilon=self.epsilon, - ) + result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual) + result, scale = self.quant_matcher(result_rms) - at2 = self.quant_op(at1[0], 128) - - # result, scale, residual - return at2[0], at2[1], at1[1] + return result, residual_out, scale def replacement( input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, ): - at = AITER_RMS_ADD_GROUP_QUANT_OP( + at = self.FUSED_OP( x=input, residual=residual, weight=weight, @@ -124,18 +262,15 @@ class AiterFusedAddRMSFp8GroupQuantPattern: # result, scale, residual return at[0], at[1], at[2] - inputs = [ - empty_bf16(5, 4), # input - empty_bf16(5, 4), # residual - empty_bf16(1, 5), # weight - ] - - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass + ) -class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass): +class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass): """ - This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. + This pass fuses aiter rms_norm & vllm/aiter quant custom ops + into a fused rms_norm_quant op. It also supports fused_add_rms_norm. """ @@ -144,20 +279,33 @@ class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass" + pass_name="rocm_aiter_rms_norm_quant_fusion_pass" ) # Make sure fused add patterns are before simple rms norm, # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: - # Fuse rms_norm + dynamic group fp8 quant - for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]: - AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register( - self.patterns - ) + # Fuse aiter rms_norm + aiter dynamic group fp8 quant + AiterRMSFp8GroupQuantPattern( + epsilon, FP8_DTYPE, GroupShape(1, 128) + ).register(self.patterns) - AiterFusedAddRMSFp8GroupQuantPattern( - epsilon, FP8_DTYPE, quant_op + # Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant + AiterFusedAddRMSFp8GroupQuantPattern( + epsilon, FP8_DTYPE, GroupShape(1, 128) + ).register(self.patterns) + + for match_aiter_quant in [True, False]: + # Fuse aiter rms_norm + (aiter / vllm built-in) + # dynamic per-token fp8 quant + AiterRMSNormDynamicQuantPattern( + epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant + ).register(self.patterns) + + # Fuse aiter fused_add_rms_norm + (aiter / vllm built-in) + # dynamic per-token fp8 quant + AiterFusedAddRMSNormDynamicQuantPattern( + epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant ).register(self.patterns) self.dump_patterns(config, self.patterns) @@ -169,6 +317,8 @@ class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass): def uuid(self) -> Any: fusion_patterns = [ + AiterRMSNormDynamicQuantPattern, + AiterFusedAddRMSNormDynamicQuantPattern, AiterRMSFp8GroupQuantPattern, AiterFusedAddRMSFp8GroupQuantPattern, ] @@ -181,6 +331,8 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): ops into an aiter silu_and_mul_group_fp8_quant op. """ + FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op() + def __init__(self, quant_op: OpOverload): self.silu_and_mul_matcher = MatcherSiluAndMul() self.quant_op = quant_op @@ -196,7 +348,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): def replacement( input: torch.Tensor, ): - at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128) + at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128) return at[0], at[1] inputs = [ @@ -216,6 +368,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 """ + AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op() + TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default + + QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP] + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) @@ -224,7 +381,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass" ) - for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]: + for quant_op in self.QUANT_OPS: AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns) self.dump_patterns(config, self.patterns) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 033cc1f544b3b..7a569ec32eac9 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -186,6 +186,7 @@ class DPMetadata: class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] + attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] """ Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata @@ -193,7 +194,6 @@ class ForwardContext: for each microbatch. Set dynamically for each forward pass """ - attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 8ed42382e3a86..ed9e916455e5f 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -11,9 +11,11 @@ import torch from vllm import envs from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.platforms import current_platform logger = init_logger(__name__) +is_batch_invariant = vllm_is_batch_invariant() _LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} _LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} @@ -150,7 +152,8 @@ def _get_lora_b_ptr( @functools.lru_cache def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None: user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER - if user_defined_config_folder is not None: + # Avoid optimizing for the batch invariant case. Use default config + if user_defined_config_folder is not None and not is_batch_invariant: gpu_name = torch.cuda.get_device_name() gpu_name = gpu_name.replace(" ", "_") gpu_name = gpu_name.replace("-", "_") @@ -203,11 +206,14 @@ def get_lora_op_configs( # default config default = {} if op_type == "shrink": + split_k = 64 if batch < 128 else 8 + if is_batch_invariant: + split_k = 1 default = { "block_m": 32, "block_n": 16, "block_k": 256 if batch < 128 else 32, - "split_k": 64 if batch < 128 else 8, + "split_k": split_k, "num_warps": 4, "num_ctas": 1, "group_size_m": 8, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 4611b83757a69..dceee42f31e39 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -325,6 +325,7 @@ def flashinfer_trtllm_fp4_moe( local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, routed_scaling_factor=None, + tile_tokens_dim=None, routing_method_type=routing_method_type, do_finalize=True, )[0] diff --git a/vllm/model_executor/models/jais2.py b/vllm/model_executor/models/jais2.py index 01e75338a8ced..aacc4abd43e61 100644 --- a/vllm/model_executor/models/jais2.py +++ b/vllm/model_executor/models/jais2.py @@ -48,7 +48,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -167,7 +166,6 @@ class Jais2Attention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=getattr(config, "rope_parameters", None), is_neox_style=is_neox_style, @@ -304,17 +302,12 @@ class Jais2Model(nn.Module): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab + + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank @@ -456,29 +449,15 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config - self.model = self._init_model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -487,7 +466,7 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 2d67551eed9f6..2e39a216a10a0 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -156,7 +156,9 @@ class XPUPlatform(Platform): if vllm_config.lora_config is not None: compilation_config.mode = CompilationMode.NONE - + # decrease triton kernel compilation scratch space for speculative decoding + if vllm_config.speculative_config is not None: + os.environ["IGC_ForceOCLSIMDWidth"] = "16" # noqa: SIM112 # check and update parallel config parallel_config = vllm_config.parallel_config # Only override worker_cls if it's still the default "auto" diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 619c1c2794daa..e9ec96835f277 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -541,11 +541,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): metadata_cls if metadata_cls is not None else MLACommonMetadata ) self.kv_cache_spec = kv_cache_spec - self.q_data_type = ( - current_platform.fp8_dtype() - if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str) - else vllm_config.model_config.dtype - ) scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config @@ -689,6 +684,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # For main run, qo_indptr == kv_indptr kv_indptr = qo_indptr.clone() + # Prepare main prefill self._fi_prefill_main.plan( qo_indptr=qo_indptr, @@ -701,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, - q_data_type=self.q_data_type, + q_data_type=self.model_config.dtype, ) # Prepare context prefills @@ -720,7 +716,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, - q_data_type=self.q_data_type, + q_data_type=self.model_config.dtype, ) prefill.prefill_main = self._fi_prefill_main @@ -973,7 +969,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, - q_data_type=self.q_data_type, ) if self._use_cudnn_prefill: @@ -1384,15 +1379,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return attn_out def _run_prefill_new_tokens_fa( - self, - prefill: MLACommonPrefillMetadata, - q, - k, - v, - return_softmax_lse, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): - logger.debug_once("Running FlashAttention prefill new tokens", scope="local") return self._flash_attn_varlen_diff_headdims( q=q, k=k, @@ -1407,23 +1395,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) def _run_prefill_new_tokens_fi( - self, - prefill: MLACommonPrefillMetadata, - q, - k, - v, - return_softmax_lse, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): - logger.debug_once("Running FlashInfer prefill new tokens", scope="local") assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None - if fp8_attention: - logger.debug_once("Running Flashinfer prefill in FP8") - fp8_dtype = current_platform.fp8_dtype() - q = q.to(fp8_dtype) - k = k.to(fp8_dtype) - v = v.to(fp8_dtype) + ret = prefill.prefill_main.run( q=q, k=k, @@ -1436,18 +1412,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return ret def _run_prefill_new_tokens_cudnn( - self, - prefill: MLACommonPrefillMetadata, - q, - k, - v, - return_softmax_lse, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): - logger.debug_once("Running Cudnn prefill new tokens", scope="local") assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.query_seq_lens is not None - assert fp8_attention is False, "Cudnn prefill does not support fp8 attention" output, lse = cudnn_batch_prefill_with_kv_cache( q=q, k_cache=k, @@ -1469,19 +1437,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return output def _run_prefill_context_chunk_fa( - self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, - q, - k, - v, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): - logger.debug_once("Running FlashAttention prefill context chunk", scope="local") assert prefill.chunked_context is not None - assert fp8_attention is False, ( - "FlashAttention prefill does not support fp8 attention" - ) return self._flash_attn_varlen_diff_headdims( q=q, k=k, @@ -1496,22 +1454,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) def _run_prefill_context_chunk_fi( - self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, - q, - k, - v, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): - logger.debug_once("Running FlashInfer prefill context chunk", scope="local") assert isinstance(prefill, FlashInferPrefillMetadata) - if fp8_attention: - logger.debug_once("Running FlashInfer prefill in FP8") - fp8_dtype = current_platform.fp8_dtype() - q = q.to(fp8_dtype) - k = k.to(fp8_dtype) - v = v.to(fp8_dtype) + attn_out, lse = prefill.prefill_chunks[chunk_idx].run( q=q, k=k, @@ -1523,20 +1469,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return attn_out, lse.transpose(0, 1).contiguous() def _run_prefill_context_chunk_cudnn( - self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, - q, - k, - v, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): - logger.debug_once("Running Cudnn prefill context chunk", scope="local") assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.chunked_context is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.query_seq_lens is not None - assert fp8_attention is False, "Cudnn prefill does not support fp8 attention" return cudnn_batch_prefill_with_kv_cache( q=q, k_cache=k, @@ -1556,28 +1494,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) def _run_prefill_new_tokens_trtllm_ragged( - self, - prefill: MLACommonPrefillMetadata, - q, - k, - v, - return_softmax_lse, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): - logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local") """TRT-LLM ragged attention for new tokens (causal).""" from flashinfer.prefill import trtllm_ragged_attention_deepseek assert prefill.query_seq_lens is not None assert prefill.workspace_buffer is not None - if fp8_attention: - logger.debug_once("Running TRT-LLM ragged prefill in FP8") - fp8_dtype = current_platform.fp8_dtype() - q = q.to(fp8_dtype) - k = k.to(fp8_dtype) - v = v.to(fp8_dtype) - ret = trtllm_ragged_attention_deepseek( query=q, key=k, @@ -1604,15 +1528,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return ret def _run_prefill_context_chunk_trtllm_ragged( - self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, - q, - k, - v, - fp8_attention: bool, + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): - logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local") """TRT-LLM ragged attention for context chunks (non-causal).""" from flashinfer.prefill import trtllm_ragged_attention_deepseek @@ -1629,13 +1546,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) prefill.workspace_buffer.fill_(0) - if fp8_attention: - logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8") - fp8_dtype = current_platform.fp8_dtype() - q = q.to(fp8_dtype) - k = k.to(fp8_dtype) - v = v.to(fp8_dtype) - attn_out, lse = trtllm_ragged_attention_deepseek( query=q, key=k, @@ -1788,7 +1698,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, - fp8_attention: bool, ): assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill @@ -1827,7 +1736,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): q=q, k=k, v=v, - fp8_attention=fp8_attention, ) if output is None: @@ -1856,7 +1764,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, dcp_world_size: int, - fp8_attention: bool, ): assert k_scale is None, "DCP not support scaled kvcache now." assert attn_metadata.prefill is not None @@ -1933,7 +1840,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): q=q, k=k, v=v, - fp8_attention=fp8_attention, ) if output is None: @@ -1964,7 +1870,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, output: torch.Tensor, - fp8_attention: bool = False, ) -> None: # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None @@ -1984,7 +1889,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): k=k, v=v, return_softmax_lse=has_context, - fp8_attention=fp8_attention, ) if has_context: @@ -1997,12 +1901,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata, k_scale=None, dcp_world_size=self.dcp_world_size, - fp8_attention=fp8_attention, ) ) else: context_output, context_lse = self._compute_prefill_context( - q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention + q, kv_c_and_k_pe_cache, attn_metadata, k_scale ) # unpad if necessary @@ -2123,7 +2026,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): attn_metadata, layer._k_scale, output=output[num_decode_tokens:], - fp8_attention=fp8_attention, ) if has_decode: diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 751862aa9c767..7370f0aefafb4 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -80,17 +80,20 @@ class AttentionSpec(KVCacheSpec): @dataclass(frozen=True) class FullAttentionSpec(AttentionSpec): - sliding_window: int | None = None - attention_chunk_size: int | None = None """ - When hybrid allocator is disabled and the model contains both full - attention layers and sliding window attention layers, sliding - window attention are regarded as full attention in KV cache manager - (blocks are allocated for all tokens), while computed as sliding window + When hybrid allocator is disabled and the model contains both full + attention layers and sliding window attention layers, sliding + window attention are regarded as full attention in KV cache manager + (blocks are allocated for all tokens), while computed as sliding window attention in model runner. In this case, we use FullAttentionSpec and record the sliding window size. + """ + + sliding_window: int | None = None + """ Default to None for not using sliding window attention. """ + attention_chunk_size: int | None = None def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len @@ -390,10 +393,11 @@ class KVCacheConfig: The KV cache configuration of a model. """ - """The number of KV cache blocks""" num_blocks: int - """How should model runner initialize the KV cache tensors for each layer""" + """The number of KV cache blocks""" kv_cache_tensors: list[KVCacheTensor] + """How should model runner initialize the KV cache tensors for each layer""" + kv_cache_groups: list[KVCacheGroupSpec] """ The kv cache groups of the model. For models with only one type of attention, there is only one group that @@ -401,4 +405,3 @@ class KVCacheConfig: For models with multiple types of attention, there will be multiple groups, see `_get_kv_cache_config_uniform_page_size` for more details. """ - kv_cache_groups: list[KVCacheGroupSpec]