[Models] Improve iteration over layers (#26425)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-10-08 21:48:33 +01:00 committed by GitHub
parent 4ebc9108a7
commit 93f2c0aa08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 23 additions and 22 deletions

View File

@ -26,6 +26,7 @@
"""Inference-only Apertus model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
@ -412,7 +413,9 @@ class ApertusModel(nn.Module):
residual = intermediate_tensors["residual"]
aux_hidden_states = []
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)

View File

@ -3,6 +3,7 @@
"""Inference-only FalconH1 model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional
import torch
@ -480,8 +481,7 @@ class FalconH1Model(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,

View File

@ -26,6 +26,7 @@
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union
import regex as re
@ -672,8 +673,9 @@ class HunYuanModel(nn.Module):
cla_factor = _get_cla_factor(self.config)
prev_kv_states = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for i, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
hidden_states, residual, kv_states = layer(
positions,
hidden_states,
@ -681,10 +683,7 @@ class HunYuanModel(nn.Module):
prev_kv_states,
)
if (
getattr(self.config, "use_cla", False)
and (i - self.start_layer) % cla_factor == 0
):
if getattr(self.config, "use_cla", False) and i % cla_factor == 0:
prev_kv_states = kv_states
else:
prev_kv_states = None

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional
import torch
@ -492,7 +493,7 @@ class Lfm2MoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer : self.end_layer]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,

View File

@ -35,6 +35,7 @@
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Optional, Union
import torch
@ -519,8 +520,7 @@ class FlashModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,

View File

@ -3,6 +3,7 @@
"""PyTorch MAMBA model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional
import torch
@ -162,8 +163,7 @@ class MambaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions=positions, hidden_states=hidden_states, residual=residual
)

View File

@ -26,6 +26,7 @@
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from itertools import islice
from typing import Any, Callable, Optional, Union
import numpy as np
@ -1106,11 +1107,9 @@ class Qwen3LLMModel(Qwen3Model):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer_idx, layer in enumerate(
self.layers[self.start_layer : self.end_layer]
for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer
):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer(
positions,
hidden_states,

View File

@ -26,6 +26,7 @@
import typing
from collections.abc import Iterable
from itertools import islice
from typing import Callable, Optional, Union
import torch
@ -103,11 +104,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer_idx, layer in enumerate(
self.layers[self.start_layer : self.end_layer]
for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer
):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer(
positions,
hidden_states,