mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 21:52:32 +08:00
[Models] Improve iteration over layers (#26425)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
4ebc9108a7
commit
93f2c0aa08
@ -26,6 +26,7 @@
|
||||
"""Inference-only Apertus model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -412,7 +413,9 @@ class ApertusModel(nn.Module):
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
|
||||
for idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
"""Inference-only FalconH1 model."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -480,8 +481,7 @@ class FalconH1Model(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from itertools import islice
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
@ -672,8 +673,9 @@ class HunYuanModel(nn.Module):
|
||||
|
||||
cla_factor = _get_cla_factor(self.config)
|
||||
prev_kv_states = None
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for i, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
hidden_states, residual, kv_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
@ -681,10 +683,7 @@ class HunYuanModel(nn.Module):
|
||||
prev_kv_states,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(self.config, "use_cla", False)
|
||||
and (i - self.start_layer) % cla_factor == 0
|
||||
):
|
||||
if getattr(self.config, "use_cla", False) and i % cla_factor == 0:
|
||||
prev_kv_states = kv_states
|
||||
else:
|
||||
prev_kv_states = None
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@ -492,7 +493,7 @@ class Lfm2MoeModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in self.layers[self.start_layer : self.end_layer]:
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
|
||||
@ -35,6 +35,7 @@
|
||||
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from itertools import islice
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@ -519,8 +520,7 @@ class FlashModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
"""PyTorch MAMBA model."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -162,8 +163,7 @@ class MambaModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(
|
||||
positions=positions, hidden_states=hidden_states, residual=residual
|
||||
)
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from itertools import islice
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@ -1106,11 +1107,9 @@ class Qwen3LLMModel(Qwen3Model):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer_idx, layer in enumerate(
|
||||
self.layers[self.start_layer : self.end_layer]
|
||||
for layer_idx, layer in islice(
|
||||
enumerate(self.layers), self.start_layer, self.end_layer
|
||||
):
|
||||
layer_idx = layer_idx + self.start_layer
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
|
||||
import typing
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -103,11 +104,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer_idx, layer in enumerate(
|
||||
self.layers[self.start_layer : self.end_layer]
|
||||
for layer_idx, layer in islice(
|
||||
enumerate(self.layers), self.start_layer, self.end_layer
|
||||
):
|
||||
layer_idx = layer_idx + self.start_layer
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user