mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 13:41:52 +08:00
[bugfix] fix aria model and add torch.compile (#10645)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
6e9ff050c8
commit
45ac4ff270
@ -29,7 +29,7 @@ from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP,
|
||||
LlamaModel)
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||
is_pp_missing_parameter,
|
||||
make_layers, maybe_prefix,
|
||||
maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
@ -363,27 +363,9 @@ class AriaMoELMModel(LlamaModel):
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
# FIXME: this is a hack to disable the compilation of the model
|
||||
self.do_not_compile = True
|
||||
|
||||
self.layers = None
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MoEDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
layer_type=MoEDecoderLayer)
|
||||
|
||||
# Adapted from LlamaModel.load_weights with the modification of adding
|
||||
# the expert weights mapping to `stacked_params_mapping`
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -273,7 +273,11 @@ class LlamaDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -299,10 +303,10 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: LlamaDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
lambda prefix: layer_type(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user