From 93f2c0aa083173256796b607db2bcb19ec132696 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Wed, 8 Oct 2025 21:48:33 +0100 Subject: [PATCH] [Models] Improve iteration over layers (#26425) Signed-off-by: Lukas Geiger --- vllm/model_executor/models/apertus.py | 5 ++++- vllm/model_executor/models/falcon_h1.py | 4 ++-- vllm/model_executor/models/hunyuan_v1.py | 11 +++++------ vllm/model_executor/models/lfm2_moe.py | 3 ++- vllm/model_executor/models/longcat_flash.py | 4 ++-- vllm/model_executor/models/mamba.py | 4 ++-- vllm/model_executor/models/qwen3_vl.py | 7 +++---- vllm/model_executor/models/qwen3_vl_moe.py | 7 +++---- 8 files changed, 23 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 7432070827214..c5d3d49d67602 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -26,6 +26,7 @@ """Inference-only Apertus model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -412,7 +413,9 @@ class ApertusModel(nn.Module): residual = intermediate_tensors["residual"] aux_hidden_states = [] - for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 8af08711038d4..db938dda5d637 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -3,6 +3,7 @@ """Inference-only FalconH1 model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional import torch @@ -480,8 +481,7 @@ class FalconH1Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer( positions=positions, hidden_states=hidden_states, diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index d33406b7be2b2..220147eb90a73 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -26,6 +26,7 @@ import typing from collections.abc import Callable, Iterable +from itertools import islice from typing import Any, Optional, Union import regex as re @@ -672,8 +673,9 @@ class HunYuanModel(nn.Module): cla_factor = _get_cla_factor(self.config) prev_kv_states = None - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for i, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): hidden_states, residual, kv_states = layer( positions, hidden_states, @@ -681,10 +683,7 @@ class HunYuanModel(nn.Module): prev_kv_states, ) - if ( - getattr(self.config, "use_cla", False) - and (i - self.start_layer) % cla_factor == 0 - ): + if getattr(self.config, "use_cla", False) and i % cla_factor == 0: prev_kv_states = kv_states else: prev_kv_states = None diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py index 728bd90be1170..f7903a7af53fe 100644 --- a/vllm/model_executor/models/lfm2_moe.py +++ b/vllm/model_executor/models/lfm2_moe.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from itertools import islice from typing import Any, Optional import torch @@ -492,7 +493,7 @@ class Lfm2MoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer : self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index 5020da37df897..17ec6b7d2b06a 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -35,6 +35,7 @@ import typing from collections.abc import Callable, Iterable +from itertools import islice from typing import Optional, Union import torch @@ -519,8 +520,7 @@ class FlashModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index fa11f92cce33b..1638aab137aaf 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -3,6 +3,7 @@ """PyTorch MAMBA model.""" from collections.abc import Iterable +from itertools import islice from typing import Optional import torch @@ -162,8 +163,7 @@ class MambaModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual ) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1c532376256d2..76a7cc3210c62 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -26,6 +26,7 @@ from collections.abc import Iterable, Mapping, Sequence from functools import partial +from itertools import islice from typing import Any, Callable, Optional, Union import numpy as np @@ -1106,11 +1107,9 @@ class Qwen3LLMModel(Qwen3Model): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer_idx, layer in enumerate( - self.layers[self.start_layer : self.end_layer] + for layer_idx, layer in islice( + enumerate(self.layers), self.start_layer, self.end_layer ): - layer_idx = layer_idx + self.start_layer - hidden_states, residual = layer( positions, hidden_states, diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index cd8046d04248e..db7bcb0436595 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -26,6 +26,7 @@ import typing from collections.abc import Iterable +from itertools import islice from typing import Callable, Optional, Union import torch @@ -103,11 +104,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer_idx, layer in enumerate( - self.layers[self.start_layer : self.end_layer] + for layer_idx, layer in islice( + enumerate(self.layers), self.start_layer, self.end_layer ): - layer_idx = layer_idx + self.start_layer - hidden_states, residual = layer( positions, hidden_states,