[Models] Improve iteration over layers (#26425)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-10-08 21:48:33 +01:00 committed by GitHub
parent 4ebc9108a7
commit 93f2c0aa08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 23 additions and 22 deletions

View File

@ -26,6 +26,7 @@
"""Inference-only Apertus model compatible with HuggingFace weights.""" """Inference-only Apertus model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
@ -412,7 +413,9 @@ class ApertusModel(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = []
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers: if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual) aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)

View File

@ -3,6 +3,7 @@
"""Inference-only FalconH1 model.""" """Inference-only FalconH1 model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional from typing import Optional
import torch import torch
@ -480,8 +481,7 @@ class FalconH1Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,

View File

@ -26,6 +26,7 @@
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import regex as re import regex as re
@ -672,8 +673,9 @@ class HunYuanModel(nn.Module):
cla_factor = _get_cla_factor(self.config) cla_factor = _get_cla_factor(self.config)
prev_kv_states = None prev_kv_states = None
for i in range(self.start_layer, self.end_layer): for i, layer in enumerate(
layer = self.layers[i] islice(self.layers, self.start_layer, self.end_layer)
):
hidden_states, residual, kv_states = layer( hidden_states, residual, kv_states = layer(
positions, positions,
hidden_states, hidden_states,
@ -681,10 +683,7 @@ class HunYuanModel(nn.Module):
prev_kv_states, prev_kv_states,
) )
if ( if getattr(self.config, "use_cla", False) and i % cla_factor == 0:
getattr(self.config, "use_cla", False)
and (i - self.start_layer) % cla_factor == 0
):
prev_kv_states = kv_states prev_kv_states = kv_states
else: else:
prev_kv_states = None prev_kv_states = None

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@ -492,7 +493,7 @@ class Lfm2MoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer : self.end_layer]: for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,

View File

@ -35,6 +35,7 @@
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import islice
from typing import Optional, Union from typing import Optional, Union
import torch import torch
@ -519,8 +520,7 @@ class FlashModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,

View File

@ -3,6 +3,7 @@
"""PyTorch MAMBA model.""" """PyTorch MAMBA model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Optional from typing import Optional
import torch import torch
@ -162,8 +163,7 @@ class MambaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, hidden_states=hidden_states, residual=residual positions=positions, hidden_states=hidden_states, residual=residual
) )

View File

@ -26,6 +26,7 @@
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from itertools import islice
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import numpy as np import numpy as np
@ -1106,11 +1107,9 @@ class Qwen3LLMModel(Qwen3Model):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer_idx, layer in enumerate( for layer_idx, layer in islice(
self.layers[self.start_layer : self.end_layer] enumerate(self.layers), self.start_layer, self.end_layer
): ):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,

View File

@ -26,6 +26,7 @@
import typing import typing
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import torch import torch
@ -103,11 +104,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer_idx, layer in enumerate( for layer_idx, layer in islice(
self.layers[self.start_layer : self.end_layer] enumerate(self.layers), self.start_layer, self.end_layer
): ):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,