mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-12 03:47:15 +08:00
[Models] Improve iteration over layers (#26425)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
4ebc9108a7
commit
93f2c0aa08
@ -26,6 +26,7 @@
|
|||||||
"""Inference-only Apertus model compatible with HuggingFace weights."""
|
"""Inference-only Apertus model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from itertools import islice
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -412,7 +413,9 @@ class ApertusModel(nn.Module):
|
|||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = []
|
||||||
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
|
for idx, layer in enumerate(
|
||||||
|
islice(self.layers, self.start_layer, self.end_layer)
|
||||||
|
):
|
||||||
if idx in self.aux_hidden_state_layers:
|
if idx in self.aux_hidden_state_layers:
|
||||||
aux_hidden_states.append(hidden_states + residual)
|
aux_hidden_states.append(hidden_states + residual)
|
||||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
"""Inference-only FalconH1 model."""
|
"""Inference-only FalconH1 model."""
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from itertools import islice
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -480,8 +481,7 @@ class FalconH1Model(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states = layer(
|
hidden_states = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
|||||||
@ -26,6 +26,7 @@
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
|
from itertools import islice
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
@ -672,8 +673,9 @@ class HunYuanModel(nn.Module):
|
|||||||
|
|
||||||
cla_factor = _get_cla_factor(self.config)
|
cla_factor = _get_cla_factor(self.config)
|
||||||
prev_kv_states = None
|
prev_kv_states = None
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i, layer in enumerate(
|
||||||
layer = self.layers[i]
|
islice(self.layers, self.start_layer, self.end_layer)
|
||||||
|
):
|
||||||
hidden_states, residual, kv_states = layer(
|
hidden_states, residual, kv_states = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -681,10 +683,7 @@ class HunYuanModel(nn.Module):
|
|||||||
prev_kv_states,
|
prev_kv_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if getattr(self.config, "use_cla", False) and i % cla_factor == 0:
|
||||||
getattr(self.config, "use_cla", False)
|
|
||||||
and (i - self.start_layer) % cla_factor == 0
|
|
||||||
):
|
|
||||||
prev_kv_states = kv_states
|
prev_kv_states = kv_states
|
||||||
else:
|
else:
|
||||||
prev_kv_states = None
|
prev_kv_states = None
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from itertools import islice
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -492,7 +493,7 @@ class Lfm2MoeModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for layer in self.layers[self.start_layer : self.end_layer]:
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
|||||||
@ -35,6 +35,7 @@
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
|
from itertools import islice
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -519,8 +520,7 @@ class FlashModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
"""PyTorch MAMBA model."""
|
"""PyTorch MAMBA model."""
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from itertools import islice
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -162,8 +163,7 @@ class MambaModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions=positions, hidden_states=hidden_states, residual=residual
|
positions=positions, hidden_states=hidden_states, residual=residual
|
||||||
)
|
)
|
||||||
|
|||||||
@ -26,6 +26,7 @@
|
|||||||
|
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from itertools import islice
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -1106,11 +1107,9 @@ class Qwen3LLMModel(Qwen3Model):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for layer_idx, layer in enumerate(
|
for layer_idx, layer in islice(
|
||||||
self.layers[self.start_layer : self.end_layer]
|
enumerate(self.layers), self.start_layer, self.end_layer
|
||||||
):
|
):
|
||||||
layer_idx = layer_idx + self.start_layer
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|||||||
@ -26,6 +26,7 @@
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from itertools import islice
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -103,11 +104,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for layer_idx, layer in enumerate(
|
for layer_idx, layer in islice(
|
||||||
self.layers[self.start_layer : self.end_layer]
|
enumerate(self.layers), self.start_layer, self.end_layer
|
||||||
):
|
):
|
||||||
layer_idx = layer_idx + self.start_layer
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user